try:
    import psyco
    psyco.full()
except:
    print "psyco not running"

import sys
import optparse
from commoncode import getMergedRegions, getFeaturesByChromDict, getConfigParser, getConfigOption, getConfigBoolOption
import ReadDataset
from cistematic.genomes import Genome
from commoncode import getGeneInfoDict
from cistematic.core import chooseDB, cacheGeneDB, uncacheGeneDB

print "geneMrnaCountsWeighted: version 4.3"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %s genome rdsfile uniqcountfile outfile [options]"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 4:
        print usage
        sys.exit(1)

    genome = args[0]
    hitfile =  args[1]
    countfile = args[2]
    outfilename = args[3]

    geneMrnaCountsWeighted(genome, hitfile, countfile, outfilename, options.ignoreSense,
                           options.withUniqs, options.withMulti,
                           options.acceptfile, options.cachePages, options.doVerbose,
                           options.extendGenome, options.replaceModels)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--stranded", action="store_false", dest="ignoreSense")
    parser.add_option("--uniq", action="store_true", dest="withUniqs")
    parser.add_option("--multi", action="store_true", dest="withMulti")
    parser.add_option("--accept", dest="acceptfile")
    parser.add_option("--cache", type="int", dest="cachePages")
    parser.add_option("--verbose", action="store_true", dest="doVerbose")
    parser.add_option("--models", dest="extendGenome")
    parser.add_option("--replacemodels", action="store_true", dest="replaceModels")

    configParser = getConfigParser()
    section = "geneMrnaCountsWeighted"
    ignoreSense = getConfigBoolOption(configParser, section, "ignoreSense", True)
    withUniqs = getConfigBoolOption(configParser, section, "withUniqs", False)
    withMulti = getConfigBoolOption(configParser, section, "withMulti", False)
    acceptfile = getConfigOption(configParser, section, "acceptfile", None)
    cachePages = getConfigOption(configParser, section, "cachePages", None)
    doVerbose = getConfigBoolOption(configParser, section, "doVerbose", False)
    extendGenome = getConfigOption(configParser, section, "extendGenome", "")
    replaceModels = getConfigBoolOption(configParser, section, "replaceModels", False)

    parser.set_defaults(ignoreSense=ignoreSense, withUniqs=withUniqs, withMulti=withMulti,
                        acceptfile=acceptfile, cachePages=cachePages, doVerbose=doVerbose, extendGenome=extendGenome,
                        replaceModels=replaceModels)

    return parser


def geneMrnaCountsWeighted(genome, hitfile, countfile, outfilename, ignoreSense=True,
                           withUniqs=False, withMulti=False, acceptfile=None,
                           cachePages=None, doVerbose=False, extendGenome="", replaceModels=False):

    if (not withUniqs and not withMulti) or (withUniqs and withMulti):
        print "must have either one of -uniq or -multi set. Exiting"
        sys.exit(1)

    if cachePages is not None:
        cacheGeneDB(genome)
        hg = Genome(genome, dbFile=chooseDB(genome), inRAM=True)
        print "%s cached" % genome
        doCache = True
    else:
        doCache = False
        cachePages = 0
        hg = Genome(genome, inRAM=True)

    if extendGenome:
        if replaceModels:
            print "will replace gene models with %s" % extendGenome
        else:
            print "will extend gene models with %s" % extendGenome

        hg.extendFeatures(extendGenome, replace=replaceModels)

    hitRDS = ReadDataset.ReadDataset(hitfile, verbose=doVerbose, cache=doCache)
    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)

    allGIDs = set(hg.allGIDs())
    if acceptfile is not None:
        regionDict = getMergedRegions(acceptfile, maxDist=0, keepLabel=True, verbose=doVerbose)
        for chrom in regionDict:
            for region in regionDict[chrom]:
                allGIDs.add(region.label)
    else:
        regionDict = {}

    featuresByChromDict = getFeaturesByChromDict(hg, regionDict)

    gidReadDict = {}
    read2GidDict = {}
    for gid in allGIDs:
        gidReadDict[gid] = []

    index = 0
    if withMulti and not withUniqs:
        chromList = hitRDS.getChromosomes(table="multi", fullChrom=False)
    else:
        chromList = hitRDS.getChromosomes(fullChrom=False)

    readlen = hitRDS.getReadSize()
    for chromosome in chromList:
        if doNotProcessChromosome(chromosome, featuresByChromDict.keys()):
            continue

        print "\n%s " % chromosome,
        fullchrom = "chr%s" % chromosome
        hitDict = hitRDS.getReadsDict(noSense=ignoreSense, fullChrom=True, chrom=fullchrom, withID=True, doUniqs=withUniqs, doMulti=withMulti)
        featureList = featuresByChromDict[chromosome]

        readGidList, totalProcessedReads = getReadGIDs(hitDict, fullchrom, featureList, readlen, index)
        index = totalProcessedReads
        for (tagReadID, gid) in readGidList:
            try:
                gidReadDict[gid].append(tagReadID)
                if tagReadID in read2GidDict:
                    read2GidDict[tagReadID].add(gid)
                else:
                    read2GidDict[tagReadID] = set([gid])
            except KeyError:
                print "gid %s not in gidReadDict" % gid

    writeCountsToFile(outfilename, countfile, allGIDs, hg, gidReadDict, read2GidDict, doVerbose, doCache)
    if doCache:
        uncacheGeneDB(genome)


