from __future__ import print_function, division, absolute_import

from contextlib import contextmanager
import sys

try:
    import ssl
except ImportError:
    ssl = None

import pytest
from tornado import gen

from distributed.comm import connect, listen
from distributed.security import Security
from distributed.utils_test import new_config, get_cert, gen_test


ca_file = get_cert('tls-ca-cert.pem')

cert1 = get_cert('tls-cert.pem')
key1 = get_cert('tls-key.pem')
keycert1 = get_cert('tls-key-cert.pem')

# Note this cipher uses RSA auth as this matches our test certs
FORCED_CIPHER = 'ECDHE-RSA-AES128-GCM-SHA256'

TLS_13_CIPHERS = [
    'TLS_AES_128_GCM_SHA256',
    'TLS_AES_256_GCM_SHA384',
    'TLS_CHACHA20_POLY1305_SHA256',
    'TLS_AES_128_CCM_SHA256',
    'TLS_AES_128_CCM_8_SHA256',
]


def test_defaults():
    with new_config({}):
        sec = Security()
    assert sec.require_encryption in (None, False)
    assert sec.tls_ca_file is None
    assert sec.tls_ciphers is None
    assert sec.tls_client_key is None
    assert sec.tls_client_cert is None
    assert sec.tls_scheduler_key is None
    assert sec.tls_scheduler_cert is None
    assert sec.tls_worker_key is None
    assert sec.tls_worker_cert is None


def test_attribute_error():
    sec = Security()
    assert hasattr(sec, 'tls_ca_file')
    with pytest.raises(AttributeError):
        sec.tls_foobar
    with pytest.raises(AttributeError):
        sec.tls_foobar = ""


def test_from_config():
    c = {
        'tls': {
            'ca-file': 'ca.pem',
            'scheduler': {
                'key': 'skey.pem',
                'cert': 'scert.pem',
            },
            'worker': {
                'cert': 'wcert.pem',
            },
            'ciphers': FORCED_CIPHER,
        },
        'require-encryption': True,
    }
    with new_config(c):
        sec = Security()
    assert sec.require_encryption is True
    assert sec.tls_ca_file == 'ca.pem'
    assert sec.tls_ciphers == FORCED_CIPHER
    assert sec.tls_client_key is None
    assert sec.tls_client_cert is None
    assert sec.tls_scheduler_key == 'skey.pem'
    assert sec.tls_scheduler_cert == 'scert.pem'
    assert sec.tls_worker_key is None
    assert sec.tls_worker_cert == 'wcert.pem'


def test_kwargs():
    c = {
        'tls': {
            'ca-file': 'ca.pem',
            'scheduler': {
                'key': 'skey.pem',
                'cert': 'scert.pem',
            },
        },
    }
    with new_config(c):
        sec = Security(tls_scheduler_cert='newcert.pem',
                       require_encryption=True,
                       tls_ca_file=None)
    assert sec.require_encryption is True
    # None value didn't override default
    assert sec.tls_ca_file == 'ca.pem'
    assert sec.tls_ciphers is None
    assert sec.tls_client_key is None
    assert sec.tls_client_cert is None
    assert sec.tls_scheduler_key == 'skey.pem'
    assert sec.tls_scheduler_cert == 'newcert.pem'
    assert sec.tls_worker_key is None
    assert sec.tls_worker_cert is None


def test_repr():
    with new_config({}):
        sec = Security(tls_ca_file='ca.pem', tls_scheduler_cert='scert.pem')
        assert repr(sec) == "Security(tls_ca_file='ca.pem', tls_scheduler_cert='scert.pem')"


def test_tls_config_for_role():
    c = {
        'tls': {
            'ca-file': 'ca.pem',
            'scheduler': {
                'key': 'skey.pem',
                'cert': 'scert.pem',
            },
            'worker': {
                'cert': 'wcert.pem',
            },
            'ciphers': FORCED_CIPHER,
        },
    }
    with new_config(c):
        sec = Security()
    t = sec.get_tls_config_for_role('scheduler')
    assert t == {
        'ca_file': 'ca.pem',
        'key': 'skey.pem',
        'cert': 'scert.pem',
        'ciphers': FORCED_CIPHER,
    }
    t = sec.get_tls_config_for_role('worker')
    assert t == {
        'ca_file': 'ca.pem',
        'key': None,
        'cert': 'wcert.pem',
        'ciphers': FORCED_CIPHER,
    }
    t = sec.get_tls_config_for_role('client')
    assert t == {
        'ca_file': 'ca.pem',
        'key': None,
        'cert': None,
        'ciphers': FORCED_CIPHER,
    }
    with pytest.raises(ValueError):
        sec.get_tls_config_for_role('supervisor')


