import warnings
from itertools import product

import numpy as np

from ._c99_config import _have_c99_complex
from ._extensions._dwt import idwt_single
from ._extensions._swt import swt_max_level, swt as _swt, swt_axis as _swt_axis
from ._extensions._pywt import Modes, _check_dtype
from ._multidim import idwt2, idwtn
from ._utils import _as_wavelet, _wavelets_per_axis


__all__ = ["swt", "swt_max_level", 'iswt', 'swt2', 'iswt2', 'swtn', 'iswtn']


def swt(data, wavelet, level=None, start_level=0, axis=-1):
    """
    Multilevel 1D stationary wavelet transform.

    Parameters
    ----------
    data :
        Input signal
    wavelet :
        Wavelet to use (Wavelet object or name)
    level : int, optional
        The number of decomposition steps to perform.
    start_level : int, optional
        The level at which the decomposition will begin (it allows one to
        skip a given number of transform steps and compute
        coefficients starting from start_level) (default: 0)
    axis: int, optional
        Axis over which to compute the SWT. If not given, the
        last axis is used.

    Returns
    -------
    coeffs : list
        List of approximation and details coefficients pairs in order
        similar to wavedec function::

            [(cAn, cDn), ..., (cA2, cD2), (cA1, cD1)]

        where n equals input parameter ``level``.

        If ``start_level = m`` is given, then the beginning m steps are
        skipped::

            [(cAm+n, cDm+n), ..., (cAm+1, cDm+1), (cAm, cDm)]

    Notes
    -----
    The implementation here follows the "algorithm a-trous" and requires that
    the signal length along the transformed axis be a multiple of ``2**level``.
    If this is not the case, the user should pad up to an appropriate size
    using a function such as ``numpy.pad``.
    """
    if not _have_c99_complex and np.iscomplexobj(data):
        data = np.asarray(data)
        coeffs_real = swt(data.real, wavelet, level, start_level)
        coeffs_imag = swt(data.imag, wavelet, level, start_level)
        coeffs_cplx = []
        for (cA_r, cD_r), (cA_i, cD_i) in zip(coeffs_real, coeffs_imag):
            coeffs_cplx.append((cA_r + 1j*cA_i, cD_r + 1j*cD_i))
        return coeffs_cplx

    # accept array_like input; make a copy to ensure a contiguous array
    dt = _check_dtype(data)
    data = np.array(data, dtype=dt)

    wavelet = _as_wavelet(wavelet)

    if axis < 0:
        axis = axis + data.ndim
    if not 0 <= axis < data.ndim:
        raise ValueError("Axis greater than data dimensions")

    if level is None:
        level = swt_max_level(data.shape[axis])

    if data.ndim == 1:
        ret = _swt(data, wavelet, level, start_level)
    else:
        ret = _swt_axis(data, wavelet, level, start_level, axis)
    return [(np.asarray(cA), np.asarray(cD)) for cA, cD in ret]


