from __future__ import absolute_import, division, print_function

import operator
from functools import partial, wraps
from itertools import product, repeat
from math import factorial, log, ceil

import numpy as np
from numbers import Integral

from toolz import compose, partition_all, get, accumulate, pluck

from . import chunk
from .core import _concatenate2, Array, handle_out
from .blockwise import blockwise
from ..blockwise import lol_tuples
from .creation import arange, diagonal
from .ufunc import sqrt
from .utils import validate_axis
from .wrap import zeros, ones
from .numpy_compat import ma_divide, divide as np_divide
from ..compatibility import getargspec, builtins
from ..base import tokenize
from ..highlevelgraph import HighLevelGraph
from ..utils import ignoring, funcname, Dispatch, deepmap
from .. import config

# Generic functions to support chunks of different types
empty_lookup = Dispatch('empty')
empty_lookup.register((object, np.ndarray), np.empty)
empty_lookup.register(np.ma.masked_array, np.ma.empty)
divide_lookup = Dispatch('divide')
divide_lookup.register((object, np.ndarray), np_divide)
divide_lookup.register(np.ma.masked_array, ma_divide)


def divide(a, b, dtype=None):
    key = lambda x: getattr(x, '__array_priority__', float('-inf'))
    f = divide_lookup.dispatch(type(builtins.max(a, b, key=key)))
    return f(a, b, dtype=dtype)


def reduction(x, chunk, aggregate, axis=None, keepdims=False, dtype=None,
              split_every=None, combine=None, name=None, out=None,
              concatenate=True, output_size=1):
    """ General version of reductions

    Parameters
    ----------
    x: Array
        Data being reduced along one or more axes
    chunk: callable(x_chunk, axis, keepdims)
        First function to be executed when resolving the dask graph.
        This function is applied in parallel to all original chunks of x.
        See below for function parameters.
    combine: callable(x_chunk, axis, keepdims), optional
        Function used for intermediate recursive aggregation (see
        split_every below). If omitted, it defaults to aggregate.
        If the reduction can be performed in less than 3 steps, it will not
        be invoked at all.
    aggregate: callable(x_chunk, axis, keepdims)
        Last function to be executed when resolving the dask graph,
        producing the final output. It is always invoked, even when the reduced
        Array counts a single chunk along the reduced axes.
    axis: int or sequence of ints, optional
        Axis or axes to aggregate upon. If omitted, aggregate along all axes.
    keepdims: boolean, optional
        Whether the reduction function should preserve the reduced axes,
        leaving them at size ``output_size``, or remove them.
    dtype: np.dtype, optional
        Force output dtype. Defaults to x.dtype if omitted.
    split_every: int >= 2 or dict(axis: int), optional
        Determines the depth of the recursive aggregation. If set to or more
        than the number of input chunks, the aggregation will be performed in
        two steps, one ``chunk`` function per input chunk and a single
        ``aggregate`` function at the end. If set to less than that, an
        intermediate ``combine`` function will be used, so that any one
        ``combine`` or ``aggregate`` function has no more than ``split_every``
        inputs. The depth of the aggregation graph will be
        :math:`log_{split_every}(input chunks along reduced axes)`. Setting to
        a low value can reduce cache size and network transfers, at the cost of
        more CPU and a larger dask graph.

        Omit to let dask heuristically decide a good default. A default can
        also be set globally with the ``split_every`` key in
        :mod:`dask.config`.
    name: str, optional
        Prefix of the keys of the intermediate and output nodes. If omitted it
        defaults to the function names.
    out: Array, optional
        Another dask array whose contents will be replaced. Omit to create a
        new one. Note that, unlike in numpy, this setting gives no performance
        benefits whatsoever, but can still be useful  if one needs to preserve
        the references to a previously existing Array.
    concatenate: bool, optional
        If True (the default), the outputs of the ``chunk``/``combine``
        functions are concatenated into a single np.array before being passed
        to the ``combine``/``aggregate`` functions. If False, the input of
        ``combine`` and ``aggregate`` will be either a list of the raw outputs
        of the previous step or a single output, and the function will have to
        concatenate it itself. It can be useful to set this to False if the
        chunk and/or combine steps do not produce np.arrays.
    output_size: int >= 1, optional
        Size of the output of the ``aggregate`` function along the reduced
        axes. Ignored if keepdims is False.

    Returns
    -------
    dask array

    **Function Parameters**

    x_chunk: numpy.ndarray
        Individual input chunk. For ``chunk`` functions, it is one of the
        original chunks of x. For ``combine`` and ``aggregate`` functions, it's
        the concatenation of the outputs produced by the previous ``chunk`` or
        ``combine`` functions. If concatenate=False, it's a list of the raw
        outputs from the previous functions.
    axis: tuple
        Normalized list of axes to reduce upon, e.g. ``(0, )``
        Scalar, negative, and None axes have been normalized away.
        Note that some numpy reduction functions cannot reduce along multiple
        axes at once and strictly require an int in input. Such functions have
        to be wrapped to cope.
    keepdims: bool
        Whether the reduction function should preserve the reduced axes or
        remove them.
    """
    if axis is None:
        axis = tuple(range(x.ndim))
    if isinstance(axis, Integral):
        axis = (axis,)
    axis = validate_axis(axis, x.ndim)

    if dtype is None:
        raise ValueError("Must specify dtype")
    if 'dtype' in getargspec(chunk).args:
        chunk = partial(chunk, dtype=dtype)
    if 'dtype' in getargspec(aggregate).args:
        aggregate = partial(aggregate, dtype=dtype)

    # Map chunk across all blocks
    inds = tuple(range(x.ndim))
    # The dtype of `tmp` doesn't actually matter, and may be incorrect.
    tmp = blockwise(chunk, inds, x, inds, axis=axis, keepdims=True, dtype=x.dtype)
    tmp._chunks = tuple((output_size, ) * len(c) if i in axis else c
                        for i, c in enumerate(tmp.chunks))
    result = _tree_reduce(tmp, aggregate, axis, keepdims, dtype, split_every,
                          combine, name=name, concatenate=concatenate)
    if keepdims and output_size != 1:
        result._chunks = tuple((output_size, ) if i in axis else c
                               for i, c in enumerate(tmp.chunks))
    return handle_out(out, result)


