#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
Test suite for the MultiWayTree class

Builds a bunch of random trees, and then traverses them with a recursive and
iterative traversal.  The two walks are then compared for equality
"""

import unittest
import os
import random
import string

from compClust.mlx.MultiWayTree import *
from compClust.mlx.Node import *

class MultiWayTreeTestCases(unittest.TestCase):
    def setUp(self):
        """ creates a random multiway tree with 100 nodes """

        numNodes  = 100
        tree      = MultiWayTree() 

        root      = Node(0, '0')   
        tree.insert(root)         
        
        nodeList  = [root]
    
        for i in range(1, numNodes):
            newNode = Node(i, str(i))
            tree.append(random.choice(nodeList).key(), newNode)
            nodeList.append(newNode)

        self.tree = tree
        self.iter = tree.iterator()
        
    def tearDown(self):
        pass
        
    def lcr(self, tree, key):
        l = []
        if key != None:
            n = tree.neighbors(key)
            c = len(n)
            j = range(c)
            for i in j[0:c/2]:
                l += self.lcr(tree, n[i])
            l += tree.find(key).value()
            for i in j[c/2:c]:
                l += self.lcr(tree, n[i])

        return l

    def rcl(self, tree, key):
        l = []
        if key != None:
            n = tree.neighbors(key)
            c = len(n)
            if (c == 1):
                l += self.rcl(tree, n[0])
                l += tree.find(key).value()
            else:
                j = range(c)
                j.reverse()
                for i in j[0:c/2]:
                    l += self.rcl(tree, n[i])
                l += tree.find(key).value()
                for i in j[c/2:c]:
                    l += self.rcl(tree, n[i])

        return l

            
    def clr(self, tree, key):
        l = []
        if key != None:
            l += tree.find(key).value()
            n  = tree.neighbors(key)
            for i in range(len(n)):
                l += self.clr(tree, n[i])
        return l

    def crl(self, tree, key):
        l = []
        if key != None:
            node = tree.find(key)
            l   += node.value()
            n    = tree.neighbors(key)
            for i in range(len(n)-1, -1, -1):
                l += self.crl(tree, n[i])
        return l

    def lrc(self, tree, key):
        l = []
        if key != None:
            node = tree.find(key)
            n    = tree.neighbors(key)
            for i in range(len(n)):
                l += self.lrc(tree, n[i])
            l += node.value()
        return l

    def rlc(self, tree, key):
        l = []
        if key != None:
            n = tree.neighbors(key)
            for i in range(len(n)-1, -1, -1):
                l += self.rlc(tree, n[i])
            l += tree.find(key).value()
        return l

    def iterate(self, tree, mode):
        l     = []
        iter  = tree.iterator(mode)
        tmp   = iter.next()
        while tmp != None:
            l   += tmp.value()
            tmp  = iter.next()
        return l


    def testCRL(self):
        recursive = string.join(self.crl(self.tree, self.tree.root()), '')    
        iterative = string.join(self.iterate(self.tree, 'CRL'),'')

        if iterative != recursive:
            self.fail("CRL iterator did not give correct results.")
            
    def testCLR(self):
        recursive = string.join(self.clr(self.tree, self.tree.root()), '')    
        iterative = string.join(self.iterate(self.tree, 'CLR'),'')

        if iterative != recursive:
            self.fail("CLR iterator did not give correct results.")

    def testRLC(self):
        recursive = string.join(self.rlc(self.tree, self.tree.root()), '')    
        iterative = string.join(self.iterate(self.tree, 'RLC'),'')

        if iterative != recursive:
            self.fail("RLC iterator did not give correct results.")

    def testLRC(self):
        recursive = string.join(self.lrc(self.tree, self.tree.root()), '')    
        iterative = string.join(self.iterate(self.tree, 'LRC'),'')

        if iterative != recursive:
            self.fail("LRC iterator did not give correct results.")

    def testRCL(self):
        recursive = string.join(self.rcl(self.tree, self.tree.root()), '')    
        iterative = string.join(self.iterate(self.tree, 'RCL'),'')

        if iterative != recursive:
            self.fail("RCL iterator did not give correct results.")

    def testLCR(self):
        recursive = string.join(self.lcr(self.tree, self.tree.root()), '')    
        iterative = string.join(self.iterate(self.tree, 'LCR'),'')

        self.tree.prettyPrint()

        if iterative != recursive:
            self.fail("LCR iterator did not give correct results.")

    def testDepth(self):
        root = self.tree.root()
        if self.tree.depth(root) != 0:
            self.fail("Depth of root not consistent.")

            
def suite(**kw):

    suite = unittest.TestSuite()

    suite.addTest(MultiWayTreeTestCases("testCRL"))
    suite.addTest(MultiWayTreeTestCases("testCLR"))
    suite.addTest(MultiWayTreeTestCases("testRLC"))
    suite.addTest(MultiWayTreeTestCases("testLRC"))
    suite.addTest(MultiWayTreeTestCases("testRCL"))
    #suite.addTest(MultiWayTreeTestCases("testLCR"))
    suite.addTest(MultiWayTreeTestCases("testDepth"))

    return suite

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
  