def iswt(coeffs, wavelet):
    """
    Multilevel 1D inverse discrete stationary wavelet transform.

    Parameters
    ----------
    coeffs : array_like
        Coefficients list of tuples::

            [(cAn, cDn), ..., (cA2, cD2), (cA1, cD1)]

        where cA is approximation, cD is details.  Index 1 corresponds to
        ``start_level`` from ``pywt.swt``.
    wavelet : Wavelet object or name string
        Wavelet to use

    Returns
    -------
    1D array of reconstructed data.

    Examples
    --------
    >>> import pywt
    >>> coeffs = pywt.swt([1,2,3,4,5,6,7,8], 'db2', level=2)
    >>> pywt.iswt(coeffs, 'db2')
    array([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.])
    """
    # copy to avoid modification of input data
    dt = _check_dtype(coeffs[0][0])
    output = np.array(coeffs[0][0], dtype=dt, copy=True)

    if not _have_c99_complex and np.iscomplexobj(output):
        # compute real and imaginary separately then combine
        coeffs_real = [(cA.real, cD.real) for (cA, cD) in coeffs]
        coeffs_imag = [(cA.imag, cD.imag) for (cA, cD) in coeffs]
        return iswt(coeffs_real, wavelet) + 1j*iswt(coeffs_imag, wavelet)

    # num_levels, equivalent to the decomposition level, n
    num_levels = len(coeffs)
    wavelet = _as_wavelet(wavelet)
    mode = Modes.from_object('periodization')
    for j in range(num_levels, 0, -1):
        step_size = int(pow(2, j-1))
        last_index = step_size
        _, cD = coeffs[num_levels - j]
        cD = np.asarray(cD, dtype=_check_dtype(cD))
        if cD.dtype != output.dtype:
            # upcast to a common dtype (float64 or complex128)
            if output.dtype.kind == 'c' or cD.dtype.kind == 'c':
                dtype = np.complex128
            else:
                dtype = np.float64
            output = np.asarray(output, dtype=dtype)
            cD = np.asarray(cD, dtype=dtype)
        for first in range(last_index):  # 0 to last_index - 1

            # Getting the indices that we will transform
            indices = np.arange(first, len(cD), step_size)

            # select the even indices
            even_indices = indices[0::2]
            # select the odd indices
            odd_indices = indices[1::2]

            # perform the inverse dwt on the selected indices,
            # making sure to use periodic boundary conditions
            # Note:  indexing with an array of ints returns a contiguous
            #        copy as required by idwt_single.
            x1 = idwt_single(output[even_indices],
                             cD[even_indices],
                             wavelet, mode)
            x2 = idwt_single(output[odd_indices],
                             cD[odd_indices],
                             wavelet, mode)

            # perform a circular shift right
            x2 = np.roll(x2, 1)

            # average and insert into the correct indices
            output[indices] = (x1 + x2)/2.

    return output


def swt2(data, wavelet, level, start_level=0, axes=(-2, -1)):
    """
    Multilevel 2D stationary wavelet transform.

    Parameters
    ----------
    data : array_like
        2D array with input data
    wavelet : Wavelet object or name string, or 2-tuple of wavelets
        Wavelet to use.  This can also be a tuple of wavelets to apply per
        axis in ``axes``.
    level : int
        The number of decomposition steps to perform.
    start_level : int, optional
        The level at which the decomposition will start (default: 0)
    axes : 2-tuple of ints, optional
        Axes over which to compute the SWT. Repeated elements are not allowed.

    Returns
    -------
    coeffs : list
        Approximation and details coefficients (for ``start_level = m``)::

            [
                (cA_m+level,
                    (cH_m+level, cV_m+level, cD_m+level)
                ),
                ...,
                (cA_m+1,
                    (cH_m+1, cV_m+1, cD_m+1)
                ),
                (cA_m,
                    (cH_m, cV_m, cD_m)
                )
            ]

        where cA is approximation, cH is horizontal details, cV is
        vertical details, cD is diagonal details and m is ``start_level``.

    Notes
    -----
    The implementation here follows the "algorithm a-trous" and requires that
    the signal length along the transformed axes be a multiple of ``2**level``.
    If this is not the case, the user should pad up to an appropriate size
    using a function such as ``numpy.pad``.
    """
    axes = tuple(axes)
    data = np.asarray(data)
    if len(axes) != 2:
        raise ValueError("Expected 2 axes")
    if len(axes) != len(set(axes)):
        raise ValueError("The axes passed to swt2 must be unique.")
    if data.ndim < len(np.unique(axes)):
        raise ValueError("Input array has fewer dimensions than the specified "
                         "axes")

    coefs = swtn(data, wavelet, level, start_level, axes)
    ret = []
    for c in coefs:
        ret.append((c['aa'], (c['da'], c['ad'], c['dd'])))

    return ret


