from __future__ import print_function, division, absolute_import

import logging
import socket
import os
import sys
import time
import traceback

try:
    from queue import Queue
except ImportError:  # Python 2.7 fix
    from Queue import Queue

from threading import Thread

from toolz import merge

from tornado import gen


logger = logging.getLogger(__name__)


# These are handy for creating colorful terminal output to enhance readability
# of the output generated by dask-ssh.
class bcolors:
    HEADER = "\033[95m"
    OKBLUE = "\033[94m"
    OKGREEN = "\033[92m"
    WARNING = "\033[93m"
    FAIL = "\033[91m"
    ENDC = "\033[0m"
    BOLD = "\033[1m"
    UNDERLINE = "\033[4m"


def async_ssh(cmd_dict):
    import paramiko
    from paramiko.buffered_pipe import PipeTimeout
    from paramiko.ssh_exception import SSHException, PasswordRequiredException

    ssh = paramiko.SSHClient()
    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())

    retries = 0
    while True:  # Be robust to transient SSH failures.
        try:
            # Set paramiko logging to WARN or higher to squelch INFO messages.
            logging.getLogger("paramiko").setLevel(logging.WARN)

            ssh.connect(
                hostname=cmd_dict["address"],
                username=cmd_dict["ssh_username"],
                port=cmd_dict["ssh_port"],
                key_filename=cmd_dict["ssh_private_key"],
                compress=True,
                timeout=20,
                banner_timeout=20,
            )  # Helps prevent timeouts when many concurrent ssh connections are opened.
            # Connection successful, break out of while loop
            break

        except (SSHException, PasswordRequiredException) as e:

            print(
                "[ dask-ssh ] : "
                + bcolors.FAIL
                + "SSH connection error when connecting to {addr}:{port}"
                "to run '{cmd}'".format(
                    addr=cmd_dict["address"],
                    port=cmd_dict["ssh_port"],
                    cmd=cmd_dict["cmd"],
                )
                + bcolors.ENDC
            )

            print(
                bcolors.FAIL
                + "               SSH reported this exception: "
                + str(e)
                + bcolors.ENDC
            )

            # Print an exception traceback
            traceback.print_exc()

            # Transient SSH errors can occur when many SSH connections are
            # simultaneously opened to the same server. This makes a few
            # attempts to retry.
            retries += 1
            if retries >= 3:
                print(
                    "[ dask-ssh ] : "
                    + bcolors.FAIL
                    + "SSH connection failed after 3 retries. Exiting."
                    + bcolors.ENDC
                )

                # Connection failed after multiple attempts.  Terminate this thread.
                os._exit(1)

            # Wait a moment before retrying
            print(
                "               "
                + bcolors.FAIL
                + "Retrying... (attempt {n}/{total})".format(n=retries, total=3)
                + bcolors.ENDC
            )

            time.sleep(1)

    # Execute the command, and grab file handles for stdout and stderr. Note
    # that we run the command using the user's default shell, but force it to
    # run in an interactive login shell, which hopefully ensures that all of the
    # user's normal environment variables (via the dot files) have been loaded
    # before the command is run. This should help to ensure that important
    # aspects of the environment like PATH and PYTHONPATH are configured.

    print("[ {label} ] : {cmd}".format(label=cmd_dict["label"], cmd=cmd_dict["cmd"]))
    stdin, stdout, stderr = ssh.exec_command(
        "$SHELL -i -c '" + cmd_dict["cmd"] + "'", get_pty=True
    )

    # Set up channel timeout (which we rely on below to make readline() non-blocking)
    channel = stdout.channel
    channel.settimeout(0.1)

    def read_from_stdout():
        """
        Read stdout stream, time out if necessary.
        """
        try:
            line = stdout.readline()
            while len(line) > 0:  # Loops until a timeout exception occurs
                line = line.rstrip()
                logger.debug("stdout from ssh channel: %s", line)
                cmd_dict["output_queue"].put(
                    "[ {label} ] : {output}".format(
                        label=cmd_dict["label"], output=line
                    )
                )
                line = stdout.readline()
        except (PipeTimeout, socket.timeout):
            pass

    def read_from_stderr():
        """
        Read stderr stream, time out if necessary.
        """
        try:
            line = stderr.readline()
            while len(line) > 0:
                line = line.rstrip()
                logger.debug("stderr from ssh channel: %s", line)
                cmd_dict["output_queue"].put(
                    "[ {label} ] : ".format(label=cmd_dict["label"])
                    + bcolors.FAIL
                    + "{output}".format(output=line)
                    + bcolors.ENDC
                )
                line = stderr.readline()
        except (PipeTimeout, socket.timeout):
            pass

    def communicate():
        """
        Communicate a little bit, without blocking too long.
        Return True if the command ended.
        """
        read_from_stdout()
        read_from_stderr()

        # Check to see if the process has exited. If it has, we let this thread
        # terminate.
        if channel.exit_status_ready():
            exit_status = channel.recv_exit_status()
            cmd_dict["output_queue"].put(
                "[ {label} ] : ".format(label=cmd_dict["label"])
                + bcolors.FAIL
                + "remote process exited with exit status "
                + str(exit_status)
                + bcolors.ENDC
            )
            return True

    # Wait for a message on the input_queue. Any message received signals this
    # thread to shut itself down.
    while cmd_dict["input_queue"].empty():
        # Kill some time so that this thread does not hog the CPU.
        time.sleep(1.0)
        if communicate():
            break

    # Ctrl-C the executing command and wait a bit for command to end cleanly
    start = time.time()
    while time.time() < start + 5.0:
        channel.send(b"\x03")  # Ctrl-C
        if communicate():
            break
        time.sleep(1.0)

    # Shutdown the channel, and close the SSH connection
    channel.close()
    ssh.close()