def _tree_reduce(x, aggregate, axis, keepdims, dtype, split_every=None,
                 combine=None, name=None, concatenate=True):
    """ Perform the tree reduction step of a reduction.

    Lower level, users should use ``reduction`` or ``arg_reduction`` directly.
    """
    # Normalize split_every
    split_every = split_every or config.get('split_every', 4)
    if isinstance(split_every, dict):
        split_every = dict((k, split_every.get(k, 2)) for k in axis)
    elif isinstance(split_every, Integral):
        n = builtins.max(int(split_every ** (1 / (len(axis) or 1))), 2)
        split_every = dict.fromkeys(axis, n)
    else:
        raise ValueError("split_every must be a int or a dict")

    # Reduce across intermediates
    depth = 1
    for i, n in enumerate(x.numblocks):
        if i in split_every and split_every[i] != 1:
            depth = int(builtins.max(depth, ceil(log(n, split_every[i]))))
    func = partial(combine or aggregate, axis=axis, keepdims=True)
    if concatenate:
        func = compose(func, partial(_concatenate2, axes=axis))
    for i in range(depth - 1):
        x = partial_reduce(func, x, split_every, True, dtype=dtype,
                           name=(name or funcname(combine or aggregate)) + '-partial')
    func = partial(aggregate, axis=axis, keepdims=keepdims)
    if concatenate:
        func = compose(func, partial(_concatenate2, axes=axis))
    return partial_reduce(func, x, split_every, keepdims=keepdims, dtype=dtype,
                          name=(name or funcname(aggregate)) + '-aggregate')


def partial_reduce(func, x, split_every, keepdims=False, dtype=None, name=None):
    """ Partial reduction across multiple axes.

    Parameters
    ----------
    func : function
    x : Array
    split_every : dict
        Maximum reduction block sizes in each dimension.

    Examples
    --------
    Reduce across axis 0 and 2, merging a maximum of 1 block in the 0th
    dimension, and 3 blocks in the 2nd dimension:

    >>> partial_reduce(np.min, x, {0: 1, 2: 3})    # doctest: +SKIP
    """
    name = (name or funcname(func)) + '-' + tokenize(func, x, split_every,
                                                     keepdims, dtype)
    parts = [list(partition_all(split_every.get(i, 1), range(n))) for (i, n)
             in enumerate(x.numblocks)]
    keys = product(*map(range, map(len, parts)))
    out_chunks = [tuple(1 for p in partition_all(split_every[i], c)) if i
                  in split_every else c for (i, c) in enumerate(x.chunks)]
    if not keepdims:
        out_axis = [i for i in range(x.ndim) if i not in split_every]
        getter = lambda k: get(out_axis, k)
        keys = map(getter, keys)
        out_chunks = list(getter(out_chunks))
    dsk = {}
    for k, p in zip(keys, product(*parts)):
        decided = dict((i, j[0]) for (i, j) in enumerate(p) if len(j) == 1)
        dummy = dict(i for i in enumerate(p) if i[0] not in decided)
        g = lol_tuples((x.name,), range(x.ndim), decided, dummy)
        dsk[(name,) + k] = (func, g)
    graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
    return Array(graph, name, out_chunks, dtype=dtype)