def iswt2(coeffs, wavelet):
    """
    Multilevel 2D inverse discrete stationary wavelet transform.

    Parameters
    ----------
    coeffs : list
        Approximation and details coefficients::

            [
                (cA_n,
                    (cH_n, cV_n, cD_n)
                ),
                ...,
                (cA_2,
                    (cH_2, cV_2, cD_2)
                ),
                (cA_1,
                    (cH_1, cV_1, cD_1)
                )
            ]

        where cA is approximation, cH is horizontal details, cV is
        vertical details, cD is diagonal details and n is the number of
        levels.  Index 1 corresponds to ``start_level`` from ``pywt.swt2``.
    wavelet : Wavelet object or name string, or 2-tuple of wavelets
        Wavelet to use.  This can also be a 2-tuple of wavelets to apply per
        axis.

    Returns
    -------
    2D array of reconstructed data.

    Examples
    --------
    >>> import pywt
    >>> coeffs = pywt.swt2([[1,2,3,4],[5,6,7,8],
    ...                     [9,10,11,12],[13,14,15,16]],
    ...                    'db1', level=2)
    >>> pywt.iswt2(coeffs, 'db1')
    array([[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]])

    """

    # copy to avoid modification of input data
    dt = _check_dtype(coeffs[0][0])
    output = np.array(coeffs[0][0], dtype=dt, copy=True)

    if output.ndim != 2:
        raise ValueError(
            "iswt2 only supports 2D arrays.  see iswtn for a general "
            "n-dimensionsal ISWT")
    # num_levels, equivalent to the decomposition level, n
    num_levels = len(coeffs)
    wavelets = _wavelets_per_axis(wavelet, axes=(0, 1))

    for j in range(num_levels):
        step_size = int(pow(2, num_levels-j-1))
        last_index = step_size
        _, (cH, cV, cD) = coeffs[j]
        # We are going to assume cH, cV, and cD are of equal size
        if (cH.shape != cV.shape) or (cH.shape != cD.shape):
            raise RuntimeError(
                "Mismatch in shape of intermediate coefficient arrays")

        # make sure output shares the common dtype
        # (conversion of dtype for individual coeffs is handled within idwt2 )
        common_dtype = np.result_type(*(
            [dt, ] + [_check_dtype(c) for c in [cH, cV, cD]]))
        if output.dtype != common_dtype:
            output = output.astype(common_dtype)

        for first_h in range(last_index):  # 0 to last_index - 1
            for first_w in range(last_index):  # 0 to last_index - 1
                # Getting the indices that we will transform
                indices_h = slice(first_h, cH.shape[0], step_size)
                indices_w = slice(first_w, cH.shape[1], step_size)

                even_idx_h = slice(first_h, cH.shape[0], 2*step_size)
                even_idx_w = slice(first_w, cH.shape[1], 2*step_size)
                odd_idx_h = slice(first_h + step_size, cH.shape[0], 2*step_size)
                odd_idx_w = slice(first_w + step_size, cH.shape[1], 2*step_size)

                # perform the inverse dwt on the selected indices,
                # making sure to use periodic boundary conditions
                x1 = idwt2((output[even_idx_h, even_idx_w],
                           (cH[even_idx_h, even_idx_w],
                            cV[even_idx_h, even_idx_w],
                            cD[even_idx_h, even_idx_w])),
                           wavelets, 'periodization')
                x2 = idwt2((output[even_idx_h, odd_idx_w],
                           (cH[even_idx_h, odd_idx_w],
                            cV[even_idx_h, odd_idx_w],
                            cD[even_idx_h, odd_idx_w])),
                           wavelets, 'periodization')
                x3 = idwt2((output[odd_idx_h, even_idx_w],
                           (cH[odd_idx_h, even_idx_w],
                            cV[odd_idx_h, even_idx_w],
                            cD[odd_idx_h, even_idx_w])),
                           wavelets, 'periodization')
                x4 = idwt2((output[odd_idx_h, odd_idx_w],
                           (cH[odd_idx_h, odd_idx_w],
                            cV[odd_idx_h, odd_idx_w],
                            cD[odd_idx_h, odd_idx_w])),
                           wavelets, 'periodization')

                # perform a circular shifts
                x2 = np.roll(x2, 1, axis=1)
                x3 = np.roll(x3, 1, axis=0)
                x4 = np.roll(x4, 1, axis=0)
                x4 = np.roll(x4, 1, axis=1)
                output[indices_h, indices_w] = (x1 + x2 + x3 + x4) / 4

    return output


