from __future__ import print_function, division, absolute_import

try:
    import ssl
except ImportError:
    ssl = None

import dask


_roles = ["client", "scheduler", "worker"]

_tls_per_role_fields = ["key", "cert"]

_tls_fields = ["ca_file", "ciphers"]

_misc_fields = ["require_encryption"]

_fields = set(
    _misc_fields
    + ["tls_%s" % field for field in _tls_fields]
    + ["tls_%s_%s" % (role, field) for role in _roles for field in _tls_per_role_fields]
)


def _field_to_config_key(field):
    return field.replace("_", "-")


class Security(object):
    """
    An object to gather and pass around security configuration.
    Default values are gathered from the global ``config`` object and
    can be overriden by constructor args.

    Supported fields:
        - require_encryption
        - tls_ca_file
        - tls_ciphers
        - tls_client_key
        - tls_client_cert
        - tls_scheduler_key
        - tls_scheduler_cert
        - tls_worker_key
        - tls_worker_cert
    """

    __slots__ = tuple(_fields)

    def __init__(self, **kwargs):
        self._init_from_dict(dask.config.config)
        for k, v in kwargs.items():
            if v is not None:
                setattr(self, k, v)
        for k in _fields:
            if not hasattr(self, k):
                setattr(self, k, None)

    def _init_from_dict(self, d):
        """
        Initialize Security from nested dict.
        """
        self._init_fields_from_dict(d, "", _misc_fields, {})
        self._init_fields_from_dict(d, "tls", _tls_fields, _tls_per_role_fields)

    def _init_fields_from_dict(self, d, category, fields, per_role_fields):
        if category:
            d = d.get(category, {})
            category_prefix = category + "_"
        else:
            category_prefix = ""
        for field in fields:
            k = _field_to_config_key(field)
            if k in d:
                setattr(self, "%s%s" % (category_prefix, field), d[k])
        for role in _roles:
            dd = d.get(role, {})
            for field in per_role_fields:
                k = _field_to_config_key(field)
                if k in dd:
                    setattr(self, "%s%s_%s" % (category_prefix, role, field), dd[k])

    def __repr__(self):
        items = sorted((k, getattr(self, k)) for k in _fields)
        return (
            "Security("
            + ", ".join("%s=%r" % (k, v) for k, v in items if v is not None)
            + ")"
        )

    def get_tls_config_for_role(self, role):
        """
        Return the TLS configuration for the given role, as a flat dict.
        """
        return self._get_config_for_role("tls", role, _tls_fields, _tls_per_role_fields)

    def _get_config_for_role(self, category, role, fields, per_role_fields):
        if role not in _roles:
            raise ValueError("unknown role %r" % (role,))
        d = {}
        for field in fields:
            k = "%s_%s" % (category, field)
            d[field] = getattr(self, k)
        for field in per_role_fields:
            k = "%s_%s_%s" % (category, role, field)
            d[field] = getattr(self, k)
        return d

    def _get_tls_context(self, tls, purpose):
        if tls.get("ca_file") and tls.get("cert"):
            try:
                ctx = ssl.create_default_context(purpose=purpose, cafile=tls["ca_file"])
            except AttributeError:
                raise RuntimeError("TLS functionality requires Python 2.7.9+")
            ctx.verify_mode = ssl.CERT_REQUIRED
            # We expect a dedicated CA for the cluster and people using
            # IP addresses rather than hostnames
            ctx.check_hostname = False
            ctx.load_cert_chain(tls["cert"], tls.get("key"))
            if tls.get("ciphers"):
                ctx.set_ciphers(tls.get("ciphers"))
            return ctx

    def get_connection_args(self, role):
        """
        Get the *connection_args* argument for a connect() call with
        the given *role*.
        """
        d = {}
        tls = self.get_tls_config_for_role(role)
        # Ensure backwards compatibility (ssl.Purpose is Python 2.7.9+ only)
        purpose = ssl.Purpose.SERVER_AUTH if hasattr(ssl, "Purpose") else None
        d["ssl_context"] = self._get_tls_context(tls, purpose)
        d["require_encryption"] = self.require_encryption
        return d

    def get_listen_args(self, role):
        """
        Get the *connection_args* argument for a listen() call with
        the given *role*.
        """
        d = {}
        tls = self.get_tls_config_for_role(role)
        # Ensure backwards compatibility (ssl.Purpose is Python 2.7.9+ only)
        purpose = ssl.Purpose.CLIENT_AUTH if hasattr(ssl, "Purpose") else None
        d["ssl_context"] = self._get_tls_context(tls, purpose)
        d["require_encryption"] = self.require_encryption
        return d
