from __future__ import absolute_import, division, print_function


def inc(x):
    return x + 1


def dec(x):
    return x - 1


def add(x, y):
    return x + y


class GetFunctionTestMixin(object):
    """
    The GetFunctionTestCase class can be imported and used to test foreign
    implementations of the `get` function specification. It aims to enforce all
    known expectations of `get` functions.

    To use the class, inherit from it and override the `get` function. For
    example:

    > from dask.utils_test import GetFunctionTestMixin
    > class TestCustomGet(GetFunctionTestMixin):
         get = staticmethod(myget)

    Note that the foreign `myget` function has to be explicitly decorated as a
    staticmethod.
    """
    def test_get(self):
        d = {':x': 1,
             ':y': (inc, ':x'),
             ':z': (add, ':x', ':y')}

        assert self.get(d, ':x') == 1
        assert self.get(d, ':y') == 2
        assert self.get(d, ':z') == 3

    def test_badkey(self):
        d = {':x': 1,
             ':y': (inc, ':x'),
             ':z': (add, ':x', ':y')}
        try:
            result = self.get(d, 'badkey')
        except KeyError:
            pass
        else:
            msg = 'Expected `{}` with badkey to raise KeyError.\n'
            msg += "Obtained '{}' instead.".format(result)
            assert False, msg.format(self.get.__name__)

    def test_nested_badkey(self):
        d = {'x': 1, 'y': 2, 'z': (sum, ['x', 'y'])}

        try:
            result = self.get(d, [['badkey'], 'y'])
        except KeyError:
            pass
        else:
            msg = 'Expected `{}` with badkey to raise KeyError.\n'
            msg += "Obtained '{}' instead.".format(result)
            assert False, msg.format(self.get.__name__)

    def test_data_not_in_dict_is_ok(self):
        d = {'x': 1, 'y': (add, 'x', 10)}
        assert self.get(d, 'y') == 11

    def test_get_with_list(self):
        d = {'x': 1, 'y': 2, 'z': (sum, ['x', 'y'])}

        assert self.get(d, ['x', 'y']) == (1, 2)
        assert self.get(d, 'z') == 3

    def test_get_with_list_top_level(self):
        d = {'a': [1, 2, 3],
             'b': 'a',
             'c': [1, (inc, 1)],
             'd': [(sum, 'a')],
             'e': ['a', 'b'],
             'f': [[[(sum, 'a'), 'c'], (sum, 'b')], 2]}
        assert self.get(d, 'a') == [1, 2, 3]
        assert self.get(d, 'b') == [1, 2, 3]
        assert self.get(d, 'c') == [1, 2]
        assert self.get(d, 'd') == [6]
        assert self.get(d, 'e') == [[1, 2, 3], [1, 2, 3]]
        assert self.get(d, 'f') == [[[6, [1, 2]], 6], 2]

    def test_get_with_nested_list(self):
        d = {'x': 1, 'y': 2, 'z': (sum, ['x', 'y'])}

        assert self.get(d, [['x'], 'y']) == ((1,), 2)
        assert self.get(d, 'z') == 3

    def test_get_works_with_unhashables_in_values(self):
        f = lambda x, y: x + len(y)
        d = {'x': 1, 'y': (f, 'x', set([1]))}

        assert self.get(d, 'y') == 2

    def test_nested_tasks(self):
        d = {'x': 1,
             'y': (inc, 'x'),
             'z': (add, (inc, 'x'), 'y')}

        assert self.get(d, 'z') == 4

    def test_get_stack_limit(self):
        d = {'x%d' % (i + 1): (inc, 'x%d' % i) for i in range(10000)}
        d['x0'] = 0
        assert self.get(d, 'x10000') == 10000

    def test_with_HighLevelGraph(self):
        from .highlevelgraph import HighLevelGraph

        layers = {'a': {'x': 1,
                        'y': (inc, 'x')},
                  'b': {'z': (add, (inc, 'x'), 'y')}}
        dependencies = {'a': (), 'b': {'a'}}
        graph = HighLevelGraph(layers, dependencies)
        assert self.get(graph, 'z') == 4
