#!/usr/bin/env python
# -*- mode: python; c-basic-offset: 4; tab-width: 8; indent-tabs-mode: nil -*-
# vi: set shiftwidth=4 tabstop=8 softtabstop=4 expandtab:
# :indentSize=4:tabSize=8:noTabs=true:
# vim: filetype=python

""" Simple wrapper for qsub which provides the functionality of the -sync
option"""

from __future__ import absolute_import, division, print_function

import sys
import os
import re
import subprocess
import time
import logging
import tempfile


from xml.etree import ElementTree

__version__ = 1.0

# CONFIGURATION

# max number of times QSTAT will be polled
MAX_QSTAT_FAILURES = 3
# time between calls to qstat (in sec)
POLLING_INTERVAL = 5
# Max polling time to for wait for exitcode file (sec) default is 5 minutes
MAX_POLLING_TIME = 60 * 1.5

# regex to match output from scheduler
SGE_JOB_ID_DECODER = "Your job (\d+) \(\".+\"\) has been submitted"
PBS_JOB_ID_DECODER = "(\d+)\..+"

SGE_NO_JOB_FOUND_CODE = 1
SGE_SUCCESS_CODE = 0

PBS_NO_JOB_FOUND_CODE = 153
PBS_SUCCESS_CODE = 0

# Unique str (built from a palindrome) to extract exitcode from wrapper script
EXIT_CODE_ENCODER = "PB_EXIT_CODE_${?}_EDOC_TIXE_BP"
EXIT_CODE_DECODER = "PB_EXIT_CODE_(\d+)_EDOC_TIXE_BP"

# internal debugging flag for testing your configuration.
#DEBUG = True
DEBUG = False

logging.info('Starting logger for...')
log = logging.getLogger(__name__)


def _setupLog(file_name=None):
    if file_name is None:
        handler = logging.StreamHandler(sys.stdout)
    else:
        handler = logging.FileHandler(file_name)

    str_formatter = '[%(levelname)s] %(asctime)-15s [%(name)s %(funcName)s %(lineno)d] %(message)s'
    formatter = logging.Formatter(str_formatter)
    handler.setFormatter(formatter)
    log.addHandler(handler)
    log.setLevel(logging.DEBUG)


class QSubError(Exception):
    pass


