from __future__ import print_function, division

from collections import defaultdict

from sympy.core import (Add, Mul, Pow, Integer, Number, NumberSymbol,)
from sympy.core.compatibility import MutableMapping
from sympy.core.numbers import ImaginaryUnit
from sympy.core.sympify import _sympify
from sympy.core.rules import Transform
from sympy.core.logic import fuzzy_or, fuzzy_and
from sympy.matrices.expressions import MatMul

from sympy.functions.elementary.complexes import Abs

from sympy.assumptions.ask import Q
from sympy.assumptions.assume import Predicate, AppliedPredicate
from sympy.logic.boolalg import (Equivalent, Implies, And, Or,
    BooleanFunction, Not)

# APIs here may be subject to change

# XXX: Better name?
class UnevaluatedOnFree(BooleanFunction):
    """
    Represents a Boolean function that remains unevaluated on free predicates

    This is intended to be a superclass of other classes, which define the
    behavior on singly applied predicates.

    A free predicate is a predicate that is not applied, or a combination
    thereof. For example, Q.zero or Or(Q.positive, Q.negative).

    A singly applied predicate is a free predicate applied everywhere to a
    single expression. For instance, Q.zero(x) and Or(Q.positive(x*y),
    Q.negative(x*y)) are singly applied, but Or(Q.positive(x), Q.negative(y))
    and Or(Q.positive, Q.negative(y)) are not.

    The boolean literals True and False are considered to be both free and
    singly applied.

    This class raises ValueError unless the input is a free predicate or a
    singly applied predicate.

    On a free predicate, this class remains unevaluated. On a singly applied
    predicate, the method apply() is called and returned, or the original
    expression returned if apply() returns None. When apply() is called,
    self.expr is set to the unique expression that the predicates are applied
    at. self.pred is set to the free form of the predicate.

    The typical usage is to create this class with free predicates and
    evaluate it using .rcall().

    """
    def __new__(cls, arg):
        # Mostly type checking here
        arg = _sympify(arg)
        predicates = arg.atoms(Predicate)
        applied_predicates = arg.atoms(AppliedPredicate)
        if predicates and applied_predicates:
            raise ValueError("arg must be either completely free or singly applied")
        if not applied_predicates:
            obj = BooleanFunction.__new__(cls, arg)
            obj.pred = arg
            obj.expr = None
            return obj
        predicate_args = {pred.args[0] for pred in applied_predicates}
        if len(predicate_args) > 1:
            raise ValueError("The AppliedPredicates in arg must be applied to a single expression.")
        obj = BooleanFunction.__new__(cls, arg)
        obj.expr = predicate_args.pop()
        obj.pred = arg.xreplace(Transform(lambda e: e.func, lambda e:
            isinstance(e, AppliedPredicate)))
        applied = obj.apply()
        if applied is None:
            return obj
        return applied

    def apply(self):
        return


class AllArgs(UnevaluatedOnFree):
    """
    Class representing vectorizing a predicate over all the .args of an
    expression

    See the docstring of UnevaluatedOnFree for more information on this
    class.

    The typical usage is to evaluate predicates with expressions using .rcall().

    Example
    =======

    >>> from sympy.assumptions.sathandlers import AllArgs
    >>> from sympy import symbols, Q
    >>> x, y = symbols('x y')
    >>> a = AllArgs(Q.positive | Q.negative)
    >>> a
    AllArgs(Q.negative | Q.positive)
    >>> a.rcall(x*y)
    (Q.negative(x) | Q.positive(x)) & (Q.negative(y) | Q.positive(y))
    """

    def apply(self):
        return And(*[self.pred.rcall(arg) for arg in self.expr.args])


class AnyArgs(UnevaluatedOnFree):
    """
    Class representing vectorizing a predicate over any of the .args of an
    expression.

    See the docstring of UnevaluatedOnFree for more information on this
    class.

    The typical usage is to evaluate predicates with expressions using .rcall().

    Example
    =======

    >>> from sympy.assumptions.sathandlers import AnyArgs
    >>> from sympy import symbols, Q
    >>> x, y = symbols('x y')
    >>> a = AnyArgs(Q.positive & Q.negative)
    >>> a
    AnyArgs(Q.negative & Q.positive)
    >>> a.rcall(x*y)
    (Q.negative(x) & Q.positive(x)) | (Q.negative(y) & Q.positive(y))
    """

    def apply(self):
        return Or(*[self.pred.rcall(arg) for arg in self.expr.args])