def swtn(data, wavelet, level, start_level=0, axes=None):
    """
    n-dimensional stationary wavelet transform.

    Parameters
    ----------
    data : array_like
        n-dimensional array with input data.
    wavelet : Wavelet object or name string, or tuple of wavelets
        Wavelet to use.  This can also be a tuple of wavelets to apply per
        axis in ``axes``.
    level : int
        The number of decomposition steps to perform.
    start_level : int, optional
        The level at which the decomposition will start (default: 0)
    axes : sequence of ints, optional
        Axes over which to compute the SWT. A value of ``None`` (the
        default) selects all axes. Axes may not be repeated.

    Returns
    -------
    [{coeffs_level_n}, ..., {coeffs_level_1}]: list of dict
        Results for each level are arranged in a dictionary, where the key
        specifies the transform type on each dimension and value is a
        n-dimensional coefficients array.

        For example, for a 2D case the result at a given level will look
        something like this::

            {'aa': <coeffs>  # A(LL) - approx. on 1st dim, approx. on 2nd dim
             'ad': <coeffs>  # V(LH) - approx. on 1st dim, det. on 2nd dim
             'da': <coeffs>  # H(HL) - det. on 1st dim, approx. on 2nd dim
             'dd': <coeffs>  # D(HH) - det. on 1st dim, det. on 2nd dim
            }

        For user-specified ``axes``, the order of the characters in the
        dictionary keys map to the specified ``axes``.

    Notes
    -----
    The implementation here follows the "algorithm a-trous" and requires that
    the signal length along the transformed axes be a multiple of ``2**level``.
    If this is not the case, the user should pad up to an appropriate size
    using a function such as ``numpy.pad``.
    """
    data = np.asarray(data)
    if not _have_c99_complex and np.iscomplexobj(data):
        real = swtn(data.real, wavelet, level, start_level, axes)
        imag = swtn(data.imag, wavelet, level, start_level, axes)
        cplx = []
        for rdict, idict in zip(real, imag):
            cplx.append(
                dict((k, rdict[k] + 1j * idict[k]) for k in rdict.keys()))
        return cplx

    if data.dtype == np.dtype('object'):
        raise TypeError("Input must be a numeric array-like")
    if data.ndim < 1:
        raise ValueError("Input data must be at least 1D")

    if axes is None:
        axes = range(data.ndim)
    axes = [a + data.ndim if a < 0 else a for a in axes]
    if len(axes) != len(set(axes)):
        raise ValueError("The axes passed to swtn must be unique.")
    num_axes = len(axes)

    wavelets = _wavelets_per_axis(wavelet, axes)

    ret = []
    for i in range(start_level, start_level + level):
        coeffs = [('', data)]
        for axis, wavelet in zip(axes, wavelets):
            new_coeffs = []
            for subband, x in coeffs:
                cA, cD = _swt_axis(x, wavelet, level=1, start_level=i,
                                   axis=axis)[0]
                new_coeffs.extend([(subband + 'a', cA),
                                   (subband + 'd', cD)])
            coeffs = new_coeffs

        coeffs = dict(coeffs)
        ret.append(coeffs)

        # data for the next level is the approximation coeffs from this level
        data = coeffs['a' * num_axes]

    ret.reverse()
    return ret