@wraps(chunk.sum)
def sum(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
    if dtype is not None:
        dt = dtype
    else:
        dt = getattr(np.empty((1,), dtype=a.dtype).sum(), 'dtype', object)
    result = reduction(a, chunk.sum, chunk.sum, axis=axis, keepdims=keepdims,
                       dtype=dt, split_every=split_every, out=out)
    return result


@wraps(chunk.prod)
def prod(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
    if dtype is not None:
        dt = dtype
    else:
        dt = getattr(np.empty((1,), dtype=a.dtype).prod(), 'dtype', object)
    return reduction(a, chunk.prod, chunk.prod, axis=axis, keepdims=keepdims,
                     dtype=dt, split_every=split_every, out=out)


@wraps(chunk.min)
def min(a, axis=None, keepdims=False, split_every=None, out=None):
    return reduction(a, chunk.min, chunk.min, axis=axis, keepdims=keepdims,
                     dtype=a.dtype, split_every=split_every, out=out)


@wraps(chunk.max)
def max(a, axis=None, keepdims=False, split_every=None, out=None):
    return reduction(a, chunk.max, chunk.max, axis=axis, keepdims=keepdims,
                     dtype=a.dtype, split_every=split_every, out=out)


@wraps(chunk.any)
def any(a, axis=None, keepdims=False, split_every=None, out=None):
    return reduction(a, chunk.any, chunk.any, axis=axis, keepdims=keepdims,
                     dtype='bool', split_every=split_every, out=out)


@wraps(chunk.all)
def all(a, axis=None, keepdims=False, split_every=None, out=None):
    return reduction(a, chunk.all, chunk.all, axis=axis, keepdims=keepdims,
                     dtype='bool', split_every=split_every, out=out)


@wraps(chunk.nansum)
def nansum(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
    if dtype is not None:
        dt = dtype
    else:
        dt = getattr(chunk.nansum(np.empty((1,), dtype=a.dtype)), 'dtype', object)
    return reduction(a, chunk.nansum, chunk.sum, axis=axis, keepdims=keepdims,
                     dtype=dt, split_every=split_every, out=out)


with ignoring(AttributeError):
    @wraps(chunk.nanprod)
    def nanprod(a, axis=None, dtype=None, keepdims=False, split_every=None,
                out=None):
        if dtype is not None:
            dt = dtype
        else:
            dt = getattr(chunk.nansum(np.empty((1,), dtype=a.dtype)), 'dtype', object)
        return reduction(a, chunk.nanprod, chunk.prod, axis=axis,
                         keepdims=keepdims, dtype=dt, split_every=split_every,
                         out=out)

    @wraps(chunk.nancumsum)
    def nancumsum(x, axis, dtype=None, out=None):
        return cumreduction(chunk.nancumsum, operator.add, 0, x, axis, dtype,
                            out=out)

    @wraps(chunk.nancumprod)
    def nancumprod(x, axis, dtype=None, out=None):
        return cumreduction(chunk.nancumprod, operator.mul, 1, x, axis, dtype,
                            out=out)


@wraps(chunk.nanmin)
def nanmin(a, axis=None, keepdims=False, split_every=None, out=None):
    return reduction(a, chunk.nanmin, chunk.nanmin, axis=axis,
                     keepdims=keepdims, dtype=a.dtype, split_every=split_every,
                     out=out)


@wraps(chunk.nanmax)
def nanmax(a, axis=None, keepdims=False, split_every=None, out=None):
    return reduction(a, chunk.nanmax, chunk.nanmax, axis=axis,
                     keepdims=keepdims, dtype=a.dtype, split_every=split_every,
                     out=out)


def numel(x, **kwargs):
    """ A reduction to count the number of elements """

    if hasattr(x, 'mask'):
        return chunk.sum(np.ones_like(x), **kwargs)

    shape = x.shape
    keepdims = kwargs.get('keepdims', False)
    axis = kwargs.get('axis', None)
    dtype = kwargs.get('dtype', np.float64)

    if axis is None:
        prod = np.prod(shape, dtype=dtype)
        return np.full((1,) * len(shape), prod, dtype=dtype) if keepdims is True else prod

    if not isinstance(axis, tuple or list):
        axis = [axis]

    prod = np.prod([shape[dim] for dim in axis])
    if keepdims is True:
        new_shape = tuple(shape[dim] if dim not in axis else 1 for dim in range(len(shape)))
    else:
        new_shape = tuple(shape[dim] for dim in range(len(shape)) if dim not in axis)
    return np.full(new_shape, prod, dtype=dtype)


def nannumel(x, **kwargs):
    """ A reduction to count the number of elements """
    return chunk.sum(~np.isnan(x), **kwargs)


def mean_chunk(x, sum=chunk.sum, numel=numel, dtype='f8', **kwargs):
    n = numel(x, dtype=dtype, **kwargs)
    total = sum(x, dtype=dtype, **kwargs)
    return {'n': n, 'total': total}


def mean_combine(pairs, sum=chunk.sum, numel=numel, dtype='f8', axis=None, **kwargs):
    if not isinstance(pairs, list):
        pairs = [pairs]
    ns = deepmap(lambda pair: pair['n'], pairs)
    totals = deepmap(lambda pair: pair['total'], pairs)
    n = _concatenate2(ns, axes=axis).sum(axis=axis, **kwargs)
    total = _concatenate2(totals, axes=axis).sum(axis=axis, **kwargs)
    return {'n': n, 'total': total}


def mean_agg(pairs, dtype='f8', axis=None, **kwargs):
    ns = deepmap(lambda pair: pair['n'], pairs)
    totals = deepmap(lambda pair: pair['total'], pairs)
    n = _concatenate2(ns, axes=axis).sum(axis=axis, dtype=dtype, **kwargs)
    total = _concatenate2(totals, axes=axis).sum(axis=axis, dtype=dtype, **kwargs)

    return divide(total, n, dtype=dtype)


@wraps(chunk.mean)
def mean(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
    if dtype is not None:
        dt = dtype
    else:
        dt = getattr(np.mean(np.empty(shape=(1,), dtype=a.dtype)), 'dtype', object)
    return reduction(a, mean_chunk, mean_agg, axis=axis, keepdims=keepdims,
                     dtype=dt, split_every=split_every, combine=mean_combine,
                     out=out, concatenate=False)


def nanmean(a, axis=None, dtype=None, keepdims=False, split_every=None,
            out=None):
    if dtype is not None:
        dt = dtype
    else:
        dt = getattr(np.mean(np.empty(shape=(1,), dtype=a.dtype)), 'dtype', object)
    return reduction(a, partial(mean_chunk, sum=chunk.nansum, numel=nannumel),
                     mean_agg, axis=axis, keepdims=keepdims, dtype=dt,
                     split_every=split_every, out=out,
                     concatenate=False,
                     combine=partial(mean_combine, sum=chunk.nansum, numel=nannumel))


with ignoring(AttributeError):
    nanmean = wraps(chunk.nanmean)(nanmean)


def moment_chunk(A, order=2, sum=chunk.sum, numel=numel, dtype='f8', **kwargs):
    total = sum(A, dtype=dtype, **kwargs)
    n = numel(A, **kwargs).astype(np.int64)
    u = total / n
    xs = [sum((A - u)**i, dtype=dtype, **kwargs) for i in range(2, order + 1)]
    M = np.stack(xs, axis=-1)
    return {'total': total, 'n': n, 'M': M}


def _moment_helper(Ms, ns, inner_term, order, sum, axis, kwargs):
    M = Ms[..., order - 2].sum(axis=axis, **kwargs) + sum(ns * inner_term ** order, axis=axis, **kwargs)
    for k in range(1, order - 1):
        coeff = factorial(order) / (factorial(k) * factorial(order - k))
        M += coeff * sum(Ms[..., order - k - 2] * inner_term**k, axis=axis, **kwargs)
    return M


def moment_combine(pairs, order=2, ddof=0, dtype='f8', sum=np.sum, axis=None, **kwargs):
    if not isinstance(pairs, list):
        pairs = [pairs]
    totals = _concatenate2(deepmap(lambda pair: pair['total'], pairs), axes=axis)
    ns = _concatenate2(deepmap(lambda pair: pair['n'], pairs), axes=axis)
    Ms = _concatenate2(deepmap(lambda pair: pair['M'], pairs), axes=axis)

    kwargs['dtype'] = dtype
    kwargs['keepdims'] = True

    total = totals.sum(axis=axis, **kwargs)
    n = ns.sum(axis=axis, **kwargs)
    mu = divide(total, n, dtype=dtype)
    inner_term = divide(totals, ns, dtype=dtype) - mu

    xs = [_moment_helper(Ms, ns, inner_term, o, sum, axis, kwargs) for o in range(2, order + 1)]
    M = np.stack(xs, axis=-1)
    return {'total': total, 'n': n, 'M': M}


def moment_agg(pairs, order=2, ddof=0, dtype='f8', sum=np.sum, axis=None, **kwargs):
    if not isinstance(pairs, list):
        pairs = [pairs]
    totals = _concatenate2(deepmap(lambda pair: pair['total'], pairs), axes=axis)
    ns = _concatenate2(deepmap(lambda pair: pair['n'], pairs), axes=axis)
    Ms = _concatenate2(deepmap(lambda pair: pair['M'], pairs), axes=axis)

    kwargs['dtype'] = dtype
    # To properly handle ndarrays, the original dimensions need to be kept for
    # part of the calculation.
    keepdim_kw = kwargs.copy()
    keepdim_kw['keepdims'] = True

    n = ns.sum(axis=axis, **keepdim_kw)
    mu = divide(totals.sum(axis=axis, **keepdim_kw), n, dtype=dtype)
    inner_term = divide(totals, ns, dtype=dtype) - mu

    M = _moment_helper(Ms, ns, inner_term, order, sum, axis, kwargs)
    return divide(M, n.sum(axis=axis, **kwargs) - ddof, dtype=dtype)


def moment(a, order, axis=None, dtype=None, keepdims=False, ddof=0,
           split_every=None, out=None):
    if not isinstance(order, Integral) or order < 0:
        raise ValueError("Order must be an integer >= 0")

    if order < 2:
        reduced = a.sum(axis=axis)   # get reduced shape and chunks
        if order == 0:
            # When order equals 0, the result is 1, by definition.
            return ones(reduced.shape, chunks=reduced.chunks, dtype='f8')
        # By definition the first order about the mean is 0.
        return zeros(reduced.shape, chunks=reduced.chunks, dtype='f8')

    if dtype is not None:
        dt = dtype
    else:
        dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), 'dtype', object)
    return reduction(a, partial(moment_chunk, order=order),
                     partial(moment_agg, order=order, ddof=ddof),
                     axis=axis, keepdims=keepdims,
                     dtype=dt, split_every=split_every, out=out,
                     concatenate=False,
                     combine=partial(moment_combine, order=order))


@wraps(chunk.var)
def var(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None,
        out=None):
    if dtype is not None:
        dt = dtype
    else:
        dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), 'dtype', object)
    return reduction(a, moment_chunk, partial(moment_agg, ddof=ddof), axis=axis,
                     keepdims=keepdims, dtype=dt, split_every=split_every,
                     combine=moment_combine, name='var', out=out,
                     concatenate=False)