def test_connection_args():
    def basic_checks(ctx):
        assert ctx.verify_mode == ssl.CERT_REQUIRED
        assert ctx.check_hostname is False

    def many_ciphers(ctx):
        if sys.version_info >= (3, 6):
            assert len(ctx.get_ciphers()) > 2  # Most likely

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()

    d = sec.get_connection_args('scheduler')
    assert not d['require_encryption']
    ctx = d['ssl_context']
    basic_checks(ctx)
    many_ciphers(ctx)

    d = sec.get_connection_args('worker')
    ctx = d['ssl_context']
    basic_checks(ctx)
    many_ciphers(ctx)

    # No cert defined => no TLS
    d = sec.get_connection_args('client')
    assert d.get('ssl_context') is None

    # With more settings
    c['tls']['ciphers'] = FORCED_CIPHER
    c['require-encryption'] = True

    with new_config(c):
        sec = Security()

    d = sec.get_listen_args('scheduler')
    assert d['require_encryption']
    ctx = d['ssl_context']
    basic_checks(ctx)
    if sys.version_info >= (3, 6):
        supported_ciphers = ctx.get_ciphers()
        tls_12_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.2']
        assert len(tls_12_ciphers) == 1
        tls_13_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.3']
        if len(tls_13_ciphers):
            assert len(tls_13_ciphers) == 3


def test_listen_args():
    def basic_checks(ctx):
        assert ctx.verify_mode == ssl.CERT_REQUIRED
        assert ctx.check_hostname is False

    def many_ciphers(ctx):
        if sys.version_info >= (3, 6):
            assert len(ctx.get_ciphers()) > 2  # Most likely

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()

    d = sec.get_listen_args('scheduler')
    assert not d['require_encryption']
    ctx = d['ssl_context']
    basic_checks(ctx)
    many_ciphers(ctx)

    d = sec.get_listen_args('worker')
    ctx = d['ssl_context']
    basic_checks(ctx)
    many_ciphers(ctx)

    # No cert defined => no TLS
    d = sec.get_listen_args('client')
    assert d.get('ssl_context') is None

    # With more settings
    c['tls']['ciphers'] = FORCED_CIPHER
    c['require-encryption'] = True

    with new_config(c):
        sec = Security()

    d = sec.get_listen_args('scheduler')
    assert d['require_encryption']
    ctx = d['ssl_context']
    basic_checks(ctx)
    if sys.version_info >= (3, 6):
        supported_ciphers = ctx.get_ciphers()
        tls_12_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.2']
        assert len(tls_12_ciphers) == 1
        tls_13_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.3']
        if len(tls_13_ciphers):
            assert len(tls_13_ciphers) == 3


@gen_test()
def test_tls_listen_connect():
    """
    Functional test for TLS connection args.
    """
    @gen.coroutine
    def handle_comm(comm):
        peer_addr = comm.peer_address
        assert peer_addr.startswith('tls://')
        yield comm.write('hello')
        yield comm.close()

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()

    c['tls']['ciphers'] = FORCED_CIPHER
    with new_config(c):
        forced_cipher_sec = Security()

    with listen('tls://', handle_comm,
                connection_args=sec.get_listen_args('scheduler')) as listener:
        comm = yield connect(listener.contact_address,
                             connection_args=sec.get_connection_args('worker'))
        msg = yield comm.read()
        assert msg == 'hello'
        comm.abort()

        # No SSL context for client
        with pytest.raises(TypeError):
            yield connect(listener.contact_address,
                          connection_args=sec.get_connection_args('client'))

        # Check forced cipher
        comm = yield connect(listener.contact_address,
                             connection_args=forced_cipher_sec.get_connection_args('worker'))
        cipher, _, _, = comm.extra_info['cipher']
        assert cipher in [FORCED_CIPHER] + TLS_13_CIPHERS
        comm.abort()


@gen_test()
def test_require_encryption():
    """
    Functional test for "require_encryption" setting.
    """
    @gen.coroutine
    def handle_comm(comm):
        comm.abort()

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()
    c['require-encryption'] = True
    with new_config(c):
        sec2 = Security()

    for listen_addr in ['inproc://', 'tls://']:
        with listen(listen_addr, handle_comm,
                    connection_args=sec.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec2.get_connection_args('worker'))
            comm.abort()

        with listen(listen_addr, handle_comm,
                    connection_args=sec2.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec2.get_connection_args('worker'))
            comm.abort()

    @contextmanager
    def check_encryption_error():
        with pytest.raises(RuntimeError) as excinfo:
            yield
        assert "encryption required" in str(excinfo.value)

    for listen_addr in ['tcp://']:
        with listen(listen_addr, handle_comm,
                    connection_args=sec.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec.get_connection_args('worker'))
            comm.abort()

            with pytest.raises(RuntimeError):
                yield connect(listener.contact_address,
                              connection_args=sec2.get_connection_args('worker'))

        with pytest.raises(RuntimeError):
            listen(listen_addr, handle_comm,
                   connection_args=sec2.get_listen_args('scheduler'))
