'''
Created on Aug 19, 2010

@author: sau

Located feature 728439 by:
    from Erange.commoncode import getFeaturesByChromDict
    genome = Genome(self.genomeName)
    featuresByChromDict = getFeaturesByChromDict(genome)
    print featuresByChromDict["1"][:3]

'''
import unittest
import os
from erange import geneMrnaCounts
from cistematic.core.geneinfo import geneinfoDB
from cistematic.genomes import Genome
from erange import ReadDataset


class TestGeneMrnaCounts(unittest.TestCase):
    idb = geneinfoDB(cache=True)
    testDBName = "testRDS.rds"
    genomeName = "hsapiens"
    outfilename = "testGeneMrnaCounts.txt"

    def setUp(self):
        self.rds = ReadDataset.ReadDataset(self.testDBName, initialize=True, datasetType="RNA", verbose=False)


    def tearDown(self):
        del(self.rds)
        os.remove(self.testDBName)


    def testGeneMrnaCounts(self):
        geneMrnaCounts.geneMrnaCounts(self.genomeName, self.testDBName, self.outfilename)
        outfile = open(self.outfilename, "r")
        for line in outfile:
            fields = line.split("\t")
            self.assertEquals("0\n", fields[2])

        outfile.close()
        os.remove(self.outfilename)

        rdsEntryList = [("testRead", "chr1", 18700, 18800, "+", 1.0, "", "")]
        self.rds.insertUniqs(rdsEntryList)
        geneMrnaCounts.geneMrnaCounts(self.genomeName, self.testDBName, self.outfilename)
        possibleCounts = ["0\n", "1\n"]
        outfile = open(self.outfilename, "r")
        for line in outfile:
            fields = line.split("\t")
            self.assertTrue(fields[2] in possibleCounts)

        outfile.close()
        os.remove(self.outfilename)

        geneMrnaCounts.geneMrnaCounts(self.genomeName, self.testDBName, self.outfilename,
                                      markGID=True, trackStrand=True)
        
        possibleCounts = ["0\n", "1\n"]
        outfile = open(self.outfilename, "r")
        for line in outfile:
            fields = line.split("\t")
            self.assertTrue(fields[2] in possibleCounts)

        outfile.close()
        os.remove(self.outfilename)
        reads = self.rds.getReadsDict(withFlag=True)
        self.assertEquals("728439", reads["1"][0]["flag"])

        geneMrnaCounts.geneMrnaCounts(self.genomeName, self.testDBName, self.outfilename,
                                      countFeats=True, markGID=True, cachePages=150000)

        possibleCounts = ["0\n", "1\n"]
        outfile = open(self.outfilename, "r")
        for line in outfile:
            fields = line.split("\t")
            self.assertTrue(fields[2] in possibleCounts)

        outfile.close()
        os.remove(self.outfilename)
        reads = self.rds.getReadsDict(withFlag=True)
        self.assertEquals("728439", reads["1"][0]["flag"])


    def testCountFeatures(self):
        testDict = {}
        self.assertEquals(0, geneMrnaCounts.countFeatures(testDict))

        testDict = {"chr1": []}
        self.assertEquals(0, geneMrnaCounts.countFeatures(testDict))

        #TODO: This is likely not the result we want
        testDict = {"chr1": "not a list"}
        self.assertEquals(10, geneMrnaCounts.countFeatures(testDict))

        testDict = {"chr1": 10}
        self.assertEquals(0, geneMrnaCounts.countFeatures(testDict))

        testDict = {"chr1": 10,
                    "chr2": ["f1"]}
        self.assertEquals(1, geneMrnaCounts.countFeatures(testDict))

        testDict = {"chr1": ["f1", "f2"]}
        self.assertEquals(2, geneMrnaCounts.countFeatures(testDict))

        testDict = {"chr1": ["f1", "f2"],
                    "chr2": []}
        self.assertEquals(2, geneMrnaCounts.countFeatures(testDict))

        testDict = {"chr1": ["f1", "f2"],
                    "chr2": ["f1"]}
        self.assertEquals(3, geneMrnaCounts.countFeatures(testDict))


    def testGetGeneSymbol(self):
        # Case: Null/None inputs
        gid = ""
        searchGID = False
        geneInfoDict = {}
        idb = None
        genomeName = ""
        geneAnnotDict = {}
        self.assertEquals("LOC", geneMrnaCounts.getGeneSymbol(gid, searchGID, geneInfoDict, idb, genomeName, geneAnnotDict))

        # Case: symbol is in geneInfoDict
        gid = "1"
        searchGID = False
        geneInfoDict = {"1": [["gene1", "wrong name"], ["wrong name 2"]]}
        idb = None
        genomeName = "test"
        geneAnnotDict = {("test", "1"): ["wrong name 3"]}
        self.assertEquals("gene1", geneMrnaCounts.getGeneSymbol(gid, searchGID, geneInfoDict, idb, genomeName, geneAnnotDict))

        # Case: symbol not in geneInfoDict, is in geneAnnotDict
        gid = "1"
        searchGID = False
        geneInfoDict = {"0": [["wrong name"], ["wrong name 2"]]}
        idb = None
        genomeName = "test"
        geneAnnotDict = {("test", "1"): ["gene1"]}
        self.assertEquals("gene1", geneMrnaCounts.getGeneSymbol(gid, searchGID, geneInfoDict, idb, genomeName, geneAnnotDict))

        # Case: symbol not in geneInfoDict or geneAnnotDict - non-null/None inputs
        gid = "1"
        searchGID = False
        geneInfoDict = {"0": [["wrong name"], ["wrong name 2"]]}
        idb = None
        genomeName = "test"
        geneAnnotDict = {("test", "0"): ["wrong name 3"]}
        self.assertEquals("LOC1", geneMrnaCounts.getGeneSymbol(gid, searchGID, geneInfoDict, idb, genomeName, geneAnnotDict))

        # Case: using search, gid not in idb
        gid = "almostCertainlyNotInTheIDB"
        searchGID = True
        geneInfoDict = {"0": [["wrong name"], ["wrong name 2"]]}
        idb = self.idb
        genomeName = "human"
        geneAnnotDict = {("human", "0"): ["wrong name 3"]}
        self.assertEquals("LOCalmostCertainlyNotInTheIDB", geneMrnaCounts.getGeneSymbol(gid, searchGID, geneInfoDict, idb, genomeName, geneAnnotDict))

        # Case: using search
        # sql to get gid: select gID from gene_info where genome="human" and locustag !="-" and locustag != symbol limit 5;
        gid = "RP11-177A2.3"
        searchGID = True
        geneInfoDict = {"27": [["correct"], ["wrong name 2"]]}
        idb = self.idb
        genomeName = "human"
        geneAnnotDict = {("human", "0"): ["wrong name 3"]}
        self.assertEquals("correct", geneMrnaCounts.getGeneSymbol(gid, searchGID, geneInfoDict, idb, genomeName, geneAnnotDict))


    def testWriteOutputFile(self):
        genome = Genome(self.genomeName)
        gidList = ["RP11-177A2.3"]
        gidCount = {"RP11-177A2.3": 1}
        geneMrnaCounts.writeOutputFile(self.outfilename, genome, gidList, gidCount, searchGID=False)

        outfile = open(self.outfilename, "r")
        line = outfile.readline()
        result = "RP11-177A2.3\tLOCRP11-177A2.3\t1\n"
        self.assertEquals(result, line)
        outfile.close()
        os.remove(self.outfilename)

        genome = Genome("hsapiens")
        gidList = ["RP11-177A2.3"]
        gidCount = {"something else": 1}
        geneMrnaCounts.writeOutputFile(self.outfilename, genome, gidList, gidCount, searchGID=False)

        outfile = open(self.outfilename, "r")
        line = outfile.readline()
        result = "RP11-177A2.3\tLOCRP11-177A2.3\t0\n"
        self.assertEquals(result, line)
        outfile.close()
        os.remove(self.outfilename)

    def testMain(self):
        argv = ["geneMRNACounts", self.genomeName, self.testDBName, self.outfilename]
        geneMrnaCounts.main(argv)
        outfile = open(self.outfilename, "r")
        for line in outfile:
            fields = line.split("\t")
            self.assertEquals("0\n", fields[2])

        outfile.close()
        os.remove(self.outfilename)


def suite():
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(TestGeneMrnaCounts))

    return suite


if __name__ == "__main__":
    #import sys;sys.argv = ['', 'Test.testName']
    unittest.main()