from .core import tensordot_lookup, concatenate_lookup, einsum_lookup


@tensordot_lookup.register_lazy('cupy')
@concatenate_lookup.register_lazy('cupy')
def register_cupy():
    import cupy
    concatenate_lookup.register(cupy.ndarray, cupy.concatenate)
    tensordot_lookup.register(cupy.ndarray, cupy.tensordot)

    @einsum_lookup.register(cupy.ndarray)
    def _cupy_einsum(*args, **kwargs):
        # NB: cupy does not accept `order` or `casting` kwargs - ignore
        kwargs.pop('casting', None)
        kwargs.pop('order', None)
        return cupy.einsum(*args, **kwargs)


@tensordot_lookup.register_lazy('sparse')
@concatenate_lookup.register_lazy('sparse')
def register_sparse():
    import sparse
    concatenate_lookup.register(sparse.COO, sparse.concatenate)
    tensordot_lookup.register(sparse.COO, sparse.tensordot)


@concatenate_lookup.register_lazy('scipy')
def register_scipy_sparse():
    import scipy.sparse

    def _concatenate(L, axis=0):
        if axis == 0:
            return scipy.sparse.vstack(L)
        elif axis == 1:
            return scipy.sparse.hstack(L)
        else:
            msg = ("Can only concatenate scipy sparse matrices for axis in "
                   "{0, 1}.  Got %s" % axis)
            raise ValueError(msg)
    concatenate_lookup.register(scipy.sparse.spmatrix, _concatenate)