def nanvar(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None,
           out=None):
    if dtype is not None:
        dt = dtype
    else:
        dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), 'dtype', object)
    return reduction(a, partial(moment_chunk, sum=chunk.nansum, numel=nannumel),
                     partial(moment_agg, sum=np.nansum, ddof=ddof), axis=axis,
                     keepdims=keepdims, dtype=dt, split_every=split_every,
                     combine=partial(moment_combine, sum=np.nansum), out=out,
                     concatenate=False)


with ignoring(AttributeError):
    nanvar = wraps(chunk.nanvar)(nanvar)


@wraps(chunk.std)
def std(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None,
        out=None):
    result = sqrt(a.var(axis=axis, dtype=dtype, keepdims=keepdims, ddof=ddof,
                        split_every=split_every, out=out))
    if dtype and dtype != result.dtype:
        result = result.astype(dtype)
    return result


def nanstd(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None,
           out=None):
    result = sqrt(nanvar(a, axis=axis, dtype=dtype, keepdims=keepdims,
                         ddof=ddof, split_every=split_every, out=out))
    if dtype and dtype != result.dtype:
        result = result.astype(dtype)
    return result


with ignoring(AttributeError):
    nanstd = wraps(chunk.nanstd)(nanstd)


def _arg_combine(data, axis, argfunc, keepdims=False):
    """ Merge intermediate results from ``arg_*`` functions"""
    axis = None if len(axis) == data.ndim or data.ndim == 1 else axis[0]
    vals = data['vals']
    arg = data['arg']
    if axis is None:
        local_args = argfunc(vals, axis=axis, keepdims=keepdims)
        vals = vals.ravel()[local_args]
        arg = arg.ravel()[local_args]
    else:
        local_args = argfunc(vals, axis=axis)
        inds = np.ogrid[tuple(map(slice, local_args.shape))]
        inds.insert(axis, local_args)
        inds = tuple(inds)
        vals = vals[inds]
        arg = arg[inds]
        if keepdims:
            vals = np.expand_dims(vals, axis)
            arg = np.expand_dims(arg, axis)
    return arg, vals


