from __future__ import absolute_import, division, print_function

import numpy as np
import pandas as pd
from toolz import partial

from ..utils import derived_from
from .utils import is_categorical_dtype, PANDAS_VERSION


def maybe_wrap_pandas(obj, x):
    if isinstance(x, np.ndarray):
        if isinstance(obj, pd.Series):
            return pd.Series(x, index=obj.index, dtype=x.dtype)
        return pd.Index(x)
    return x


class Accessor(object):
    """
    Base class for pandas Accessor objects cat, dt, and str.

    Notes
    -----
    Subclasses should define the following attributes:

    * _accessor
    * _accessor_name
    """
    _not_implemented = set()

    def __init__(self, series):
        from .core import Series
        if not isinstance(series, Series):
            raise ValueError('Accessor cannot be initialized')
        self._validate(series)
        self._series = series

    def _validate(self, series):
        pass

    @staticmethod
    def _delegate_property(obj, accessor, attr):
        out = getattr(getattr(obj, accessor, obj), attr)
        return maybe_wrap_pandas(obj, out)

    @staticmethod
    def _delegate_method(obj, accessor, attr, args, kwargs):
        out = getattr(getattr(obj, accessor, obj), attr)(*args, **kwargs)
        return maybe_wrap_pandas(obj, out)

    def _property_map(self, attr):
        meta = self._delegate_property(self._series._meta,
                                       self._accessor_name, attr)
        token = '%s-%s' % (self._accessor_name, attr)
        return self._series.map_partitions(self._delegate_property,
                                           self._accessor_name, attr,
                                           token=token, meta=meta)

    def _function_map(self, attr, *args, **kwargs):
        if 'meta' in kwargs:
            meta = kwargs.pop('meta')
        else:
            meta = self._delegate_method(self._series._meta_nonempty,
                                         self._accessor_name, attr, args, kwargs)
        token = '%s-%s' % (self._accessor_name, attr)
        return self._series.map_partitions(self._delegate_method,
                                           self._accessor_name, attr, args,
                                           kwargs, meta=meta, token=token)

    @property
    def _delegates(self):
        return set(dir(self._accessor)).difference(self._not_implemented)

    def __dir__(self):
        o = self._delegates
        o.update(self.__dict__)
        o.update(dir(type(self)))
        return list(o)

    def __getattr__(self, key):
        if key in self._delegates:
            if isinstance(getattr(self._accessor, key), property):
                return self._property_map(key)
            else:
                return partial(self._function_map, key)
        else:
            raise AttributeError(key)


class DatetimeAccessor(Accessor):
    """ Accessor object for datetimelike properties of the Series values.

    Examples
    --------

    >>> s.dt.microsecond  # doctest: +SKIP
    """
    _accessor = pd.Series.dt
    _accessor_name = 'dt'


class StringAccessor(Accessor):
    """ Accessor object for string properties of the Series values.

    Examples
    --------

    >>> s.str.lower()  # doctest: +SKIP
    """
    _accessor = pd.Series.str
    _accessor_name = 'str'
    _not_implemented = {'get_dummies'}

    def _validate(self, series):
        if not (series.dtype == 'object' or (
                is_categorical_dtype(series) and
                series.cat.categories.dtype == 'object')):
            raise AttributeError("Can only use .str accessor with object dtype")

    @derived_from(pd.core.strings.StringMethods)
    def split(self, pat=None, n=-1, expand=False):
        if expand:
            if n == -1:
                raise NotImplementedError(
                    "To use the expand parameter you must specify the number of "
                    "expected output columns with the n= parameter"
                )
            else:
                meta = type(self._series._meta)([' '.join(['a'] * 2 * n)])
                meta = meta.str.split(n=n, expand=expand, pat=pat)
        return self._function_map('split', pat=pat, n=n, expand=expand, meta=meta)

    @derived_from(pd.core.strings.StringMethods)
    def cat(self, others=None, sep=None, na_rep=None):
        from .core import Series, Index
        if others is None:
            raise NotImplementedError("x.str.cat() with `others == None`")

        valid_types = (Series, Index, pd.Series, pd.Index)
        if isinstance(others, valid_types):
            others = [others]
        elif not all(isinstance(a, valid_types) for a in others):
            raise TypeError("others must be Series/Index")

        return self._series.map_partitions(str_cat, *others, sep=sep,
                                           na_rep=na_rep, meta=self._series._meta)

    @derived_from(pd.core.strings.StringMethods)
    def extractall(self, pat, flags=0):
        # TODO: metadata inference here won't be necessary for pandas >= 0.23.0
        meta = self._series._meta.str.extractall(pat, flags=flags)
        if PANDAS_VERSION < '0.23.0':
            index_name = self._series.index.name
            meta.index = pd.MultiIndex(levels=[[], []],
                                       labels=[[], []],
                                       names=[index_name, 'match'])
        return self._series.map_partitions(str_extractall, pat, flags,
                                           meta=meta, token='str-extractall')

    def __getitem__(self, index):
        return self._series.map_partitions(str_get, index,
                                           meta=self._series._meta)


def str_extractall(series, pat, flags):
    return series.str.extractall(pat, flags=flags)


def str_get(series, index):
    """ Implements series.str[index] """
    return series.str[index]


def str_cat(self, *others, **kwargs):
    return self.str.cat(others=others, **kwargs)