class ExactlyOneArg(UnevaluatedOnFree):
    """
    Class representing a predicate holding on exactly one of the .args of an
    expression.

    See the docstring of UnevaluatedOnFree for more information on this
    class.

    The typical usage is to evaluate predicate with expressions using
    .rcall().

    Example
    =======

    >>> from sympy.assumptions.sathandlers import ExactlyOneArg
    >>> from sympy import symbols, Q
    >>> x, y = symbols('x y')
    >>> a = ExactlyOneArg(Q.positive)
    >>> a
    ExactlyOneArg(Q.positive)
    >>> a.rcall(x*y)
    (Q.positive(x) & ~Q.positive(y)) | (Q.positive(y) & ~Q.positive(x))
    """
    def apply(self):
        expr = self.expr
        pred = self.pred
        pred_args = [pred.rcall(arg) for arg in expr.args]
        # Technically this is xor, but if one term in the disjunction is true,
        # it is not possible for the remainder to be true, so regular or is
        # fine in this case.
        return Or(*[And(pred_args[i], *map(Not, pred_args[:i] +
            pred_args[i+1:])) for i in range(len(pred_args))])
        # Note: this is the equivalent cnf form. The above is more efficient
        # as the first argument of an implication, since p >> q is the same as
        # q | ~p, so the the ~ will convert the Or to and, and one just needs
        # to distribute the q across it to get to cnf.

        # return And(*[Or(*map(Not, c)) for c in combinations(pred_args, 2)]) & Or(*pred_args)


def _old_assump_replacer(obj):
    # Things to be careful of:
    # - real means real or infinite in the old assumptions.
    # - nonzero does not imply real in the old assumptions.
    # - finite means finite and not zero in the old assumptions.
    if not isinstance(obj, AppliedPredicate):
        return obj

    e = obj.args[0]
    ret = None

    if obj.func == Q.positive:
        ret = fuzzy_and([e.is_finite, e.is_positive])
    if obj.func == Q.zero:
        ret = e.is_zero
    if obj.func == Q.negative:
        ret = fuzzy_and([e.is_finite, e.is_negative])
    if obj.func == Q.nonpositive:
        ret = fuzzy_and([e.is_finite, e.is_nonpositive])
    if obj.func == Q.nonzero:
        ret = fuzzy_and([e.is_nonzero, e.is_finite])
    if obj.func == Q.nonnegative:
        ret = fuzzy_and([fuzzy_or([e.is_zero, e.is_finite]),
        e.is_nonnegative])

    if obj.func == Q.rational:
        ret = e.is_rational
    if obj.func == Q.irrational:
        ret = e.is_irrational

    if obj.func == Q.even:
        ret = e.is_even
    if obj.func == Q.odd:
        ret = e.is_odd
    if obj.func == Q.integer:
        ret = e.is_integer
    if obj.func == Q.imaginary:
        ret = e.is_imaginary
    if obj.func == Q.commutative:
        ret = e.is_commutative

    if ret is None:
        return obj
    return ret


def evaluate_old_assump(pred):
    """
    Replace assumptions of expressions replaced with their values in the old
    assumptions (like Q.negative(-1) => True). Useful because some direct
    computations for numeric objects is defined most conveniently in the old
    assumptions.

    """
    return pred.xreplace(Transform(_old_assump_replacer))


class CheckOldAssump(UnevaluatedOnFree):
    def apply(self):
        return Equivalent(self.args[0], evaluate_old_assump(self.args[0]))


class CheckIsPrime(UnevaluatedOnFree):
    def apply(self):
        from sympy import isprime
        return Equivalent(self.args[0], isprime(self.expr))


class CustomLambda(object):
    """
    Interface to lambda with rcall

    Workaround until we get a better way to represent certain facts.
    """
    def __init__(self, lamda):
        self.lamda = lamda

    def rcall(self, *args):
        return self.lamda(*args)


class ClassFactRegistry(MutableMapping):
    """
    Register handlers against classes

    ``registry[C] = handler`` registers ``handler`` for class
    ``C``. ``registry[C]`` returns a set of handlers for class ``C``, or any
    of its superclasses.
    """
    def __init__(self, d=None):
        d = d or {}
        self.d = defaultdict(frozenset, d)
        super(ClassFactRegistry, self).__init__()

    def __setitem__(self, key, item):
        self.d[key] = frozenset(item)

    def __getitem__(self, key):
        ret = self.d[key]
        for k in self.d:
            if issubclass(key, k):
                ret |= self.d[k]
        return ret

    def __delitem__(self, key):
        del self.d[key]

    def __iter__(self):
        return self.d.__iter__()

    def __len__(self):
        return len(self.d)

    def __repr__(self):
        return repr(self.d)


fact_registry = ClassFactRegistry()


def register_fact(klass, fact, registry=fact_registry):
    registry[klass] |= {fact}