def arg_chunk(func, argfunc, x, axis, offset_info):
    arg_axis = None if len(axis) == x.ndim or x.ndim == 1 else axis[0]
    vals = func(x, axis=arg_axis, keepdims=True)
    arg = argfunc(x, axis=arg_axis, keepdims=True)
    if arg_axis is None:
        offset, total_shape = offset_info
        ind = np.unravel_index(arg.ravel()[0], x.shape)
        total_ind = tuple(o + i for (o, i) in zip(offset, ind))
        arg[:] = np.ravel_multi_index(total_ind, total_shape)
    else:
        arg += offset_info

    if isinstance(vals, np.ma.masked_array):
        if 'min' in argfunc.__name__:
            fill_value = np.ma.minimum_fill_value(vals)
        else:
            fill_value = np.ma.maximum_fill_value(vals)
        vals = np.ma.filled(vals, fill_value)

    result = np.empty(shape=vals.shape, dtype=[('vals', vals.dtype),
                                               ('arg', arg.dtype)])
    result['vals'] = vals
    result['arg'] = arg
    return result


def arg_combine(func, argfunc, data, axis=None, **kwargs):
    arg, vals = _arg_combine(data, axis, argfunc, keepdims=True)
    result = np.empty(shape=vals.shape, dtype=[('vals', vals.dtype),
                                               ('arg', arg.dtype)])
    result['vals'] = vals
    result['arg'] = arg
    return result


