'''
Created on Sep 10, 2010

@author: sau
'''
import unittest
import os
from erange.rnapath import RNAPATH

compDict = {"A": "T",
            "T": "A",
            "G": "C",
            "C": "G",
            "S": "S",
            "W": "W",
            "R": "Y",
            "Y": "R",
            "M": "K",
            "K": "M",
            "H": "D",
            "D": "H",
            "B": "V",
            "V": "B",
            "N": "N",
            "a": "t",
            "t": "a",
            "g": "c",
            "c": "g",
            "n": "n",
            "z": "z"
}


class TestRNAPATH(unittest.TestCase):
    incontigfilename = "contigIn.txt"
    distalPairsfile = "distalPair.txt"
    outpathfilename = "rnapathOut.txt"
    outcontigfilename = "contigOut.txt"

    def setUp(self):
        inContigs = open(self.incontigfilename, "w")
        inContigs.close()

        distal = open(self.distalPairsfile, "w")
        distal.close()


    def tearDown(self):
        try:
            os.remove(self.incontigfilename)
        except OSError:
            pass

        try:
            os.remove(self.distalPairsfile)
        except OSError:
            pass

        try:
            os.remove(self.outpathfilename)
        except OSError:
            pass

        try:
            os.remove(self.outcontigfilename)
        except OSError:
            pass


    def testCompNT(self):
        for nt in compDict.keys():
            self.assertEquals(compDict[nt], RNAPATH.compNT(nt))

        self.assertEquals("N", RNAPATH.compNT("5"))
        self.assertEquals("N", RNAPATH.compNT("anything"))


    def testComplement(self):
        self.assertEquals("", RNAPATH.complement(""))
        for nt in compDict.keys():
            self.assertEquals(compDict[nt], RNAPATH.complement(nt))

        self.assertEquals("TGTAATC", RNAPATH.complement("GATTACA"))
        self.assertEquals("TGTAATC", RNAPATH.complement("GATTACA", 7))
        self.assertEquals("TGTAATC", RNAPATH.complement("GATTACA", -75632))
        self.assertEquals("TGTA", RNAPATH.complement("GATTACA", 4))

        #TODO: do we want to return when length > seqlength?  This is
        # the current return and it seems very wrong we only N fill
        # after going more then seqlength in negative direction
        self.assertEquals("TGTAATCTG", RNAPATH.complement("GATTACA", 9))
        self.assertEquals("TGTAATCTGTAATCNNNNN", RNAPATH.complement("GATTACA", 19))

    #TODO: write test
    def testRnaPath(self):
        RNAPATH.rnaPath(self.incontigfilename, self.distalPairsfile, self.outpathfilename, self.outcontigfilename)
        outfile = open(self.outpathfilename)
        self.assertTrue("#settings:" in outfile.readline())
        self.assertEquals("", outfile.readline())
        outfile.close()
        outcontig = open(self.outcontigfilename)
        self.assertEquals(0, len(outcontig.readlines()))
        outcontig.close()

        #infile = open(self.incontigfilename, "w")
        #infile.write(">chr1 stuff\n")
        #infile.write("GATTACA\n")
        #infile.close()
        #RNAPATH.rnaPath(self.incontigfilename, self.distalPairsfile, self.outpathfilename, self.outcontigfilename)
        #outfile = open(self.outpathfilename)
        #self.assertTrue("#settings:" in outfile.readline())
        #self.assertEquals("", outfile.readline())
        #outfile.close()


    #TODO: write test
    def testGetPath(self):
        pass


    #TODO: write test
    def testTraverseGraph(self):
        leafList = []
        edgeMatrix = RNAPATH.EdgeMatrix(0)
        pathList, visitedDict = RNAPATH.traverseGraph(leafList, edgeMatrix)
        self.assertEquals([], pathList)
        self.assertEquals({}, visitedDict)

        leafList = [1]
        edgeMatrix = RNAPATH.EdgeMatrix(3)
        edgeMatrix.edgeArray[2][1] = 3
        edgeMatrix.edgeArray[1][2] = 3
        pathList, visitedDict = RNAPATH.traverseGraph(leafList, edgeMatrix)
        self.assertEquals([ [1, 2] ], pathList)
        self.assertEquals({1: "", 2: ""}, visitedDict)

        leafList = [1, 2]
        edgeMatrix = RNAPATH.EdgeMatrix(3)
        edgeMatrix.edgeArray[2][1] = 3
        edgeMatrix.edgeArray[1][2] = 3
        pathList, visitedDict = RNAPATH.traverseGraph(leafList, edgeMatrix)
        self.assertEquals([ [1, 2] ], pathList)
        self.assertEquals({1: "", 2: ""}, visitedDict)


    #TODO: write test
    def testGetContigsFromFile(self):
        contigNum, nameList, contigDict, origSize = RNAPATH.getContigsFromFile(self.incontigfilename)
        self.assertEquals(0, contigNum)
        self.assertEquals([], nameList)
        self.assertEquals({}, contigDict)
        self.assertEquals([], origSize)


    #TODO: check for boundary condition and special cases
    def testEdgeMatrix(self):
        edgeMatrix = RNAPATH.EdgeMatrix(0)
        result = "[]"
        self.assertEquals(result, str(edgeMatrix.edgeArray))

        edgeMatrix = RNAPATH.EdgeMatrix(3)
        result = "[[0 0 0]\n [0 0 0]\n [0 0 0]]"
        self.assertEquals(result, str(edgeMatrix.edgeArray))
        self.assertEquals([], edgeMatrix.visitLink(0))

        edgeMatrix.edgeArray[0][1] = 1
        self.assertEquals([], edgeMatrix.visitLink(0))

        edgeMatrix.edgeArray[0][1] = 2
        result = [0]
        self.assertEquals(result, edgeMatrix.visitLink(0))

        edgeMatrix.edgeArray[2][1] = 2
        result = []
        self.assertEquals(result, edgeMatrix.visitLink(0))
        edgeMatrix.edgeArray[2][1] = 2
        result = []
        self.assertEquals(result, edgeMatrix.visitLink(1))
        edgeMatrix.edgeArray[2][1] = 2
        result = [2]
        self.assertEquals(result, edgeMatrix.visitLink(2))

        edgeMatrix.edgeArray[2][1] = 3
        edgeMatrix.edgeArray[1][2] = 3
        result = [1, 2]
        self.assertEquals(result, edgeMatrix.visitLink(1))


    def testMain(self):
        argv = ["RNAPATH", self.incontigfilename, self.distalPairsfile, self.outpathfilename, self.outcontigfilename]
        RNAPATH.main(argv)
        outfile = open(self.outpathfilename)
        self.assertTrue("#settings:" in outfile.readline())
        self.assertEquals("", outfile.readline())
        outfile.close()
        outcontig = open(self.outcontigfilename)
        self.assertEquals(0, len(outcontig.readlines()))
        outcontig.close()


def suite():
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(TestRNAPATH))

    return suite


if __name__ == "__main__":
    #import sys;sys.argv = ['', 'Test.testName']
    unittest.main()