from __future__ import print_function, division, absolute_import

from collections import defaultdict, deque
import logging
from math import log
from time import time

import dask
from .core import CommClosedError
from .diagnostics.plugin import SchedulerPlugin
from .utils import log_errors, PeriodicCallback

try:
    from cytoolz import topk
except ImportError:
    from toolz import topk

BANDWIDTH = 100e6
LATENCY = 10e-3
log_2 = log(2)

logger = logging.getLogger(__name__)


LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")


class WorkStealing(SchedulerPlugin):
    def __init__(self, scheduler):
        self.scheduler = scheduler
        # { level: { task states } }
        self.stealable_all = [set() for i in range(15)]
        # { worker: { level: { task states } } }
        self.stealable = dict()
        # { task state: (worker, level) }
        self.key_stealable = dict()
        # { prefix: { task states } }
        self.stealable_unknown_durations = defaultdict(set)

        self.cost_multipliers = [1 + 2 ** (i - 6) for i in range(15)]
        self.cost_multipliers[0] = 1

        for worker in scheduler.workers:
            self.add_worker(worker=worker)

        pc = PeriodicCallback(
            callback=self.balance, callback_time=100, io_loop=self.scheduler.loop
        )
        self._pc = pc
        self.scheduler.periodic_callbacks["stealing"] = pc
        self.scheduler.plugins.append(self)
        self.scheduler.extensions["stealing"] = self
        self.scheduler.events["stealing"] = deque(maxlen=100000)
        self.count = 0
        # { task state: <stealing info dict> }
        self.in_flight = dict()
        # { worker state: occupancy }
        self.in_flight_occupancy = defaultdict(lambda: 0)

        self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm

    @property
    def log(self):
        return self.scheduler.events["stealing"]

    def add_worker(self, scheduler=None, worker=None):
        self.stealable[worker] = [set() for i in range(15)]

    def remove_worker(self, scheduler=None, worker=None):
        del self.stealable[worker]

    def teardown(self):
        self._pc.stop()

    def transition(
        self, key, start, finish, compute_start=None, compute_stop=None, *args, **kwargs
    ):
        ts = self.scheduler.tasks[key]
        if finish == "processing":
            self.put_key_in_stealable(ts)

        if start == "processing":
            self.remove_key_from_stealable(ts)
            if finish == "memory":
                for tts in self.stealable_unknown_durations.pop(ts.prefix, ()):
                    if tts not in self.in_flight and tts.state == "processing":
                        self.put_key_in_stealable(tts)
            else:
                self.in_flight.pop(ts, None)

    def put_key_in_stealable(self, ts):
        ws = ts.processing_on
        worker = ws.address
        cost_multiplier, level = self.steal_time_ratio(ts)
        self.log.append(("add-stealable", ts.key, worker, level))
        if cost_multiplier is not None:
            self.stealable_all[level].add(ts)
            self.stealable[worker][level].add(ts)
            self.key_stealable[ts] = (worker, level)

    def remove_key_from_stealable(self, ts):
        result = self.key_stealable.pop(ts, None)
        if result is None:
            return

        worker, level = result
        self.log.append(("remove-stealable", ts.key, worker, level))
        try:
            self.stealable[worker][level].remove(ts)
        except KeyError:
            pass
        try:
            self.stealable_all[level].remove(ts)
        except KeyError:
            pass

    def steal_time_ratio(self, ts):
        """ The compute to communication time ratio of a key

        Returns
        -------

        cost_multiplier: The increased cost from moving this task as a factor.
        For example a result of zero implies a task without dependencies.
        level: The location within a stealable list to place this value
        """
        if not ts.loose_restrictions and (
            ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions
        ):
            return None, None  # don't steal

        if not ts.dependencies:  # no dependencies fast path
            return 0, 0

        nbytes = sum(dep.get_nbytes() for dep in ts.dependencies)

        transfer_time = nbytes / BANDWIDTH + LATENCY
        split = ts.prefix
        if split in fast_tasks:
            return None, None
        ws = ts.processing_on
        if ws is None:
            self.stealable_unknown_durations[split].add(ts)
            return None, None
        else:
            compute_time = ws.processing[ts]
            if compute_time < 0.005:  # 5ms, just give up
                return None, None
            cost_multiplier = transfer_time / compute_time
            if cost_multiplier > 100:
                return None, None

            level = int(round(log(cost_multiplier) / log_2 + 6, 0))
            level = max(1, level)
            return cost_multiplier, level

    def move_task_request(self, ts, victim, thief):
        try:
            if self.scheduler.validate:
                if victim is not ts.processing_on:
                    import pdb

                    pdb.set_trace()

            key = ts.key
            self.remove_key_from_stealable(ts)
            logger.debug(
                "Request move %s, %s: %2f -> %s: %2f",
                key,
                victim,
                victim.occupancy,
                thief,
                thief.occupancy,
            )

            victim_duration = victim.processing[ts]

            thief_duration = self.scheduler.get_task_duration(
                ts
            ) + self.scheduler.get_comm_cost(ts, thief)

            self.scheduler.stream_comms[victim.address].send(
                {"op": "steal-request", "key": key}
            )

            self.in_flight[ts] = {
                "victim": victim,
                "thief": thief,
                "victim_duration": victim_duration,
                "thief_duration": thief_duration,
            }

            self.in_flight_occupancy[victim] -= victim_duration
            self.in_flight_occupancy[thief] += thief_duration
        except CommClosedError:
            logger.info("Worker comm closed while stealing: %s", victim)
        except Exception as e:
            logger.exception(e)
            if LOG_PDB:
                import pdb

                pdb.set_trace()
            raise

    def move_task_confirm(self, key=None, worker=None, state=None):
        try:
            try:
                ts = self.scheduler.tasks[key]
            except KeyError:
                logger.debug("Key released between request and confirm: %s", key)
                return
            try:
                d = self.in_flight.pop(ts)
            except KeyError:
                return
            thief = d["thief"]
            victim = d["victim"]
            logger.debug(
                "Confirm move %s, %s -> %s.  State: %s", key, victim, thief, state
            )

            self.in_flight_occupancy[thief] -= d["thief_duration"]
            self.in_flight_occupancy[victim] += d["victim_duration"]

            if not self.in_flight:
                self.in_flight_occupancy = defaultdict(lambda: 0)

            if ts.state != "processing" or ts.processing_on is not victim:
                old_thief = thief.occupancy
                new_thief = sum(thief.processing.values())
                old_victim = victim.occupancy
                new_victim = sum(victim.processing.values())
                thief.occupancy = new_thief
                victim.occupancy = new_victim
                self.scheduler.total_occupancy += (
                    new_thief - old_thief + new_victim - old_victim
                )
                return

            # One of the pair has left, punt and reschedule
            if (
                thief.address not in self.scheduler.workers
                or victim.address not in self.scheduler.workers
            ):
                self.scheduler.reschedule(key)
                return

            # Victim had already started execution, reverse stealing
            if state in ("memory", "executing", "long-running", None):
                self.log.append(
                    ("already-computing", key, victim.address, thief.address)
                )
                self.scheduler.check_idle_saturated(thief)
                self.scheduler.check_idle_saturated(victim)

            # Victim was waiting, has given up task, enact steal
            elif state in ("waiting", "ready"):
                self.remove_key_from_stealable(ts)
                ts.processing_on = thief
                duration = victim.processing.pop(ts)
                victim.occupancy -= duration
                self.scheduler.total_occupancy -= duration
                if not victim.processing:
                    self.scheduler.total_occupancy -= victim.occupancy
                    victim.occupancy = 0
                thief.processing[ts] = d["thief_duration"]
                thief.occupancy += d["thief_duration"]
                self.scheduler.total_occupancy += d["thief_duration"]
                self.put_key_in_stealable(ts)

                try:
                    self.scheduler.send_task_to_worker(thief.address, key)
                except CommClosedError:
                    self.scheduler.remove_worker(thief.address)
                self.log.append(("confirm", key, victim.address, thief.address))
            else:
                raise ValueError("Unexpected task state: %s" % state)
        except Exception as e:
            logger.exception(e)
            if LOG_PDB:
                import pdb

                pdb.set_trace()
            raise
        finally:
            try:
                self.scheduler.check_idle_saturated(thief)
            except Exception:
                pass
            try:
                self.scheduler.check_idle_saturated(victim)
            except Exception:
                pass

    def balance(self):
        s = self.scheduler

        def combined_occupancy(ws):
            return ws.occupancy + self.in_flight_occupancy[ws]

        def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
            occ_idl = combined_occupancy(idl)
            occ_sat = combined_occupancy(sat)

            if occ_idl + cost_multiplier * duration <= occ_sat - duration / 2:
                self.move_task_request(ts, sat, idl)
                log.append(
                    (
                        start,
                        level,
                        ts.key,
                        duration,
                        sat.address,
                        occ_sat,
                        idl.address,
                        occ_idl,
                    )
                )
                s.check_idle_saturated(sat, occ=occ_sat)
                s.check_idle_saturated(idl, occ=occ_idl)

        with log_errors():
            i = 0
            idle = s.idle
            saturated = s.saturated
            if not idle or len(idle) == len(s.workers):
                return

            log = []
            start = time()

            if not s.saturated:
                saturated = topk(10, s.workers.values(), key=combined_occupancy)
                saturated = [
                    ws
                    for ws in saturated
                    if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.ncores
                ]
            elif len(s.saturated) < 20:
                saturated = sorted(saturated, key=combined_occupancy, reverse=True)
            if len(idle) < 20:
                idle = sorted(idle, key=combined_occupancy)

            for level, cost_multiplier in enumerate(self.cost_multipliers):
                if not idle:
                    break
                for sat in list(saturated):
                    stealable = self.stealable[sat.address][level]
                    if not stealable or not idle:
                        continue

                    for ts in list(stealable):
                        if ts not in self.key_stealable or ts.processing_on is not sat:
                            stealable.discard(ts)
                            continue
                        i += 1
                        if not idle:
                            break
                        idl = idle[i % len(idle)]

                        duration = sat.processing.get(ts)
                        if duration is None:
                            stealable.discard(ts)
                            continue

                        maybe_move_task(level, ts, sat, idl, duration, cost_multiplier)

                if self.cost_multipliers[level] < 20:  # don't steal from public at cost
                    stealable = self.stealable_all[level]
                    for ts in list(stealable):
                        if not idle:
                            break
                        if ts not in self.key_stealable:
                            stealable.discard(ts)
                            continue

                        sat = ts.processing_on
                        if sat is None:
                            stealable.discard(ts)
                            continue
                        if combined_occupancy(sat) < 0.2:
                            continue
                        if len(sat.processing) <= sat.ncores:
                            continue

                        i += 1
                        idl = idle[i % len(idle)]
                        duration = sat.processing[ts]

                        maybe_move_task(level, ts, sat, idl, duration, cost_multiplier)

            if log:
                self.log.append(log)
                self.count += 1
            stop = time()
            if s.digests:
                s.digests["steal-duration"].add(stop - start)

    def restart(self, scheduler):
        for stealable in self.stealable.values():
            for s in stealable:
                s.clear()

        for s in self.stealable_all:
            s.clear()
        self.key_stealable.clear()
        self.stealable_unknown_durations.clear()

    def story(self, *keys):
        keys = set(keys)
        out = []
        for L in self.log:
            if not isinstance(L, list):
                L = [L]
            for t in L:
                if any(x in keys for x in t):
                    out.append(t)
        return out


fast_tasks = {"shuffle-split"}