def arg_agg(func, argfunc, data, axis=None, **kwargs):
    return _arg_combine(data, axis, argfunc, keepdims=False)[0]


def nanarg_agg(func, argfunc, data, axis=None, **kwargs):
    arg, vals = _arg_combine(data, axis, argfunc, keepdims=False)
    if np.any(np.isnan(vals)):
        raise ValueError("All NaN slice encountered")
    return arg


def arg_reduction(x, chunk, combine, agg, axis=None, split_every=None, out=None):
    """ Generic function for argreduction.

    Parameters
    ----------
    x : Array
    chunk : callable
        Partialed ``arg_chunk``.
    combine : callable
        Partialed ``arg_combine``.
    agg : callable
        Partialed ``arg_agg``.
    axis : int, optional
    split_every : int or dict, optional
    """
    if axis is None:
        axis = tuple(range(x.ndim))
        ravel = True
    elif isinstance(axis, Integral):
        axis = validate_axis(axis, x.ndim)
        axis = (axis,)
        ravel = x.ndim == 1
    else:
        raise TypeError("axis must be either `None` or int, "
                        "got '{0}'".format(axis))

    for ax in axis:
        chunks = x.chunks[ax]
        if len(chunks) > 1 and np.isnan(chunks).any():
            raise ValueError(
                "Arg-reductions do not work with arrays that have "
                "unknown chunksizes.  At some point in your computation "
                "this array lost chunking information"
            )

    # Map chunk across all blocks
    name = 'arg-reduce-{0}'.format(tokenize(axis, x, chunk,
                                            combine, split_every))
    old = x.name
    keys = list(product(*map(range, x.numblocks)))
    offsets = list(product(*(accumulate(operator.add, bd[:-1], 0)
                             for bd in x.chunks)))
    if ravel:
        offset_info = zip(offsets, repeat(x.shape))
    else:
        offset_info = pluck(axis[0], offsets)

    chunks = tuple((1, ) * len(c) if i in axis else c for (i, c)
                   in enumerate(x.chunks))
    dsk = dict(((name,) + k, (chunk, (old,) + k, axis, off)) for (k, off)
               in zip(keys, offset_info))
    # The dtype of `tmp` doesn't actually matter, just need to provide something
    graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
    tmp = Array(graph, name, chunks, dtype=x.dtype)
    dtype = np.argmin([1]).dtype
    result = _tree_reduce(tmp, agg, axis, False, dtype, split_every, combine)
    return handle_out(out, result)


