try:
    import psyco
    psyco.full()
except:
    print "psyco not running"

import sys
import optparse
from commoncode import getFeaturesByChromDict, getConfigParser, getConfigOption, getConfigBoolOption
import ReadDataset
from cistematic.genomes import Genome
from cistematic.core.geneinfo import geneinfoDB

print "geneMrnaCounts: version 5.2"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog genome rdsfile outfilename [options]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    genomeName = args[0]
    hitfile =  args[1]
    outfilename = args[2]

    geneMrnaCounts(genomeName, hitfile, outfilename, options.trackStrand, options.doSplices,
                   options.doUniqs, options.doMulti, options.extendGenome, options.replaceModels,
                   options.searchGID, options.countFeats, options.cachePages, options.markGID)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--stranded", action="store_true", dest="trackStrand")
    parser.add_option("--splices", action="store_true", dest="doSplices")
    parser.add_option("--noUniqs", action="store_false", dest="doUniqs")
    parser.add_option("--multi", action="store_true", dest="doMulti")
    parser.add_option("--models", dest="extendGenome")
    parser.add_option("--replacemodels", action="store_true", dest="replaceModels")
    parser.add_option("--searchGID", action="store_true", dest="searchGID")
    parser.add_option("--countfeatures", action="store_true", dest="countFeats")
    parser.add_option("--cache", type="int", dest="cachePages")
    parser.add_option("--markGID", action="store_true", dest="markGID")

    configParser = getConfigParser()
    section = "geneMrnaCounts"
    trackStrand = getConfigBoolOption(configParser, section, "trackStrand", False)
    doSplices = getConfigBoolOption(configParser, section, "doSplices", False)
    doUniqs = getConfigBoolOption(configParser, section, "doUniqs", True)
    doMulti = getConfigBoolOption(configParser, section, "doMulti", False)
    extendGenome = getConfigOption(configParser, section, "extendGenome", "")
    replaceModels = getConfigBoolOption(configParser, section, "replaceModels", False)
    searchGID = getConfigBoolOption(configParser, section, "searchGID", False)
    countFeats = getConfigBoolOption(configParser, section, "countFeats", False)
    cachePages = getConfigOption(configParser, section, "cachePages", None)
    markGID = getConfigBoolOption(configParser, section, "markGID", False)

    parser.set_defaults(trackStrand=trackStrand, doSplices=doSplices, doUniqs=doUniqs, doMulti=doMulti,
                        extendGenome=extendGenome, replaceModels=replaceModels, searchGID=searchGID,
                        countFeats=countFeats, cachePages=cachePages, markGID=markGID)

    return parser

def geneMrnaCounts(genomeName, hitfile, outfilename, trackStrand=False, doSplices=False,
                   doUniqs=True, doMulti=False, extendGenome="", replaceModels=False,
                   searchGID=False, countFeats=False, cachePages=None, markGID=False):

    if trackStrand:
        print "will track strandedness"
        doStranded = "track"
    else:
        doStranded = "both"

    if extendGenome:
        if replaceModels:
            print "will replace gene models with %s" % extendGenome
        else:
            print "will extend gene models with %s" % extendGenome
    else:
        replaceModels = False

    if cachePages is not None:
        doCache = True
    else:
        cachePages = 100000
        doCache = False

    hitRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)

    genome = Genome(genomeName, inRAM=True)
    if extendGenome != "":
        genome.extendFeatures(extendGenome, replace=replaceModels)

    print "getting gene features...."
    featuresByChromDict = getFeaturesByChromDict(genome)

    seenFeaturesByChromDict = {}
    print "getting geneIDs...."
    gidList = genome.allGIDs()
    gidList.sort()
    gidCount = {}
    for gid in gidList:
        gidCount[gid] = 0

    chromList = hitRDS.getChromosomes(fullChrom=False)
    if len(chromList) == 0 and doSplices:
        chromList = hitRDS.getChromosomes(table="splices", fullChrom=False)

    if markGID:
        print "Flagging all reads as NM"
        hitRDS.setFlags("NM", uniqs=doUniqs, multi=doMulti, splices=doSplices)

    for chrom in chromList:
        if chrom not in featuresByChromDict:
            continue

        if countFeats:
            seenFeaturesByChromDict[chrom] = []

        print "\nchr%s" % chrom
        fullchrom = "chr%s" % chrom
        regionList = []        
        print "counting GIDs"
        for (start, stop, gid, featureSense, featureType) in featuresByChromDict[chrom]:
            try:
                if doStranded == "track":
                    checkSense = "+"
                    if featureSense == "R":
                        checkSense = "-"

                    regionList.append((gid, fullchrom, start, stop, checkSense))
                    count = hitRDS.getCounts(fullchrom, start, stop, uniqs=doUniqs, multi=doMulti, splices=doSplices, sense=checkSense)
                else:
                    regionList.append((gid, fullchrom, start, stop))
                    count = hitRDS.getCounts(fullchrom, start, stop, uniqs=doUniqs, multi=doMulti, splices=doSplices)
                    if count != 0:
                        print count

                gidCount[gid] += count
                if countFeats:
                    if (start, stop, gid, featureSense) not in seenFeaturesByChromDict[chrom]:
                        seenFeaturesByChromDict[chrom].append((start, stop, gid, featureSense))
            except:
                print "problem with %s - skipping" % gid

        if markGID:
            print "marking GIDs"
            hitRDS.flagReads(regionList, uniqs=doUniqs, multi=doMulti, splices=doSplices, sense=doStranded)
            print "finished marking"

    print " "
    if countFeats:
        numFeatures = countFeatures(seenFeaturesByChromDict)
        print "saw %d features" % numFeatures

    writeOutputFile(outfilename, genome, gidList, gidCount, searchGID)
    if markGID and doCache:
        hitRDS.saveCacheDB(hitfile)


def countFeatures(seenFeaturesByChromDict):
    count = 0
    for chrom in seenFeaturesByChromDict.keys():
        try:
            count += len(seenFeaturesByChromDict[chrom])
        except TypeError:
            pass

    return count


def writeOutputFile(outfilename, genome, gidList, gidCount, searchGID):
    geneAnnotDict = genome.allAnnotInfo()
    genomeName = genome.genome
    outfile = open(outfilename, "w")
    idb = geneinfoDB(cache=True)
    geneInfoDict = idb.getallGeneInfo(genomeName)
    for gid in gidList:
        symbol = getGeneSymbol(gid, searchGID, geneInfoDict, idb, genomeName, geneAnnotDict)
        if gid in gidCount:
            outfile.write("%s\t%s\t%d\n" % (gid, symbol, gidCount[gid]))
        else:
            outfile.write("%s\t%s\t0\n" % (gid, symbol))

    outfile.close()


def getGeneSymbol(gid, searchGID, geneInfoDict, idb, genomeName, geneAnnotDict):
    lookupGID = gid
    if searchGID and gid not in geneInfoDict:
        actualGeneID = idb.getGeneID(genomeName, gid)
        if len(actualGeneID) > 0:
            lookupGID = actualGeneID[1]

    try:
        geneinfo = geneInfoDict[lookupGID]
        symbol = geneinfo[0][0]
    except (KeyError, IndexError):
        try:
            symbol = geneAnnotDict[(genomeName, gid)][0]
        except (KeyError, IndexError):
            symbol = "LOC%s" % gid

    return symbol


if __name__ == "__main__":
    main(sys.argv)