def doNotProcessChromosome(chromosome, chromosomeList):
    return chromosome not in chromosomeList


def getReadGIDs(hitDict, fullchrom, featList, readlen, index):

    startFeature = 0
    readGidList = []
    ignoreSense = True
    for read in hitDict[fullchrom]:
        tagStart = read["start"]
        tagReadID = read["readID"]
        if read.has_key("sense"):
            tagSense = read["sense"]
            ignoreSense = False

        index += 1
        if index % 100000 == 0:
            print "read %d" % index,

        stopPoint = tagStart + readlen
        if startFeature < 0:
            startFeature = 0

        for (start, stop, gid, sense, ftype) in featList[startFeature:]:
            if tagStart > stop:
                startFeature += 1
                continue

            if start > stopPoint:
                startFeature -= 100
                break

            if not ignoreSense:
                if sense == "R":
                    sense = "-"
                else:
                    sense = "+"

            if start <= tagStart <= stop and (ignoreSense or tagSense == sense):
                readGidList.append((tagReadID, gid))
                stopPoint = stop

    return readGidList, index


def writeCountsToFile(outFilename, countFilename, allGIDs, genome, gidReadDict, read2GidDict, doVerbose=False, doCache=False):

    uniqueCountDict = {}
    uniquecounts = open(countFilename)
    for line in uniquecounts:
        fields = line.strip().split()
        # add a pseudo-count here to ease calculations below
        uniqueCountDict[fields[0]] = float(fields[-1]) + 1

    uniquecounts.close()

    genomeName = genome.genome
    geneinfoDict = getGeneInfoDict(genomeName, cache=doCache)
    geneannotDict = genome.allAnnotInfo()
    outfile = open(outFilename, "w")
    for gid in allGIDs:
        symbol = getGeneSymbol(gid, genomeName, geneinfoDict, geneannotDict)
        tagCount = getTagCount(uniqueCountDict, gid, gidReadDict, read2GidDict)
        if doVerbose:
            print "%s %s %f" % (gid, symbol, tagCount)

        outfile.write("%s\t%s\t%d\n" % (gid, symbol, tagCount))

    outfile.close()


def getGeneSymbol(gid, genomeName, geneinfoDict, geneannotDict):
    if "FAR" not in gid:
        symbol = "LOC%s" % gid
        geneinfo = ""
        try:
            geneinfo = geneinfoDict[gid]
            if genomeName == "celegans":
                symbol = geneinfo[0][1]
            else:
                symbol = geneinfo[0][0]
        except (KeyError, IndexError):
            try:
                symbol = geneannotDict[(genomeName, gid)][0]
            except (KeyError, IndexError):
                symbol = "LOC%s" % gid
    else:
        symbol = gid

    return symbol


def getTagCount(uniqueCountDict, gid, gidReadDict, read2GidDict):
    tagCount = 0.
    for readID in gidReadDict[gid]:
        try:
            tagValue = uniqueCountDict[gid]
        except KeyError:
            tagValue = 1

        tagDenom = 0.
        for relatedGID in read2GidDict[readID]:
            try:
                tagDenom += uniqueCountDict[relatedGID]
            except KeyError:
                tagDenom += 1

        try:
            tagCount += tagValue / tagDenom
        except ZeroDivisionError:
            pass

    return tagCount


if __name__ == "__main__":
    main(sys.argv)