def make_arg_reduction(func, argfunc, is_nan_func=False):
    """ Create an argreduction callable

    Parameters
    ----------
    func : callable
        The reduction (e.g. ``min``)
    argfunc : callable
        The argreduction (e.g. ``argmin``)
    """
    chunk = partial(arg_chunk, func, argfunc)
    combine = partial(arg_combine, func, argfunc)
    if is_nan_func:
        agg = partial(nanarg_agg, func, argfunc)
    else:
        agg = partial(arg_agg, func, argfunc)

    @wraps(argfunc)
    def _(x, axis=None, split_every=None, out=None):
        return arg_reduction(x, chunk, combine, agg, axis,
                             split_every=split_every, out=out)

    return _


def _nanargmin(x, axis, **kwargs):
    try:
        return chunk.nanargmin(x, axis, **kwargs)
    except ValueError:
        return chunk.nanargmin(np.where(np.isnan(x), np.inf, x), axis, **kwargs)


def _nanargmax(x, axis, **kwargs):
    try:
        return chunk.nanargmax(x, axis, **kwargs)
    except ValueError:
        return chunk.nanargmax(np.where(np.isnan(x), -np.inf, x), axis, **kwargs)


argmin = make_arg_reduction(chunk.min, chunk.argmin)
argmax = make_arg_reduction(chunk.max, chunk.argmax)
nanargmin = make_arg_reduction(chunk.nanmin, _nanargmin, True)
nanargmax = make_arg_reduction(chunk.nanmax, _nanargmax, True)


def cumreduction(func, binop, ident, x, axis=None, dtype=None, out=None):
    """ Generic function for cumulative reduction

    Parameters
    ----------
    func: callable
        Cumulative function like np.cumsum or np.cumprod
    binop: callable
        Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
    ident: Number
        Associated identity like ``np.cumsum->0`` or ``np.cumprod->1``
    x: dask Array
    axis: int
    dtype: dtype

    Returns
    -------
    dask array

    See also
    --------
    cumsum
    cumprod
    """
    if axis is None:
        x = x.flatten()
        axis = 0
    if dtype is None:
        dtype = getattr(func(np.empty((0,), dtype=x.dtype)), 'dtype', object)
    assert isinstance(axis, Integral)
    axis = validate_axis(axis, x.ndim)

    m = x.map_blocks(func, axis=axis, dtype=dtype)

    name = '{0}-{1}'.format(func.__name__, tokenize(func, axis, binop,
                                                    ident, x, dtype))
    n = x.numblocks[axis]
    full = slice(None, None, None)
    slc = (full,) * axis + (slice(-1, None),) + (full,) * (x.ndim - axis - 1)

    indices = list(product(*[range(nb) if i != axis else [0]
                             for i, nb in enumerate(x.numblocks)]))
    dsk = dict()
    for ind in indices:
        shape = tuple(x.chunks[i][ii] if i != axis else 1
                      for i, ii in enumerate(ind))
        dsk[(name, 'extra') + ind] = (np.full, shape, ident, m.dtype)
        dsk[(name,) + ind] = (m.name,) + ind

    for i in range(1, n):
        last_indices = indices
        indices = list(product(*[range(nb) if ii != axis else [i]
                                 for ii, nb in enumerate(x.numblocks)]))
        for old, ind in zip(last_indices, indices):
            this_slice = (name, 'extra') + ind
            dsk[this_slice] = (binop, (name, 'extra') + old,
                                      (operator.getitem, (m.name,) + old, slc))
            dsk[(name,) + ind] = (binop, this_slice, (m.name,) + ind)

    graph = HighLevelGraph.from_collections(name, dsk, dependencies=[m])
    result = Array(graph, name, x.chunks, m.dtype)
    return handle_out(out, result)


def _cumsum_merge(a, b):
    if isinstance(a, np.ma.masked_array) or isinstance(b, np.ma.masked_array):
        values = np.ma.getdata(a) + np.ma.getdata(b)
        return np.ma.masked_array(values, mask=np.ma.getmaskarray(b))
    return a + b


def _cumprod_merge(a, b):
    if isinstance(a, np.ma.masked_array) or isinstance(b, np.ma.masked_array):
        values = np.ma.getdata(a) * np.ma.getdata(b)
        return np.ma.masked_array(values, mask=np.ma.getmaskarray(b))
    return a * b