for klass, fact in [
    (Mul, Equivalent(Q.zero, AnyArgs(Q.zero))),
    (MatMul, Implies(AllArgs(Q.square), Equivalent(Q.invertible, AllArgs(Q.invertible)))),
    (Add, Implies(AllArgs(Q.positive), Q.positive)),
    (Add, Implies(AllArgs(Q.negative), Q.negative)),
    (Mul, Implies(AllArgs(Q.positive), Q.positive)),
    (Mul, Implies(AllArgs(Q.commutative), Q.commutative)),
    (Mul, Implies(AllArgs(Q.real), Q.commutative)),

    (Pow, CustomLambda(lambda power: Implies(Q.real(power.base) &
    Q.even(power.exp) & Q.nonnegative(power.exp), Q.nonnegative(power)))),
    (Pow, CustomLambda(lambda power: Implies(Q.nonnegative(power.base) & Q.odd(power.exp) & Q.nonnegative(power.exp), Q.nonnegative(power)))),
    (Pow, CustomLambda(lambda power: Implies(Q.nonpositive(power.base) & Q.odd(power.exp) & Q.nonnegative(power.exp), Q.nonpositive(power)))),

    # This one can still be made easier to read. I think we need basic pattern
    # matching, so that we can just write Equivalent(Q.zero(x**y), Q.zero(x) & Q.positive(y))
    (Pow, CustomLambda(lambda power: Equivalent(Q.zero(power), Q.zero(power.base) & Q.positive(power.exp)))),
    (Integer, CheckIsPrime(Q.prime)),
    # Implicitly assumes Mul has more than one arg
    # Would be AllArgs(Q.prime | Q.composite) except 1 is composite
    (Mul, Implies(AllArgs(Q.prime), ~Q.prime)),
    # More advanced prime assumptions will require inequalities, as 1 provides
    # a corner case.
    (Mul, Implies(AllArgs(Q.imaginary | Q.real), Implies(ExactlyOneArg(Q.imaginary), Q.imaginary))),
    (Mul, Implies(AllArgs(Q.real), Q.real)),
    (Add, Implies(AllArgs(Q.real), Q.real)),
    # General Case: Odd number of imaginary args implies mul is imaginary(To be implemented)
    (Mul, Implies(AllArgs(Q.real), Implies(ExactlyOneArg(Q.irrational),
        Q.irrational))),
    (Add, Implies(AllArgs(Q.real), Implies(ExactlyOneArg(Q.irrational),
        Q.irrational))),
    (Mul, Implies(AllArgs(Q.rational), Q.rational)),
    (Add, Implies(AllArgs(Q.rational), Q.rational)),

    (Abs, Q.nonnegative),
    (Abs, Equivalent(AllArgs(~Q.zero), ~Q.zero)),

    # Including the integer qualification means we don't need to add any facts
    # for odd, since the assumptions already know that every integer is
    # exactly one of even or odd.
    (Mul, Implies(AllArgs(Q.integer), Equivalent(AnyArgs(Q.even), Q.even))),

    (Abs, Implies(AllArgs(Q.even), Q.even)),
    (Abs, Implies(AllArgs(Q.odd), Q.odd)),

    (Add, Implies(AllArgs(Q.integer), Q.integer)),
    (Add, Implies(ExactlyOneArg(~Q.integer), ~Q.integer)),
    (Mul, Implies(AllArgs(Q.integer), Q.integer)),
    (Mul, Implies(ExactlyOneArg(~Q.rational), ~Q.integer)),
    (Abs, Implies(AllArgs(Q.integer), Q.integer)),

    (Number, CheckOldAssump(Q.negative)),
    (Number, CheckOldAssump(Q.zero)),
    (Number, CheckOldAssump(Q.positive)),
    (Number, CheckOldAssump(Q.nonnegative)),
    (Number, CheckOldAssump(Q.nonzero)),
    (Number, CheckOldAssump(Q.nonpositive)),
    (Number, CheckOldAssump(Q.rational)),
    (Number, CheckOldAssump(Q.irrational)),
    (Number, CheckOldAssump(Q.even)),
    (Number, CheckOldAssump(Q.odd)),
    (Number, CheckOldAssump(Q.integer)),
    (Number, CheckOldAssump(Q.imaginary)),
    # For some reason NumberSymbol does not subclass Number
    (NumberSymbol, CheckOldAssump(Q.negative)),
    (NumberSymbol, CheckOldAssump(Q.zero)),
    (NumberSymbol, CheckOldAssump(Q.positive)),
    (NumberSymbol, CheckOldAssump(Q.nonnegative)),
    (NumberSymbol, CheckOldAssump(Q.nonzero)),
    (NumberSymbol, CheckOldAssump(Q.nonpositive)),
    (NumberSymbol, CheckOldAssump(Q.rational)),
    (NumberSymbol, CheckOldAssump(Q.irrational)),
    (NumberSymbol, CheckOldAssump(Q.imaginary)),
    (ImaginaryUnit, CheckOldAssump(Q.negative)),
    (ImaginaryUnit, CheckOldAssump(Q.zero)),
    (ImaginaryUnit, CheckOldAssump(Q.positive)),
    (ImaginaryUnit, CheckOldAssump(Q.nonnegative)),
    (ImaginaryUnit, CheckOldAssump(Q.nonzero)),
    (ImaginaryUnit, CheckOldAssump(Q.nonpositive)),
    (ImaginaryUnit, CheckOldAssump(Q.rational)),
    (ImaginaryUnit, CheckOldAssump(Q.irrational)),
    (ImaginaryUnit, CheckOldAssump(Q.imaginary))
    ]:

    register_fact(klass, fact)
