import random

import numpy as np
import pytest

import dask.array as da
from dask.array.utils import assert_eq

sparse = pytest.importorskip('sparse')
if sparse:
    # Test failures on older versions of Numba.
    # Conda-Forge provides 0.35.0 on windows right now, causing failures like
    # searchsorted() got an unexpected keyword argument 'side'
    pytest.importorskip("numba", minversion="0.40.0")


functions = [
    lambda x: x,
    lambda x: da.expm1(x),
    lambda x: 2 * x,
    lambda x: x / 2,
    lambda x: x**2,
    lambda x: x + x,
    lambda x: x * x,
    lambda x: x[0],
    lambda x: x[:, 1],
    lambda x: x[:1, None, 1:3],
    lambda x: x.T,
    lambda x: da.transpose(x, (1, 2, 0)),
    lambda x: x.sum(),
    lambda x: x.mean(),
    lambda x: x.moment(order=0),
    pytest.param(lambda x: x.std(),
                 marks=pytest.mark.xfail(reason="fixed in https://github.com/pydata/sparse/pull/243")),
    pytest.param(lambda x: x.var(),
                 marks=pytest.mark.xfail(reason="fixed in https://github.com/pydata/sparse/pull/243")),
    lambda x: x.dot(np.arange(x.shape[-1])),
    lambda x: x.dot(np.eye(x.shape[-1])),
    lambda x: da.tensordot(x, np.ones(x.shape[:2]), axes=[(0, 1), (0, 1)]),
    lambda x: x.sum(axis=0),
    lambda x: x.max(axis=0),
    lambda x: x.sum(axis=(1, 2)),
    lambda x: x.astype(np.complex128),
    lambda x: x.map_blocks(lambda x: x * 2),
    lambda x: x.round(1),
    lambda x: x.reshape((x.shape[0] * x.shape[1], x.shape[2])),
    lambda x: abs(x),
    lambda x: x > 0.5,
    lambda x: x.rechunk((4, 4, 4)),
    lambda x: x.rechunk((2, 2, 1)),
]


@pytest.mark.parametrize('func', functions)
def test_basic(func):
    x = da.random.random((2, 3, 4), chunks=(1, 2, 2))
    x[x < 0.8] = 0

    y = x.map_blocks(sparse.COO.from_numpy)

    xx = func(x)
    yy = func(y)

    assert_eq(xx, yy)

    if yy.shape:
        zz = yy.compute()
        if not isinstance(zz, sparse.COO):
            assert (zz != 1).sum() > np.prod(zz.shape) / 2  # mostly dense


def test_tensordot():
    x = da.random.random((2, 3, 4), chunks=(1, 2, 2))
    x[x < 0.8] = 0
    y = da.random.random((4, 3, 2), chunks=(2, 2, 1))
    y[y < 0.8] = 0

    xx = x.map_blocks(sparse.COO.from_numpy)
    yy = y.map_blocks(sparse.COO.from_numpy)

    assert_eq(da.tensordot(x, y, axes=(2, 0)),
              da.tensordot(xx, yy, axes=(2, 0)))
    assert_eq(da.tensordot(x, y, axes=(1, 1)),
              da.tensordot(xx, yy, axes=(1, 1)))
    assert_eq(da.tensordot(x, y, axes=((1, 2), (1, 0))),
              da.tensordot(xx, yy, axes=((1, 2), (1, 0))))


@pytest.mark.xfail(reason="upstream change", strict=False)
@pytest.mark.parametrize('func', functions)
def test_mixed_concatenate(func):
    x = da.random.random((2, 3, 4), chunks=(1, 2, 2))

    y = da.random.random((2, 3, 4), chunks=(1, 2, 2))
    y[y < 0.8] = 0
    yy = y.map_blocks(sparse.COO.from_numpy)

    d = da.concatenate([x, y], axis=0)
    s = da.concatenate([x, yy], axis=0)

    dd = func(d)
    ss = func(s)

    assert_eq(dd, ss)


@pytest.mark.xfail(reason="upstream change", strict=False)
@pytest.mark.parametrize('func', functions)
def test_mixed_random(func):
    d = da.random.random((4, 3, 4), chunks=(1, 2, 2))
    d[d < 0.7] = 0

    fn = lambda x: sparse.COO.from_numpy(x) if random.random() < 0.5 else x
    s = d.map_blocks(fn)

    dd = func(d)
    ss = func(s)

    assert_eq(dd, ss)


@pytest.mark.xfail(reason="upstream change", strict=False)
def test_mixed_output_type():
    y = da.random.random((10, 10), chunks=(5, 5))
    y[y < 0.8] = 0
    y = y.map_blocks(sparse.COO.from_numpy)

    x = da.zeros((10, 1), chunks=(5, 1))

    z = da.concatenate([x, y], axis=1)

    assert z.shape == (10, 11)

    zz = z.compute()
    assert isinstance(zz, sparse.COO)
    assert zz.nnz == y.compute().nnz