def start_scheduler(
    logdir, addr, port, ssh_username, ssh_port, ssh_private_key, remote_python=None
):
    cmd = "{python} -m distributed.cli.dask_scheduler --port {port}".format(
        python=remote_python or sys.executable, port=port, logdir=logdir
    )

    # Optionally re-direct stdout and stderr to a logfile
    if logdir is not None:
        cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd
        cmd += "&> {logdir}/dask_scheduler_{addr}:{port}.log".format(
            addr=addr, port=port, logdir=logdir
        )

    # Format output labels we can prepend to each line of output, and create
    # a 'status' key to keep track of jobs that terminate prematurely.
    label = (
        bcolors.BOLD
        + "scheduler {addr}:{port}".format(addr=addr, port=port)
        + bcolors.ENDC
    )

    # Create a command dictionary, which contains everything we need to run and
    # interact with this command.
    input_queue = Queue()
    output_queue = Queue()
    cmd_dict = {
        "cmd": cmd,
        "label": label,
        "address": addr,
        "port": port,
        "input_queue": input_queue,
        "output_queue": output_queue,
        "ssh_username": ssh_username,
        "ssh_port": ssh_port,
        "ssh_private_key": ssh_private_key,
    }

    # Start the thread
    thread = Thread(target=async_ssh, args=[cmd_dict])
    thread.daemon = True
    thread.start()

    return merge(cmd_dict, {"thread": thread})


def start_worker(
    logdir,
    scheduler_addr,
    scheduler_port,
    worker_addr,
    nthreads,
    nprocs,
    ssh_username,
    ssh_port,
    ssh_private_key,
    nohost,
    memory_limit,
    worker_port,
    nanny_port,
    remote_python=None,
    remote_dask_worker="distributed.cli.dask_worker",
):

    cmd = (
        "{python} -m {remote_dask_worker} "
        "{scheduler_addr}:{scheduler_port} "
        "--nthreads {nthreads}" + (" --nprocs {nprocs}" if nprocs != 1 else "")
    )

    if not nohost:
        cmd += " --host {worker_addr}"

    if memory_limit:
        cmd += " --memory-limit {memory_limit}"

    if worker_port:
        cmd += " --worker-port {worker_port}"

    if nanny_port:
        cmd += " --nanny-port {nanny_port}"

    cmd = cmd.format(
        python=remote_python or sys.executable,
        remote_dask_worker=remote_dask_worker,
        scheduler_addr=scheduler_addr,
        scheduler_port=scheduler_port,
        worker_addr=worker_addr,
        nthreads=nthreads,
        nprocs=nprocs,
        memory_limit=memory_limit,
        worker_port=worker_port,
        nanny_port=nanny_port,
    )

    # Optionally redirect stdout and stderr to a logfile
    if logdir is not None:
        cmd = "mkdir -p {logdir} && ".format(logdir=logdir) + cmd
        cmd += "&> {logdir}/dask_scheduler_{addr}.log".format(
            addr=worker_addr, logdir=logdir
        )

    label = "worker {addr}".format(addr=worker_addr)

    # Create a command dictionary, which contains everything we need to run and
    # interact with this command.
    input_queue = Queue()
    output_queue = Queue()
    cmd_dict = {
        "cmd": cmd,
        "label": label,
        "address": worker_addr,
        "input_queue": input_queue,
        "output_queue": output_queue,
        "ssh_username": ssh_username,
        "ssh_port": ssh_port,
        "ssh_private_key": ssh_private_key,
    }

    # Start the thread
    thread = Thread(target=async_ssh, args=[cmd_dict])
    thread.daemon = True
    thread.start()

    return merge(cmd_dict, {"thread": thread})