def iswtn(coeffs, wavelet, axes=None):
    """
    Multilevel nD inverse discrete stationary wavelet transform.

    Parameters
    ----------
    coeffs : list
        [{coeffs_level_n}, ..., {coeffs_level_1}]: list of dict
    wavelet : Wavelet object or name string, or tuple of wavelets
        Wavelet to use.  This can also be a tuple of wavelets to apply per
        axis in ``axes``.
    axes : sequence of ints, optional
        Axes over which to compute the inverse SWT. Axes may not be repeated.
        The default is ``None``, which means transform all axes
        (``axes = range(data.ndim)``).

    Returns
    -------
    nD array of reconstructed data.

    Examples
    --------
    >>> import pywt
    >>> coeffs = pywt.swtn([[1,2,3,4],[5,6,7,8],
    ...                     [9,10,11,12],[13,14,15,16]],
    ...                    'db1', level=2)
    >>> pywt.iswtn(coeffs, 'db1')
    array([[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]])

    """

    # key length matches the number of axes transformed
    ndim_transform = max(len(key) for key in coeffs[0].keys())

    # copy to avoid modification of input data
    dt = _check_dtype(coeffs[0]['a'*ndim_transform])
    output = np.array(coeffs[0]['a'*ndim_transform], dtype=dt, copy=True)
    ndim = output.ndim

    if axes is None:
        axes = range(output.ndim)
    axes = [a + ndim if a < 0 else a for a in axes]
    if len(axes) != len(set(axes)):
        raise ValueError("The axes passed to swtn must be unique.")
    if ndim_transform != len(axes):
        raise ValueError("The number of axes used in iswtn must match the "
                         "number of dimensions transformed in swtn.")

    # num_levels, equivalent to the decomposition level, n
    num_levels = len(coeffs)
    wavelets = _wavelets_per_axis(wavelet, axes)

    # initialize various slice objects used in the loops below
    # these will remain slice(None) only on axes that aren't transformed
    indices = [slice(None), ]*ndim
    even_indices = [slice(None), ]*ndim
    odd_indices = [slice(None), ]*ndim
    odd_even_slices = [slice(None), ]*ndim

    for j in range(num_levels):
        step_size = int(pow(2, num_levels-j-1))
        last_index = step_size
        a = coeffs[j].pop('a'*ndim_transform)  # will restore later
        details = coeffs[j]
        # make sure dtype matches the coarsest level approximation coefficients
        common_dtype = np.result_type(*(
            [dt, ] + [v.dtype for v in details.values()]))
        if output.dtype != common_dtype:
            output = output.astype(common_dtype)

        # We assume all coefficient arrays are of equal size
        shapes = [v.shape for k, v in details.items()]
        if len(set(shapes)) != 1:
            raise RuntimeError(
                "Mismatch in shape of intermediate coefficient arrays")

        # shape of a single coefficient array, excluding non-transformed axes
        coeff_trans_shape = tuple([shapes[0][ax] for ax in axes])

        # nested loop over all combinations of axis offsets at this level
        for firsts in product(*([range(last_index), ]*ndim_transform)):
            for first, sh, ax in zip(firsts, coeff_trans_shape, axes):
                indices[ax] = slice(first, sh, step_size)
                even_indices[ax] = slice(first, sh, 2*step_size)
                odd_indices[ax] = slice(first+step_size, sh, 2*step_size)

            # nested loop over all combinations of odd/even inidices
            approx = output.copy()
            output[tuple(indices)] = 0
            ntransforms = 0
            for odds in product(*([(0, 1), ]*ndim_transform)):
                for o, ax in zip(odds, axes):
                    if o:
                        odd_even_slices[ax] = odd_indices[ax]
                    else:
                        odd_even_slices[ax] = even_indices[ax]
                # extract the odd/even indices for all detail coefficients
                details_slice = {}
                for key, value in details.items():
                    details_slice[key] = value[tuple(odd_even_slices)]
                details_slice['a'*ndim_transform] = approx[
                    tuple(odd_even_slices)]

                # perform the inverse dwt on the selected indices,
                # making sure to use periodic boundary conditions
                x = idwtn(details_slice, wavelets, 'periodization', axes=axes)
                for o, ax in zip(odds, axes):
                    # circular shift along any odd indexed axis
                    if o:
                        x = np.roll(x, 1, axis=ax)
                output[tuple(indices)] += x
                ntransforms += 1
            output[tuple(indices)] /= ntransforms  # normalize
        coeffs[j]['a'*ndim_transform] = a  # restore approx coeffs to dict
    return output
