########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Author: Lucas J. Scharenbroich
# 
# Original Implementation:  July 10 by Lucas Scharenbroich
#
##########################################################################

"""
A graph data structure.

A graph is a collection of vertices connected by edges.  This class is a
directed graph which can take a weight for the connections between nodes.
"""

import Numeric

from Collection import *
from compClust.util import NaN

class GraphIterator:
    """
    An iterator for a Graph class.

    This iterator can traverse in either Depth-first or Breadth-first order.
    """
    
    def __init__(self, graph, key, type='DFO'):
        self.type    = type
        self.graph   = graph
        self.marked  = {}
        self.stack   = []
        
        if key is not None:
            self.stack.append(key)
            self.marked[key] = 1

    
    def next(self):
        if self.type == 'DFO':
            return self.__next_dfs()
        elif self.type == 'BFO':
            return self.__next_bfs()
        return None


    def __next_dfs(self):

        if len(self.stack) == 0:
            return None

        # treat the list like a stack

        current = self.stack.pop()

        for key in self.graph.neighbors(current):
            if not self.marked.has_key(key):
                self.stack.append(key)
                self.marked[key] = 1

        return self.graph.find(current)

    def __next_bfs(self):

        if len(self.stack) == 0:
            return None

        # treat the list like a queue

        current = self.stack.pop(0)

        for key in self.graph.neighbors(current):
            if not self.marked.has_key(key):
                self.stack.append(key)
                self.marked[key] = 1

        return self.graph.find(current)
    

