from distributed.compatibility import MutableMapping
from distributed.utils import log_errors, tokey


class PublishExtension(object):
    """ An extension for the scheduler to manage collections

    *  publish-list
    *  publish-put
    *  publish-get
    *  publish-delete
    """

    def __init__(self, scheduler):
        self.scheduler = scheduler
        self.datasets = dict()

        handlers = {
            "publish_list": self.list,
            "publish_put": self.put,
            "publish_get": self.get,
            "publish_delete": self.delete,
        }

        self.scheduler.handlers.update(handlers)
        self.scheduler.extensions["publish"] = self

    def put(self, stream=None, keys=None, data=None, name=None, client=None):
        with log_errors():
            if name in self.datasets:
                raise KeyError("Dataset %s already exists" % name)
            self.scheduler.client_desires_keys(keys, "published-%s" % tokey(name))
            self.datasets[name] = {"data": data, "keys": keys}
            return {"status": "OK", "name": name}

    def delete(self, stream=None, name=None):
        with log_errors():
            out = self.datasets.pop(name, {"keys": []})
            self.scheduler.client_releases_keys(
                out["keys"], "published-%s" % tokey(name)
            )

    def list(self, *args):
        with log_errors():
            return list(sorted(self.datasets.keys(), key=str))

    def get(self, stream, name=None, client=None):
        with log_errors():
            return self.datasets.get(name, None)


class Datasets(MutableMapping):
    """A dict-like wrapper around :class:`Client` dataset methods.

    Parameters
    ----------
    client : distributed.client.Client

    """

    def __init__(self, client):
        self.__client = client

    def __getitem__(self, key):
        return self.__client.get_dataset(key)

    def __setitem__(self, key, value):
        self.__client.publish_dataset(value, name=key)

    def __delitem__(self, key):
        self.__client.unpublish_dataset(key)

    def __iter__(self):
        for key in self.__client.list_datasets():
            yield key

    def __len__(self):
        return len(self.__client.list_datasets())
