from collections import namedtuple

import pytest
import pickle

from dask.utils_test import GetFunctionTestMixin, inc, add
from dask import core
from dask.core import (istask, get_dependencies, get_deps, flatten, subs,
                       preorder_traversal, literal, quote, has_tasks)


def contains(a, b):
    """

    >>> contains({'x': 1, 'y': 2}, {'x': 1})
    True
    >>> contains({'x': 1, 'y': 2}, {'z': 3})
    False
    """
    return all(a.get(k) == v for k, v in b.items())


def test_istask():
    assert istask((inc, 1))
    assert not istask(1)
    assert not istask((1, 2))
    f = namedtuple('f', ['x', 'y'])
    assert not istask(f(sum, 2))


def test_has_tasks():
    dsk = {'a': [1, 2, 3],
           'b': 'a',
           'c': [1, (inc, 1)],
           'd': [(sum, 'a')],
           'e': ['a', 'b'],
           'f': [['a', 'b'], 2, 3]}
    assert not has_tasks(dsk, dsk['a'])
    assert has_tasks(dsk, dsk['b'])
    assert has_tasks(dsk, dsk['c'])
    assert has_tasks(dsk, dsk['d'])
    assert has_tasks(dsk, dsk['e'])
    assert has_tasks(dsk, dsk['f'])


def test_preorder_traversal():
    t = (add, 1, 2)
    assert list(preorder_traversal(t)) == [add, 1, 2]
    t = (add, (add, 1, 2), (add, 3, 4))
    assert list(preorder_traversal(t)) == [add, add, 1, 2, add, 3, 4]
    t = (add, (sum, [1, 2]), 3)
    assert list(preorder_traversal(t)) == [add, sum, list, 1, 2, 3]


class TestGet(GetFunctionTestMixin):
    get = staticmethod(core.get)


def test_GetFunctionTestMixin_class():
    class TestCustomGetFail(GetFunctionTestMixin):
        get = staticmethod(lambda x, y: 1)

    custom_testget = TestCustomGetFail()
    pytest.raises(AssertionError, custom_testget.test_get)

    class TestCustomGetPass(GetFunctionTestMixin):
        get = staticmethod(core.get)

    custom_testget = TestCustomGetPass()
    custom_testget.test_get()


def test_get_dependencies_nested():
    dsk = {'x': 1, 'y': 2,
           'z': (add, (inc, [['x']]), 'y')}

    assert get_dependencies(dsk, 'z') == set(['x', 'y'])
    assert sorted(get_dependencies(dsk, 'z', as_list=True)) == ['x', 'y']


def test_get_dependencies_empty():
    dsk = {'x': (inc,)}
    assert get_dependencies(dsk, 'x') == set()
    assert get_dependencies(dsk, 'x', as_list=True) == []


def test_get_dependencies_list():
    dsk = {'x': 1, 'y': 2, 'z': ['x', [(inc, 'y')]]}
    assert get_dependencies(dsk, 'z') == set(['x', 'y'])
    assert sorted(get_dependencies(dsk, 'z', as_list=True)) == ['x', 'y']


def test_get_dependencies_task():
    dsk = {'x': 1, 'y': 2, 'z': ['x', [(inc, 'y')]]}
    assert get_dependencies(dsk, task=(inc, 'x')) == set(['x'])
    assert get_dependencies(dsk, task=(inc, 'x'), as_list=True) == ['x']


def test_get_dependencies_nothing():
    with pytest.raises(ValueError):
        get_dependencies({})


def test_get_dependencies_many():
    dsk = {'a': [1, 2, 3],
           'b': 'a',
           'c': [1, (inc, 1)],
           'd': [(sum, 'c')],
           'e': ['a', 'b', 'zzz'],
           'f': [['a', 'b'], 2, 3]}

    tasks = [dsk[k] for k in ('d', 'f')]
    s = get_dependencies(dsk, task=tasks)
    assert s == {'a', 'b', 'c'}
    s = get_dependencies(dsk, task=tasks, as_list=True)
    assert sorted(s) == ['a', 'b', 'c']

    s = get_dependencies(dsk, task=[])
    assert s == set()
    s = get_dependencies(dsk, task=[], as_list=True)
    assert s == []


def test_get_deps():
    """
    >>> dsk = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')}
    >>> dependencies, dependents = get_deps(dsk)
    >>> dependencies
    {'a': set(), 'b': {'a'}, 'c': {'b'}}
    >>> dependents
    {'a': {'b'}, 'b': {'c'}, 'c': set()}
    """
    dsk = {'a': [1, 2, 3],
           'b': 'a',
           'c': [1, (inc, 1)],
           'd': [(sum, 'c')],
           'e': ['b', 'zzz', 'b'],
           'f': [['a', 'b'], 2, 3]}
    dependencies, dependents = get_deps(dsk)
    assert dependencies == {'a': set(),
                            'b': {'a'},
                            'c': set(),
                            'd': {'c'},
                            'e': {'b'},
                            'f': {'a', 'b'},
                            }
    assert dependents == {'a': {'b', 'f'},
                          'b': {'e', 'f'},
                          'c': {'d'},
                          'd': set(),
                          'e': set(),
                          'f': set(),
                          }


def test_flatten():
    assert list(flatten(())) == []
    assert list(flatten('foo')) == ['foo']


def test_subs():
    assert subs((sum, [1, 'x']), 'x', 2) == (sum, [1, 2])
    assert subs((sum, [1, ['x']]), 'x', 2) == (sum, [1, [2]])


class MutateOnEq(object):
    hit_eq = 0

    def __eq__(self, other):
        self.hit_eq += 1
        return False


def test_subs_no_key_data_eq():
    # Numpy throws a deprecation warning on bool(array == scalar), which
    # pollutes the terminal. This test checks that `subs` never tries to
    # compare keys (scalars) with values (which could be arrays)`subs` never
    # tries to compare keys (scalars) with values (which could be arrays).
    a = MutateOnEq()
    subs(a, 'x', 1)
    assert a.hit_eq == 0
    subs((add, a, 'x'), 'x', 1)
    assert a.hit_eq == 0


def test_subs_with_unfriendly_eq():
    try:
        import numpy as np
    except ImportError:
        return
    else:
        task = (np.sum, np.array([1, 2]))
        assert (subs(task, (4, 5), 1) == task) is True

    class MyException(Exception):
        pass

    class F():
        def __eq__(self, other):
            raise MyException()

    task = F()
    assert subs(task, 1, 2) is task


def test_subs_with_surprisingly_friendly_eq():
    try:
        import pandas as pd
    except ImportError:
        return
    else:
        df = pd.DataFrame()
        assert subs(df, 'x', 1) is df


def test_subs_unexpected_hashable_key():
    class UnexpectedButHashable:
        def __init__(self):
            self.name = "a"

        def __hash__(self):
            return hash(self.name)

        def __eq__(self, other):
            return isinstance(other, UnexpectedButHashable)

    assert subs((id, UnexpectedButHashable()), UnexpectedButHashable(), 1) == (id, 1)


def test_quote():
    literals = [[1, 2, 3], (add, 1, 2),
                [1, [2, 3]], (add, 1, (add, 2, 3))]

    for l in literals:
        assert core.get({'x': quote(l)}, 'x') == l


def test_literal_serializable():
    l = literal((add, 1, 2))
    assert pickle.loads(pickle.dumps(l)).data == (add, 1, 2)