class SSHCluster(object):
    def __init__(
        self,
        scheduler_addr,
        scheduler_port,
        worker_addrs,
        nthreads=0,
        nprocs=1,
        ssh_username=None,
        ssh_port=22,
        ssh_private_key=None,
        nohost=False,
        logdir=None,
        remote_python=None,
        memory_limit=None,
        worker_port=None,
        nanny_port=None,
        remote_dask_worker="distributed.cli.dask_worker",
    ):

        self.scheduler_addr = scheduler_addr
        self.scheduler_port = scheduler_port
        self.nthreads = nthreads
        self.nprocs = nprocs

        self.ssh_username = ssh_username
        self.ssh_port = ssh_port
        self.ssh_private_key = ssh_private_key

        self.nohost = nohost

        self.remote_python = remote_python

        self.memory_limit = memory_limit
        self.worker_port = worker_port
        self.nanny_port = nanny_port
        self.remote_dask_worker = remote_dask_worker

        # Generate a universal timestamp to use for log files
        import datetime

        if logdir is not None:
            logdir = os.path.join(
                logdir,
                "dask-ssh_" + datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
            )
            print(
                bcolors.WARNING + "Output will be redirected to logfiles "
                'stored locally on individual worker nodes under "{logdir}".'.format(
                    logdir=logdir
                )
                + bcolors.ENDC
            )
        self.logdir = logdir

        # Keep track of all running threads
        self.threads = []

        # Start the scheduler node
        self.scheduler = start_scheduler(
            logdir,
            scheduler_addr,
            scheduler_port,
            ssh_username,
            ssh_port,
            ssh_private_key,
            remote_python,
        )

        # Start worker nodes
        self.workers = []
        for i, addr in enumerate(worker_addrs):
            self.add_worker(addr)

    @gen.coroutine
    def _start(self):
        pass

    @property
    def scheduler_address(self):
        return "%s:%d" % (self.scheduler_addr, self.scheduler_port)

    def monitor_remote_processes(self):

        # Form a list containing all processes, since we treat them equally from here on out.
        all_processes = [self.scheduler] + self.workers

        try:
            while True:
                for process in all_processes:
                    while not process["output_queue"].empty():
                        print(process["output_queue"].get())

                # Kill some time and free up CPU before starting the next sweep
                # through the processes.
                time.sleep(0.1)

            # end while true

        except KeyboardInterrupt:
            pass  # Return execution to the calling process

    def add_worker(self, address):
        self.workers.append(
            start_worker(
                self.logdir,
                self.scheduler_addr,
                self.scheduler_port,
                address,
                self.nthreads,
                self.nprocs,
                self.ssh_username,
                self.ssh_port,
                self.ssh_private_key,
                self.nohost,
                self.memory_limit,
                self.worker_port,
                self.nanny_port,
                self.remote_python,
                self.remote_dask_worker,
            )
        )

    def shutdown(self):
        all_processes = [self.scheduler] + self.workers

        for process in all_processes:
            process["input_queue"].put("shutdown")
            process["thread"].join()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.shutdown()