@wraps(np.cumsum)
def cumsum(x, axis=None, dtype=None, out=None):
    return cumreduction(np.cumsum, _cumsum_merge, 0, x, axis, dtype, out=out)


@wraps(np.cumprod)
def cumprod(x, axis=None, dtype=None, out=None):
    return cumreduction(np.cumprod, _cumprod_merge, 1, x, axis, dtype, out=out)


def topk(a, k, axis=-1, split_every=None):
    """ Extract the k largest elements from a on the given axis,
    and return them sorted from largest to smallest.
    If k is negative, extract the -k smallest elements instead,
    and return them sorted from smallest to largest.

    This performs best when ``k`` is much smaller than the chunk size. All
    results will be returned in a single chunk along the given axis.

    Parameters
    ----------
    x: Array
        Data being sorted
    k: int
    axis: int, optional
    split_every: int >=2, optional
        See :func:`reduce`. This parameter becomes very important when k is
        on the same order of magnitude of the chunk size or more, as it
        prevents getting the whole or a significant portion of the input array
        in memory all at once, with a negative impact on network transfer
        too when running on distributed.

    Returns
    -------
    Selection of x with size abs(k) along the given axis.

    Examples
    --------
    >>> import dask.array as da
    >>> x = np.array([5, 1, 3, 6])
    >>> d = da.from_array(x, chunks=2)
    >>> d.topk(2).compute()
    array([6, 5])
    >>> d.topk(-2).compute()
    array([1, 3])
    """
    axis = validate_axis(axis, a.ndim)

    # chunk and combine steps of the reduction, which recursively invoke
    # np.partition to pick the top/bottom k elements from the previous step.
    # The selection is not sorted internally.
    chunk_combine = partial(chunk.topk, k=k)
    # aggregate step of the reduction. Internally invokes the chunk/combine
    # function, then sorts the results internally.
    aggregate = partial(chunk.topk_aggregate, k=k)

    return reduction(
        a, chunk=chunk_combine, combine=chunk_combine, aggregate=aggregate,
        axis=axis, keepdims=True, dtype=a.dtype, split_every=split_every,
        output_size=abs(k))


def argtopk(a, k, axis=-1, split_every=None):
    """ Extract the indices of the k largest elements from a on the given axis,
    and return them sorted from largest to smallest. If k is negative, extract
    the indices of the -k smallest elements instead, and return them sorted
    from smallest to largest.

    This performs best when ``k`` is much smaller than the chunk size. All
    results will be returned in a single chunk along the given axis.

    Parameters
    ----------
    x: Array
        Data being sorted
    k: int
    axis: int, optional
    split_every: int >=2, optional
        See :func:`topk`. The performance considerations for topk also apply
        here.

    Returns
    -------
    Selection of np.intp indices of x with size abs(k) along the given axis.

    Examples
    --------
    >>> import dask.array as da
    >>> x = np.array([5, 1, 3, 6])
    >>> d = da.from_array(x, chunks=2)
    >>> d.argtopk(2).compute()
    array([3, 0])
    >>> d.argtopk(-2).compute()
    array([1, 2])
    """
    axis = validate_axis(axis, a.ndim)

    # Generate nodes where every chunk is a tuple of (a, original index of a)
    idx = arange(a.shape[axis], chunks=(a.chunks[axis], ), dtype=np.intp)
    idx = idx[tuple(slice(None) if i == axis else np.newaxis
                    for i in range(a.ndim))]
    a_plus_idx = a.map_blocks(chunk.argtopk_preprocess, idx,
                              dtype=object)

    # chunk and combine steps of the reduction. They acquire in input a tuple
    # of (a, original indices of a) and return another tuple containing the top
    # k elements of a and the matching original indices. The selection is not
    # sorted internally, as in np.argpartition.
    chunk_combine = partial(chunk.argtopk, k=k)
    # aggregate step of the reduction. Internally invokes the chunk/combine
    # function, then sorts the results internally, drops a and returns the
    # index only.
    aggregate = partial(chunk.argtopk_aggregate, k=k)

    return reduction(
        a_plus_idx, chunk=chunk_combine, combine=chunk_combine,
        aggregate=aggregate, axis=axis, keepdims=True, dtype=np.intp,
        split_every=split_every, concatenate=False, output_size=abs(k))


@wraps(np.trace)
def trace(a, offset=0, axis1=0, axis2=1, dtype=None):
    return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1, dtype=dtype)