class QSubWrapper(object):

    def __init__(self):
        self.scriptToRun = None
        self.jobIdDecoder = None
        self.noJobFoundCode = None
        self.qstatCmd = None
        self.successCode = None
        self.tmpScriptPath = None
        self.qsubArgs = None
        self.debug = False

        # This will properly initialize the instance vars
        self.__parseArgs()

    def __parseArgs(self):
        """Handle command line argument parsing"""
        args = sys.argv[1:]

        log.debug("Parsing args {a}".format(a=args))

        if len(args) < 2:
            sys.stderr.write("qsw.py script.sh [qsub options]\n")
            sys.exit(-1)

        self.scriptToRun = args[0]

        if not os.path.exists(self.scriptToRun):
            sys.stderr.write("Unable to find {f}\n".format(f=self.scriptToRun))
            # Standard FileNotFound return code
            sys.exit(256)

        self.tmpScriptPath = os.path.splitext(self.scriptToRun)[0] + ".wrap.sh"

        # We need this option in order to function correctly.
        if '-DEBUG' in args:
            args.remove('-DEBUG')
            DEBUG = True
            log.setLevel(logging.DEBUG)
        if '-S' not in args:
            args.append('-S')
            args.append('/bin/bash')

        # Check if we are wrapping PBS or SGE
        if '-PBS' in args:
            args.remove('-PBS')
            # Use xml output of qstat
            self.jobIdDecoder = PBS_JOB_ID_DECODER
            self.noJobFoundCode = PBS_NO_JOB_FOUND_CODE
            self.successCode = PBS_SUCCESS_CODE
            self.qstatCmd = "qstat -x"
        elif '-SGE' in args:
            args.remove('-SGE')
            # always use -sync n
            if '-sync' in args:
                sync_yes = args.index('-sync') + 1
                # index begins with 0
                if len(args) <= sync_yes:
                    args.remove('-sync')
                    args.append('-sync')
                    args.append('n')
                elif 'y' in args[sync_yes]:
                    args[args.index('-sync') + 1] = "n"
                    log.debug("Overriding -sync y")
                else:
                    pass
            self.jobIdDecoder = SGE_JOB_ID_DECODER
            self.noJobFoundCode = SGE_NO_JOB_FOUND_CODE
            self.successCode = SGE_SUCCESS_CODE
            if '-no-qstat' in args:
                args.remove('-no-qstat')
                self.qstatCmd = "noqstat -j"
            else:
                self.qstatCmd = "qstat -j"
        else:
            sys.stderr.write("unsupported scheduler")
            sys.exit(256)

        self.qsubArgs = " ".join(args[1:])

    def _writeWrapperScript(self, childScript, parentScript):
        """
        Generates a wrapper(parent) script which calls the child script and
        pipes its exit code to a file.
        Returns the path to this file.
        """

        base, _ = os.path.splitext(childScript)

        exitCodePath = base + '.exit'
        # this will be the actual stdout and stderr of the task
        # this will write the task log directory
        log_base = base.replace('/workflow/', '/log/')
        std_out = log_base + '.stdout'
        std_err = log_base + '.stderr'

        with open(parentScript, 'w') as out:
            out.write("# Wrapper script for %s\n" % childScript)
            out.write("/bin/bash {c} > {o} 2> {e}\n".format(c=childScript, o=std_out, e=std_err))
            # MK: This might be a better/simpler solution to directly output the returncode?
            #outFile.write("echo $? > exitCodePath")
            out.write("echo \"%s\" > %s;\n" % (EXIT_CODE_ENCODER, exitCodePath))

        log.debug("Completed writing wrapper script to {x}".format(x=exitCodePath))
        return exitCodePath

    def _waitForJobTermination(self, jobId, exitCodePath):
        """
        Loop until we no longer see the job in qstat or we hit a bunch of
        qstat failures.
        """
        log.info("waiting for jobId {i} to complete".format(i=jobId))

        consecutiveFailures = 0

        failed = True
        while True:
            if os.path.isfile(exitCodePath):
                log.info("Found exit file at '{s}'".format(s=exitCodePath))
                break

            if "noqstat" in self.qstatCmd and not failed:
                if os.path.isfile(exitCodePath.replace(".exit", ".stdout")):
                    log.debug("job %s is running" % jobId)
                else:
                    log.debug("Waiting for job %s to start" % jobId)
                pass
            else:
                # Push this error up to smrtpipe level and useful for
                # debugging/testing
                e_out = tempfile.NamedTemporaryFile(suffix='.err')
                o_out = tempfile.NamedTemporaryFile(suffix='.out')

                # Check job status via qstat/qstat -j 1234
                if "noqstat" in self.qstatCmd:
                    cmd = "%s %s" % (self.qstatCmd.replace("noqstat", "qstat"), jobId)
                else:
                    cmd = "%s %s" % (self.qstatCmd, jobId)

                log.debug("calling cmd {c}".format(c=cmd))
                retCode = subprocess.call(cmd, shell=True,
                                          stderr=e_out,
                                          stdout=o_out)
                e_out.seek(0)
                error_str = e_out.read()
                if error_str:
                    # this should just be qstat related errors
                    # (e.g., unable to get server, timeout)
                    # Most of the time this will catch Errors like:
                    # PBS: 'qstat: Unknown Job Id 38976.localhost.localdomain'
                    # SGE: Following jobs do not exist:\n1234
                    for myline in error_str.splitlines():
                        log.warn(myline)
                    log.warn("The job might have ended.")


                if DEBUG:
                    o_out.seek(0)
                    out_str = o_out.read()
                    log.debug(out_str)

                failed = retCode not in [self.noJobFoundCode, self.successCode]
                consecutiveFailures = consecutiveFailures + 1 if failed else 0

                if consecutiveFailures >= MAX_QSTAT_FAILURES:
                    msg = "Unable to query qstat for job status (job %s). Failed %d times." % (jobId, consecutiveFailures)
                    log.error(msg)
                    raise QSubError(msg)

            if retCode == self.noJobFoundCode:
                # Job is not indentified by qstat (qstat returns non-zero exit
                # code). Hence the job is done.
                log.info("Breaking. Found returncode {r}".format(r=retCode))
                e_out.close()
                o_out.close()
                break

            if self.qstatCmd == "qstat -x":
                try:
                    xml = ElementTree.parse(o_out.name)
                    state=xml.getroot().getchildren()[0].find('job_state').text
                    if state == 'C':
                        # job completed
                        log.info("Breaking.  Found return job_state {s}".format(s=state))
                        e_out.close()
                        o_out.close()
                        break
                except Exception as e:
                    o_out.seek(0)
                    out_str = o_out.read()
                    match = re.search("^Job\s", out_str);
                    if match:
                        # This is a hack to handle the PBSPro "qstat -x" output.
                        # case.  PBSPro does not treat the -x as the 'xml output'
                        # flag, but rather to print the output in standard form.
                        #
                        # From: http://www.pbsworks.com/documentation/support/PBSProUserGuide11.2.pdf
                        #   qstat -x
                        #      Displays information for queued, running, finished,
                        #      and moved jobs, in standard format.
                        #   qstat -x <job ID>
                        #      Displays information for a job, regardless of its
                        #      state, in standard format
                        #   Example Output:
                        #      % qstat -x
                        #      Job id        Name       User  Time Use  S  Queue
                        #      ------------- ---------- ----- ------   --- -----
                        #      101.server1   STDIN      user1 00:00:00  F  workq
                        #      102.server1   STDIN      user1 00:00:00  M  destq@server2
                        #      103.server1   STDIN      user1 00:00:00  R  workq
                        #      104.server1   STDIN      user1 00:00:00  Q  workq
                        #
                        # Where the Job State 'S' field is defined as:
                        # From: http://www.pbsworks.com/pdfs/PBSReferenceGuide13.0.pdf
                        #      B Array job has at least one subjob running.
                        #      E Job is exiting after having run.
                        #      F Job is finished.
                        #      H Job is held.
                        #      M Job was moved to another server.
                        #      Q Job is queued.
                        #      R Job is running.
                        #      S Job is suspended.
                        #      T Job is being moved to new location.
                        #      U Cycle-harvesting job is suspended due to keyboard
                        #        activity.
                        #      W Job is waiting for its submitter -assigned start
                        #        time to be reached.
                        #      X Subjob has completed execution or has been
                        #        deleted.
                        #
                        # For now, this hack just determines if the first word of
                        # the output is 'Job' and if so, tries to determine if the
                        # job is finished by looking for the 'F' status in the 'S'
                        # field.  This may not be sufficient in the long run (we
                        # may need to handle the exit codes differently, we may
                        # need to handle other state possibilities, we may want to
                        # detect we are running PBSPro and use the -f (full) output
                        # option,...).
                        #
                        # We should probably change this exception handling and
                        # exit with non-zero exit status if we get unexpected
                        # output (instead of issuing errors and potentially
                        # hangingi), in both the PBS xml output and the PBSPro
                        # output case.
                        match = re.search("^{s}[\.\s].*\s(F)\s+\S+\s*$".format(s=jobId), out_str, re.MULTILINE)
                        if match:
                            log.info("Breaking.  Found return job_state '{s}'".format(s=match.group(1)))
                            e_out.close()
                            o_out.close()
                            break
                    else:
                        o_out.seek(0)
                        out_str = o_out.read()
                        match = re.search("^Job\s", out_str);
            else:
                pass


            time.sleep(POLLING_INTERVAL)

        log.info("Completed waiting for termination of jobId {i}".format(i=jobId))

    def _extractExitCode(self, fileName):
        """
        Extracts the exit code written to the specified path of the
        wrapper script.
        """
        # fileName might still not show up immediately, so wait for the file to
        # appear.

        log.debug("extracting exit code from {f}".format(f=fileName))
        nAttempts = 0
        startedAt = time.time()
        runTime = time.time() - startedAt

        while runTime < MAX_POLLING_TIME:
            # need to careful when this is computed due to try/catch
            runTime = time.time() - startedAt
            nAttempts += 1

            # Attempt to force an NFS Refresh
            os.listdir(os.path.dirname(fileName))
            try:
                with open(fileName, 'r') as inFile:
                    match = re.search(EXIT_CODE_DECODER, inFile.read())

                if match:
                    exitCode = int(match.group(1))
                else:
                    msg = "Unable to extract exit code from file %s" % fileName
                    raise QSubError(msg)
                return exitCode

            except Exception as e:
                print("Attempt {n} still waiting for {f}. Got error {e}. runtime {r}".format(n=nAttempts, f=fileName, e=str(e), r=runTime))

            time.sleep(POLLING_INTERVAL)

        # if we haven't been able to see the file by now, raise an Exception
        raise QSubError("Unable to extract exit code from file {f}".format(f=fileName))

    def run(self):
        """
        Submits the command using qsub. Monitors progress using qstat.
        Returns with the exit code of the job

        :return: (int) Returncode
        """

        exitCodePath = self._writeWrapperScript(self.scriptToRun, self.tmpScriptPath)

        try:
            cmd = "qsub %s %s" % (self.qsubArgs, self.tmpScriptPath)
            log.info("calling cmd : {c}".format(c=cmd))
            output = subprocess.check_output(cmd, shell=True)

            match = re.search(self.jobIdDecoder, output)

            if match:
                jobId = int(match.group(1))
            else:
                msg = "Unable to derive jobId from qsub output ({s}) using pattern ({x})".format(s=output, x=self.jobIdDecoder)
                raise QSubError(msg)

            # This will block
            self._waitForJobTermination(jobId, exitCodePath)

            exitCode = self._extractExitCode(exitCodePath)
            # propogate the exitCode to sys.exit()
            log.info("exiting {f} with returncode {r}".format(f=__file__,
                                                              r=exitCode))
            return exitCode

        except KeyboardInterrupt:
            if jobId is None:
                raise KeyboardInterrupt
            cmd = "qdel %s" % jobId
            retCode = subprocess.call(cmd, shell=True)

            if retCode != self.successCode:
                msg = "Unable to qdel running job %s. You may have to kill it manually" % jobId
                raise QSubError(msg)


if __name__ == "__main__":
    if DEBUG:
        file_name = os.path.join(os.getcwd(), 'wrapper.log')
        _setupLog()
        log.info("Running {f} version {v}".format(f=__file__, v=__version__))
    app = QSubWrapper()
    sys.exit(app.run())
