"""Tests for kernel connection utilities"""

# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import json
import os
import re
import stat
import tempfile
import shutil

from traitlets.config import Config
from jupyter_core.application import JupyterApp
from jupyter_core.paths import jupyter_runtime_dir
from ipython_genutils.tempdir import TemporaryDirectory, TemporaryWorkingDirectory
from ipython_genutils.py3compat import str_to_bytes
from jupyter_client import connect, KernelClient
from jupyter_client.consoleapp import JupyterConsoleApp
from jupyter_client.session import Session
from jupyter_client.connect import secure_write


class DummyConsoleApp(JupyterApp, JupyterConsoleApp):
    def initialize(self, argv=[]):
        JupyterApp.initialize(self, argv=argv)
        self.init_connection_file()

class DummyConfigurable(connect.ConnectionFileMixin):
    def initialize(self):
        pass

sample_info = dict(ip='1.2.3.4', transport='ipc',
        shell_port=1, hb_port=2, iopub_port=3, stdin_port=4, control_port=5,
        key=b'abc123', signature_scheme='hmac-md5', kernel_name='python'
    )

sample_info_kn = dict(ip='1.2.3.4', transport='ipc',
        shell_port=1, hb_port=2, iopub_port=3, stdin_port=4, control_port=5,
        key=b'abc123', signature_scheme='hmac-md5', kernel_name='test'
    )


def test_write_connection_file():
    with TemporaryDirectory() as d:
        cf = os.path.join(d, 'kernel.json')
        connect.write_connection_file(cf, **sample_info)
        assert os.path.exists(cf)
        with open(cf, 'r') as f:
            info = json.load(f)
    info['key'] = str_to_bytes(info['key'])
    assert info == sample_info


def test_load_connection_file_session():
    """test load_connection_file() after """
    session = Session()
    app = DummyConsoleApp(session=Session())
    app.initialize(argv=[])
    session = app.session

    with TemporaryDirectory() as d:
        cf = os.path.join(d, 'kernel.json')
        connect.write_connection_file(cf, **sample_info)
        app.connection_file = cf
        app.load_connection_file()

    assert session.key == sample_info['key']
    assert session.signature_scheme == sample_info['signature_scheme']


def test_load_connection_file_session_with_kn():
    """test load_connection_file() after """
    session = Session()
    app = DummyConsoleApp(session=Session())
    app.initialize(argv=[])
    session = app.session

    with TemporaryDirectory() as d:
        cf = os.path.join(d, 'kernel.json')
        connect.write_connection_file(cf, **sample_info_kn)
        app.connection_file = cf
        app.load_connection_file()

    assert session.key == sample_info_kn['key']
    assert session.signature_scheme == sample_info_kn['signature_scheme']


def test_app_load_connection_file():
    """test `ipython console --existing` loads a connection file"""
    with TemporaryDirectory() as d:
        cf = os.path.join(d, 'kernel.json')
        connect.write_connection_file(cf, **sample_info)
        app = DummyConsoleApp(connection_file=cf)
        app.initialize(argv=[])

    for attr, expected in sample_info.items():
        if attr in ('key', 'signature_scheme'):
            continue
        value = getattr(app, attr)
        assert value == expected, "app.%s = %s != %s" % (attr, value, expected)


def test_load_connection_info():
    client = KernelClient()
    info = {
        'control_port': 53702,
        'hb_port': 53705,
        'iopub_port': 53703,
        'ip': '0.0.0.0',
        'key': 'secret',
        'shell_port': 53700,
        'signature_scheme': 'hmac-sha256',
        'stdin_port': 53701,
        'transport': 'tcp',
    }
    client.load_connection_info(info)
    assert client.control_port == info['control_port']
    assert client.session.key.decode('ascii') == info['key']
    assert client.ip == info['ip']


def test_find_connection_file():
    with TemporaryDirectory() as d:
        cf = 'kernel.json'
        app = DummyConsoleApp(runtime_dir=d, connection_file=cf)
        app.initialize()

        security_dir = app.runtime_dir
        profile_cf = os.path.join(security_dir, cf)

        with open(profile_cf, 'w') as f:
            f.write("{}")

        for query in (
            'kernel.json',
            'kern*',
            '*ernel*',
            'k*',
            ):
            assert connect.find_connection_file(query, path=security_dir) == profile_cf


