#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
Test suite for the Graph class

Creates a graph from a matrix and then checks for correctness.
"""

import unittest
import os
import random
import string

from compClust.mlx.Graph import *
from compClust.mlx.UGraph import *
from compClust.mlx.Node import *

class GraphTestCases(unittest.TestCase):
    def setUp(self):
        pass

    def tearDown(self):
        pass

    def testGraphOps(self):

        adjList = [[1,2],
                   [0,2,3,4],
                   [0,1,4,8],
                   [1,5,6],
                   [1,2,7,8],
                   [3],
                   [3],
                   [4,8],
                   [2,4,7]]

        g = UGraph()
        for i in range(len(adjList)):
            g.addNode(SimpleNode(i))

        for i in range(len(adjList)):
            for j in adjList[i]:
                g.addEdge(i, j)

        count, order = g.BFS(0)

        #print
        #print "Count:     ", count
        #print "BFS Order: ", order
        #print "Bipartite: ", g.isBipartite()
        #print "Connected: ", g.isConnected()
        #print "Girth:     ", g.girth()
        #print "Diameter:  ", g.diameter()
        
        matrix = g.convertToAdjMatrix()
        #for i in matrix:
        #    print i

    def makeWeightedGraphFromList(self, adjlist, weights):

        n = len(adjlist)
        graph = Graph()

        for i in range(n):
            graph.addNode(SimpleNode(i))

        for i in range(n):
            for t in zip(adjlist[i], weights[i]):
                graph.addEdge(i, t[0], t[1])

        return graph
    
    def makeGraphFromList(self, adjlist):

        n = len(adjlist)
        graph = Graph()

        for i in range(n):
            graph.addNode(SimpleNode(i))

        for i in range(n):
            for j in adjlist[i]:
                graph.addEdge(i, j, 1)

        return graph

    def makeGraphFromMatrix(self, matrix):

        n     = len(matrix)
        graph = Graph()
        
        for i in range(n):
            graph.addNode(SimpleNode(i))

        for i in range(n):
            for j in range(n):
                if matrix[i][j] != 0:
                    graph.addEdge(i, j, matrix[i][j])

        return graph

    def testCyclic(self):

        matrix1 = [[2],
                   [3],
                   [0, 3],
                   [1, 2, 4, 5],
                   [3],
                   [3, 6, 7, 8],
                   [5],
                   [5],
                   [5]]

        matrix2 = [[1, 2],
                   [4, 5],
                   [3],
                   [6],
                   [6],
                   [3, 6],
                   []]

        matrix3 = [[1, 2],
                   [0, 4, 5],
                   [0, 3, 5],
                   [2, 6],
                   [1, 6, 7],
                   [1, 2, 7],
                   [3, 4, 8],
                   [4, 5, 8],
                   [6, 7]]

        graph = self.makeGraphFromList(matrix1)
        assert not graph.isAcyclic()
        assert graph.girth() == 10
        assert graph.isBipartite()

        graph = self.makeGraphFromList(matrix2)
        assert graph.isAcyclic()
        assert graph.girth() == 3
        assert not graph.isBipartite()

        graph = self.makeGraphFromList(matrix3)
        assert not graph.isAcyclic()
        assert graph.girth() == 4
        assert graph.isBipartite()

        
    def testDFS(self):

        matrix1 = [[1, 2],
                   [0, 2, 3, 4],
                   [0, 1, 4, 8],
                   [1, 5, 6],
                   [1, 2, 7, 8],
                   [3],
                   [3],
                   [4, 8],
                   [2, 4, 7]]

        matrix2 = [[1, 2, 3],
                   [6],
                   [0, 4],
                   [1, 5, 6],
                   [5],
                   [0, 2, 4, 6],
                   []]

        graph = self.makeGraphFromList(matrix1)

        cnt, order = graph.BFS(0)
        assert order == [0,1,2,3,4,8,5,6,7]
        
        cnt, pre, post = graph.DFS(0)
        assert pre == [0,1,2,4,7,8,3,5,6]
        assert post == [8,7,4,2,5,6,3,1,0]
        
        graph = self.makeGraphFromList(matrix2)

        cnt, order = graph.BFS(0)
        assert order == [0,1,2,3,6,4,5]

        cnt, pre, post = graph.DFS(0)
        assert pre == [0, 1, 6, 2, 4, 5, 3]
        assert post == [6, 1, 5, 4, 2, 3, 0]
        
    def testDiameter(self):

        matrix = [[0,1,1,0,0,0,1,0],
                  [1,0,0,1,0,0,0,1],
                  [1,0,0,1,1,0,0,0],
                  [0,1,1,0,0,1,0,0],
                  [0,0,1,0,0,1,1,0],
                  [0,0,0,1,1,0,0,1],
                  [1,0,0,0,1,0,0,1],
                  [0,1,0,0,0,1,1,0]]
                 
        graph = self.makeGraphFromMatrix(matrix)
        
        assert graph.diameter() == 3
        
    def testGraph(self):

        numNodes = 20
        matrix   = []
        graph    = UGraph()
        
        for i in range(numNodes):
            node = SimpleNode(i)
            graph.addNode(node)
            tmp = map(random.randrange, [2] * numNodes)
            if tmp == [0] * numNodes:
                tmp[random.randrange(numNodes)] = 1
            matrix.append(tmp)

        for i in range(numNodes):
            for j in range(i):
                matrix[i][j] = matrix[j][i]

        for i in range(numNodes):
            for j in range(numNodes):
                if matrix[i][j] == 1:
                    graph.addEdge(i, j)

        for i in range(numNodes):
            neighbors = graph.neighbors(i)
            for j in neighbors:
                assert matrix[i][j] == 1
                
            for j in range(numNodes):
                if matrix[i][j] == 1:
                    assert graph.edgeExists(i, j)
                else:
                    assert not graph.edgeExists(i, j)

        iter = graph.iterator(type='DFO')
        node = iter.next()
        order = []

        while node is not None:
            order.append(node.key())
            node = iter.next()

        assert len(order) == numNodes
        order.sort()
        assert order == range(numNodes)
        
        iter = graph.iterator(type='BFO')
        node = iter.next()
        order = []
        while node is not None:
            order.append(node.key())
            node = iter.next()
        assert len(order) == numNodes
        order.sort()
        assert order == range(numNodes)

        # remove all the nodes
        
        for i in range(numNodes):
            graph.removeNode(i)

        # Check that all the nodes were deleted
        
        iter = graph.iterator(type='BFO')
        assert iter.next() == None  

    def testDijkstra(self):

        matrix1 = [[1, 2],
                   [],
                   [1, 3],
                   [0, 1]]

        weight1 = [[1, 4],
                   [],
                   [2,5],
                   [2, 2]]
                 
        graph = self.makeWeightedGraphFromList(matrix1, weight1)

        assert graph.shortestPath(0, 1) == [1]
        assert graph.shortestPath(0, 2) == [2]
        assert graph.shortestPath(0, 3) == [2, 3]

        assert graph.shortestPath(1, 0) == []
        assert graph.shortestPath(1, 2) == []
        assert graph.shortestPath(1, 3) == []
        
        assert graph.shortestPath(2, 0) == [3, 0]
        assert graph.shortestPath(2, 1) == [1]
        assert graph.shortestPath(2, 3) == [3]
        
        assert graph.shortestPath(3, 0) == [0]
        assert graph.shortestPath(3, 1) == [1]
        assert graph.shortestPath(3, 2) == [0, 2]

def suite(**kw):

    suite = unittest.TestSuite()

    suite.addTest(GraphTestCases("testGraph"))
    #suite.addTest(GraphTestCases("testGraphOps"))
    suite.addTest(GraphTestCases("testDiameter"))
    suite.addTest(GraphTestCases("testDFS"))
    suite.addTest(GraphTestCases("testCyclic"))
    suite.addTest(GraphTestCases("testDijkstra"))
    return suite

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
  