class Graph(Collection):
    """
    A Graph ADT

    The connectivity of the graph is represented by a hash of key/list
    pairs. i.e.

    self.edges = {'A': ['B', 'C'],
                  'B': ['C', 'D'],
                  'C': ['D'],
                  'D': ['C'],
                  'E': ['F'],
                  'F': ['C']}

    If the graph is weighted, another hash is used which parallels the first

    self.weights = {'A': [0.5, 0.1],
                    'B': [1.0, 0.2],
                    'C': [2.0],
                    'D': [1.1],
                    'E': [2.3],
                    'F': [0.3]}
    """
    
    def __init__(self):
        Collection.__init__(self)
        self.edges = {}
        self.weights = {}

    def clear(self):
        self.edges.clear()
        self.weights.clear()
        Collection.clear(self)
        
    def _getRandKey(self):

        try:
            key = self.edges.keys()[0]
        except IndexError:
            key = None
        return key
    
    ########################################################################
    #
    # basic ops.
    #
    # Basic operations pertaining to edges and vertices
    #
    ########################################################################

    def addEdge(self, key1, key2, weight = 1):

        if key2 not in self.edges[key1]:
            self.edges[key1].append(key2)
            self.weights[key1].append(weight)
            
    def addNode(self, node):
        Collection.addNode(self, node)

        key = node.key()
        if not self.edges.has_key(key):
            self.edges[key] = []

        if not self.weights.has_key(key):
            self.weights[key] = []

    def removeEdge(self, key1, key2):

        i = self.edges[key1].index(key2)
        self.edges[key1].pop(i)
        self.weights[key1].pop(i)
        
    def removeNode(self, key):
        """
        Remove a node with the given key from the data structure.
        """
        
        #
        # Remove the node itself from the collection
        #

        Collection.removeNode(self, key)

        #
        # Remove any connections to the node
        #
        
        for k in self.edges.keys():
            try:
                self.removeEdge(k, key)
            except ValueError:
                pass
            
        #
        # Remove the node from the list of edges and weights
        #
        
        del self.edges[key]
        del self.weights[key]


    def convertToAdjMatrix(self):

        matrix  = []
        mapping = {}
        n       = self.order()
        keys    = self.edges.keys()

        keys.sort()

        for i in range(n):
            matrix.append([0] * n)
            mapping[keys[i]] = i

        for i in range(n):
            key = keys[i]
            pairs   = zip(self.edges[key], self.weights[key])

            for pair in pairs:
                matrix[i][mapping[pair[0]]] = pair[1]

        return matrix
            
    ##########################################################################
    #
    # Basic query operations
    #
    ##########################################################################


    def BFS(self, key):
        """
        BFS: Breadth-First Search

        Returns the number of nodes visited and the order in which they were
        visited.
        """
        
        iter  = self.iterator(key, 'BFO')
        node  = iter.next()
        order = []
        cnt   = 0
        
        while node:
            cnt += 1
            order.append(node.key())
            node = iter.next()

        return cnt, order

    def degree(self, key):
        """
        The degree is the number of edges incident to a vertex.  This is
        currently implemented in an inefficient manner.
        """

        sum = 0
        for edges in self.edges.values():
            if key in edges:
                sum += 1
        return sum

    def _doDFS(self, key, preorder, postorder, mark):

        preorder.append(key)
        mark[key] = 1
        
        nbrs = self.neighbors(key)
        for nbr in nbrs:
            if not mark.has_key(nbr):
                self._doDFS(nbr, preorder, postorder, mark)

        postorder.append(key)
        
    def DFS(self, key):
        """
        DFS: Depth-First Search

        Returns the number of nodes visited and the pre- and post-order in
        which they were visited.  This is implemented recursively.
        """

        preorder  = []
        postorder = []

        self._doDFS(key, preorder, postorder, {});

        return len(preorder), preorder, postorder

    def edgeExists(self, key1, key2):
        return key2 in self.edges[key1]

    def weight(self, key1, key2):
        i = self.edges[key1].index(key2)
        return self.weights[key1][i]
    
    def iterator(self, key=None, type='DFO'):

        if key is None:
            key = self._getRandKey()
            
        return GraphIterator(self, key, type)

    def maxDistance(self, key):

        n = self.order()
        
        marked = {}
        stack  = [key]
        depth  = [0]
        count  = 0
        
        while len(stack):
            current   = stack.pop(0)
            new_depth = depth.pop(0) + 1
            for key in self.neighbors(current):
                if not marked.has_key(key):
                    count += 1
                    stack.append(key)
                    depth.append(new_depth)
                    marked[key] = 1

        if count == n:
            n = new_depth - 1

        return n

    def neighbors(self, key):
        return self.edges[key]
  
    def preorder(self, key):

        preorder  = []
        self._doDFS(key, preorder, [], {});
        return preorder

    def postorder(self, key):

        postorder = []
        self._doDFS(key, [], postorder, {});
        return postorder
    
    #########################################################################
    #
    # graph properties
    #
    # Find characteristic measures of a graph as a whole.
    #
    #########################################################################

    def diameter(self):
        """
        The diameter of a graph is the largest distance between any two
        vertices.
        """
        
        max = 0
        
        for key in self.edges.keys():
            tmp = self.maxDistance(key)
            if tmp > max:
                max = tmp

        return max


    def girth(self):
        """
        The girth of a graph is the length of the smallest cycle.
        """

        from UGraph import UGraph
        
        G = UGraph(self)

        n    = G.order()
        best = n+1
        keys = G.edges.keys()

        for i in range(n-2):
            span     = {}
            depth    = 1
            distList = []
            key      = keys[i]
            
            span[key] = 1
            distList.append(key)

            while (depth*2 <= best and best > 3):

                nextList = []

                for e in distList:
                    for nbr in G.neighbors(e):
                        if not span.has_key(nbr):
                            span[nbr] = 1
                            nextList.append(nbr)
                        else:
                            if nbr in distList:
                                best = depth*2 - 1
                                break
                            if nbr in nextList:
                                best = depth*2

                distList = nextList
                depth += 1

        return best

    def order(self):
        """
        Returns the order of the graph which is defined as |V|, or the number
        or vertices (nodes) in the graph.
        """
        
        return len(self.nodes)

    def size(self):
        """
        Returns the size of the graph which is defined as |E|, or the number
        or edges (arcs) in the graph.
        """
        
        sum = 0
        for key in self.edges.keys():
            sum += len(self.edges[key])
        return sum
    
    #########################################################################
    #
    # graph query operations
    #
    # Determine if a graph fullfills some specific properties.
    #
    #########################################################################

    def isAcyclic(self):
        """
        """
        
        n = self.order()

        span = {}

        for key in self.edges.keys():
            if span.has_key(key):
                continue

            marked = {}
            postorder = []
            self._doDFS(key, [], postorder, marked)

            span.update(marked)

            for j in postorder:
                for u in self.neighbors(j):
                    if postorder.index(u) > postorder.index(j):
                        return 0

        return 1

    def isBipartite(self):
        """
        A bipartite graph can be partitioned into two sets in which the
        vertices in each set have no edges between themselves.
        """

        from UGraph import UGraph
        
        G = UGraph(self)
        
        color = {}

        iter = G.iterator(type='BFO')
        node = iter.next()

        if node:
            color[node.key()] = 1

        while node:
            neighbors = G.neighbors(node.key())
            ncolor    = color[node.key()]

            for key in neighbors:
                if not color.has_key(key):
                    color[key] = 3 - ncolor
                else:
                    if color[key] == ncolor:
                        return 0
                    
            node = iter.next()
            
        return 1

    def isComplete(self):
        """
        A graph is complete is every vertex is connected to each other.
        """

        complete = 1
        n = self.order()

        for edges in self.edges.values():
            if len(edges) != n:
                complete = 0
                break

        return complete
    
    def isConnected(self):
        """
        A graph is connected if there is a path to every node from a given
        node.  For undirected graphs, this also satisfies the test of strong
        connectivity.  For directed graphs, use the isStronglyConnected()
        method.
        """

        from UGraph import UGraph

        G = UGraph(self)
        return G.isConnected()
    
    def isRegular(self):
        """
        A graph is regular is every vertex has the same degree.
        """

        regular = 1

        if self.order() > 0:
            
            keys = self.edges.keys()
            deg  = self.degree[keys[0]]
        
            for i in range(1, len(keys)):
                if deg != self.degree(keys[i]):
                    regular = 0
                    break

        return regular
    
    def isStronglyConnected(self):
        """
        A graph is strongly connected is there is a path to every vertex from
        any other vertex.  For a directed graph this can be efficiently
        checked via two applications of a BFS.
        """

        n     = self.order()
        start = self._getRandKey()

        count, preorder = self.BFS(start)

        if count != n:
            return 0

        # Build another graph with the edges reversed

        graph = Graph()
        for key in self.edges.keys():
            graph.addNode(SimpleNode(key))
            
        for key in self.edges.keys():
            for k in self.neighbors(key):
                graph.addEdge(k, key)

        count, preorder = graph.BFS(start)

        return count == n
    
    #########################################################################
    #
    # Shortest Path and Flow algorithms
    #
    # Some of these algorithms require a graph with non-negative wieghts
    #
    #########################################################################
    
    def Dijkstra(self, key, maxcost=NaN.inf):
        """
        Returns a dictionary of keys and their minimum distance from the
        given key.  This implementation does not use
        """

        from compClust.mlx.PQueue import PriorityQueue        

        pq       = PriorityQueue()
        visited  = {key:1}
        settled  = {}
        previous = {}
        distance = {key:0}

        #
        # Add the first vertex
        #
        
        pq.append(key, 0)
        while len(pq):

            #
            # Get the vertex with the smallest weight
            #
            
            v = pq.pop(0)
            if not settled.has_key(v):
                settled[v] = 1
                for nbr in self.neighbors(v):
                    c = self.weight(v, nbr)
                    d = distance[v] + c
                    if not visited.has_key(nbr) or distance[nbr] > d:
                        visited[nbr]  = 1
                        distance[nbr] = d
                        previous[nbr] = v
                        pq.append(nbr, d)


        return distance, previous
        
    def shortestPath(self, key1, key2):
        """
        Find a single shortest path from the given start key to the given
        end key. The output is a list of the keys in order along the
        shortest path.
        """

        path = []
        d, p = self.Dijkstra(key1)

        while p.has_key(key2):
            path.append(key2)
            key2 = p[key2]

        path.reverse()
        return path
                
    def BellmanFord(self):
        pass
    
    def FloydWarshall(self):
        pass

    def Johnson(self):
        pass
    
    def Kruskal(self):
        pass

    def Prim(self):
        pass
        