def test_find_connection_file_local():
    with TemporaryWorkingDirectory() as d:
        cf = 'test.json'
        abs_cf = os.path.abspath(cf)
        with open(cf, 'w') as f:
            f.write('{}')

        for query in (
            'test.json',
            'test',
            abs_cf,
            os.path.join('.', 'test.json'),
        ):
            assert connect.find_connection_file(query, path=['.', jupyter_runtime_dir()]) == abs_cf


def test_find_connection_file_relative():
    with TemporaryWorkingDirectory() as d:
        cf = 'test.json'
        os.mkdir('subdir')
        cf = os.path.join('subdir', 'test.json')
        abs_cf = os.path.abspath(cf)
        with open(cf, 'w') as f:
            f.write('{}')

        for query in (
            os.path.join('.', 'subdir', 'test.json'),
            os.path.join('subdir', 'test.json'),
            abs_cf,
        ):
            assert connect.find_connection_file(query, path=['.', jupyter_runtime_dir()]) == abs_cf


def test_find_connection_file_abspath():
    with TemporaryDirectory() as d:
        cf = 'absolute.json'
        abs_cf = os.path.abspath(cf)
        with open(cf, 'w') as f:
            f.write('{}')
        assert connect.find_connection_file(abs_cf, path=jupyter_runtime_dir()) == abs_cf
        os.remove(abs_cf)


def test_mixin_record_random_ports():
    with TemporaryDirectory() as d:
        dc = DummyConfigurable(data_dir=d, kernel_name='via-tcp', transport='tcp')
        dc.write_connection_file()

        assert dc._connection_file_written
        assert os.path.exists(dc.connection_file)
        assert dc._random_port_names == connect.port_names


def test_mixin_cleanup_random_ports():
    with TemporaryDirectory() as d:
        dc = DummyConfigurable(data_dir=d, kernel_name='via-tcp', transport='tcp')
        dc.write_connection_file()
        filename = dc.connection_file
        dc.cleanup_random_ports()

        assert not os.path.exists(filename)
        for name in dc._random_port_names:
            assert getattr(dc, name) == 0


def test_secure_write():
    def fetch_win32_permissions(filename):
        '''Extracts file permissions on windows using icacls'''
        role_permissions = {}
        for index, line in enumerate(os.popen("icacls %s" % filename).read().splitlines()):
            if index == 0:
                line = line.split(filename)[-1].strip().lower()
            match = re.match(r'\s*([^:]+):\(([^\)]*)\)', line)
            if match:
                usergroup, permissions = match.groups()
                usergroup = usergroup.lower().split('\\')[-1]
                permissions = set(p.lower() for p in permissions.split(','))
                role_permissions[usergroup] = permissions
            elif not line.strip():
                break
        return role_permissions

    def check_user_only_permissions(fname):
        # Windows has it's own permissions ACL patterns
        if os.name == 'nt':
            import win32api
            username = win32api.GetUserName().lower()
            permissions = fetch_win32_permissions(fname)
            assert username in permissions
            assert permissions[username] == set(['r', 'w'])
            assert 'administrators' in permissions
            assert permissions['administrators'] == set(['f'])
            assert 'everyone' not in permissions
            assert len(permissions) == 2
        else:
            mode = os.stat(fname).st_mode
            assert '0600' == oct(stat.S_IMODE(mode)).replace('0o', '0')

    directory = tempfile.mkdtemp()
    fname = os.path.join(directory, 'check_perms')
    try:
        with secure_write(fname) as f:
            f.write('test 1')
        check_user_only_permissions(fname)
        with open(fname, 'r') as f:
            assert f.read() == 'test 1'

        # Try changing file permissions ahead of time
        os.chmod(fname, 0o755)
        with secure_write(fname) as f:
            f.write('test 2')
        check_user_only_permissions(fname)
        with open(fname, 'r') as f:
            assert f.read() == 'test 2'
    finally:
        shutil.rmtree(directory)
