from __future__ import print_function, division

from sympy.core.basic import Basic
from sympy.core.expr import Expr
from sympy.core.symbol import Symbol
from sympy.core.numbers import Integer, Rational, Float
from sympy.core.compatibility import default_sort_key
from sympy.core.add import Add
from sympy.core.mul import Mul

__all__ = ['dotprint']

default_styles = ((Basic, {'color': 'blue', 'shape': 'ellipse'}),
          (Expr,  {'color': 'black'}))


sort_classes = (Add, Mul)
slotClasses = (Symbol, Integer, Rational, Float)
# XXX: Why not just use srepr()?
def purestr(x):
    """ A string that follows obj = type(obj)(*obj.args) exactly """
    if not isinstance(x, Basic):
        return str(x)
    if type(x) in slotClasses:
        args = [getattr(x, slot) for slot in x.__slots__]
    elif type(x) in sort_classes:
        args = sorted(x.args, key=default_sort_key)
    else:
        args = x.args
    return "%s(%s)"%(type(x).__name__, ', '.join(map(purestr, args)))


def styleof(expr, styles=default_styles):
    """ Merge style dictionaries in order

    >>> from sympy import Symbol, Basic, Expr
    >>> from sympy.printing.dot import styleof
    >>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
    ...           (Expr,  {'color': 'black'})]

    >>> styleof(Basic(1), styles)
    {'color': 'blue', 'shape': 'ellipse'}

    >>> x = Symbol('x')
    >>> styleof(x + 1, styles)  # this is an Expr
    {'color': 'black', 'shape': 'ellipse'}
    """
    style = dict()
    for typ, sty in styles:
        if isinstance(expr, typ):
            style.update(sty)
    return style

def attrprint(d, delimiter=', '):
    """ Print a dictionary of attributes

    >>> from sympy.printing.dot import attrprint
    >>> print(attrprint({'color': 'blue', 'shape': 'ellipse'}))
    "color"="blue", "shape"="ellipse"
    """
    return delimiter.join('"%s"="%s"'%item for item in sorted(d.items()))

def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True):
    """ String defining a node

    >>> from sympy.printing.dot import dotnode
    >>> from sympy.abc import x
    >>> print(dotnode(x))
    "Symbol(x)_()" ["color"="black", "label"="x", "shape"="ellipse"];
    """
    style = styleof(expr, styles)

    if isinstance(expr, Basic) and not expr.is_Atom:
        label = str(expr.__class__.__name__)
    else:
        label = labelfunc(expr)
    style['label'] = label
    expr_str = purestr(expr)
    if repeat:
        expr_str += '_%s' % str(pos)
    return '"%s" [%s];' % (expr_str, attrprint(style))


def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True):
    """ List of strings for all expr->expr.arg pairs

    See the docstring of dotprint for explanations of the options.

    >>> from sympy.printing.dot import dotedges
    >>> from sympy.abc import x
    >>> for e in dotedges(x+2):
    ...     print(e)
    "Add(Integer(2), Symbol(x))_()" -> "Integer(2)_(0,)";
    "Add(Integer(2), Symbol(x))_()" -> "Symbol(x)_(1,)";
    """
    if atom(expr):
        return []
    else:
        # TODO: This is quadratic in complexity (purestr(expr) already
        # contains [purestr(arg) for arg in expr.args]).
        expr_str = purestr(expr)
        arg_strs = [purestr(arg) for arg in expr.args]
        if repeat:
            expr_str += '_%s' % str(pos)
            arg_strs = [arg_str + '_%s' % str(pos + (i,)) for i, arg_str in enumerate(arg_strs)]
        return ['"%s" -> "%s";' % (expr_str, arg_str) for arg_str in arg_strs]

template = \
"""digraph{

# Graph style
%(graphstyle)s

#########
# Nodes #
#########

%(nodes)s

#########
# Edges #
#########

%(edges)s
}"""

_graphstyle = {'rankdir': 'TD', 'ordering': 'out'}

def dotprint(expr, styles=default_styles, atom=lambda x: not isinstance(x,
    Basic), maxdepth=None, repeat=True, labelfunc=str, **kwargs):
    """
    DOT description of a SymPy expression tree

    Options are

    ``styles``: Styles for different classes.  The default is::

        [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
        (Expr, {'color': 'black'})]``

    ``atom``: Function used to determine if an arg is an atom.  The default is
          ``lambda x: not isinstance(x, Basic)``.  Another good choice is
          ``lambda x: not x.args``.

    ``maxdepth``: The maximum depth.  The default is None, meaning no limit.

    ``repeat``: Whether to different nodes for separate common subexpressions.
          The default is True.  For example, for ``x + x*y`` with
          ``repeat=True``, it will have two nodes for ``x`` and with
          ``repeat=False``, it will have one (warning: even if it appears
          twice in the same object, like Pow(x, x), it will still only appear
          only once.  Hence, with repeat=False, the number of arrows out of an
          object might not equal the number of args it has).

    ``labelfunc``: How to label leaf nodes.  The default is ``str``.  Another
          good option is ``srepr``. For example with ``str``, the leaf nodes
          of ``x + 1`` are labeled, ``x`` and ``1``.  With ``srepr``, they
          are labeled ``Symbol('x')`` and ``Integer(1)``.

    Additional keyword arguments are included as styles for the graph.

    Examples
    ========

    >>> from sympy.printing.dot import dotprint
    >>> from sympy.abc import x
    >>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE
    digraph{
    <BLANKLINE>
    # Graph style
    "ordering"="out"
    "rankdir"="TD"
    <BLANKLINE>
    #########
    # Nodes #
    #########
    <BLANKLINE>
    "Add(Integer(2), Symbol(x))_()" ["color"="black", "label"="Add", "shape"="ellipse"];
    "Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"];
    "Symbol(x)_(1,)" ["color"="black", "label"="x", "shape"="ellipse"];
    <BLANKLINE>
    #########
    # Edges #
    #########
    <BLANKLINE>
    "Add(Integer(2), Symbol(x))_()" -> "Integer(2)_(0,)";
    "Add(Integer(2), Symbol(x))_()" -> "Symbol(x)_(1,)";
    }

    """
    # repeat works by adding a signature tuple to the end of each node for its
    # position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the
    # Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0].
    graphstyle = _graphstyle.copy()
    graphstyle.update(kwargs)

    nodes = []
    edges = []
    def traverse(e, depth, pos=()):
        nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat))
        if maxdepth and depth >= maxdepth:
            return
        edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat))
        [traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)]
    traverse(expr, 0)

    return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'),
                     'nodes': '\n'.join(nodes),
                     'edges': '\n'.join(edges)}
