try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
import ReadDataset
from cistematic.genomes import Genome
from cistematic.core import chooseDB, cacheGeneDB, uncacheGeneDB
from commoncode import getConfigParser, getConfigOption, getConfigIntOption, getConfigBoolOption, getConfigFloatOption

print "normalizeExpandedExonic: version 5.7"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %s genome rdsfile uniqcountfile splicecountfile outfile [candidatefile acceptfile] [--gidField fieldID] [--maxLength kblength] [--cache]"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(sys.argv) < 6:
        print usage
        print "\twhere splicecountfile can be set to 'none' to not count splices\n"
        sys.exit(1)

    genome = args[0]
    hitfile = args[1]
    uniquecountfile = args[2]
    splicecountfile = args[3]
    outfile = args[4]

    candidateLines = []
    acceptedfilename = ""
    if len(args) > 5:
        try:
            candidatefile = open(args[5])
            candidateLines = candidatefile.readlines()
            candidatefile.close()
            acceptedfilename = args[6]
        except IndexError:
            pass

    normalizeExpandedExonic(genome, hitfile, uniquecountfile, splicecountfile, outfile,
                            candidateLines, acceptedfilename, options.fieldID,
                            options.maxLength, options.doCache, options.extendGenome,
                            options.replaceModels)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--gidField", type="int", dest="fieldID")
    parser.add_option("--maxLength", type="float", dest="maxLength")
    parser.add_option("--cache", action="store_true", dest="doCache")
    parser.add_option("--models", dest="extendGenome")
    parser.add_option("--replacemodels", action="store_true", dest="replaceModels")

    configParser = getConfigParser()
    section = "normalizeExpandedExonic"
    fieldID = getConfigIntOption(configParser, section, "fieldID", 0)
    maxLength = getConfigFloatOption(configParser, section, "maxLength", 1000000000.)
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    extendGenome = getConfigOption(configParser, section, "extendGenome", "")
    replaceModels = getConfigBoolOption(configParser, section, "replaceModels", False)

    parser.set_defaults(fieldID=fieldID, maxLength=maxLength, doCache=doCache, extendGenome=extendGenome,
                        replaceModels=replaceModels)

    return parser


def normalizeExpandedExonic(genome, hitfile, uniquecountfilename, splicecountfilename,
                            outfilename, candidateLines=[], acceptedfilename="",
                            fieldID=0, maxLength=1000000000., doCache=False,
                            extendGenome="", replaceModels=False):

    uniquecountfile = open(uniquecountfilename)

    if acceptedfilename:
        acceptedfile = open(acceptedfilename, "w")

    dosplicecount = False
    if splicecountfilename != "none":
        dosplicecount = True
        splicecountfile = open(splicecountfilename)

    if extendGenome:
        if replaceModels:
            print "will replace gene models with %s" % extendGenome
        else:
            print "will extend gene models with %s" % extendGenome

    if doCache:
        cacheGeneDB(genome)
        hg = Genome(genome, dbFile=chooseDB(genome), inRAM=True)
        print "%s cached" % genome
    else:
        hg = Genome(genome, inRAM=True)

    if extendGenome != "":
        hg.extendFeatures(extendGenome, replace=replaceModels)

    RDS = ReadDataset.ReadDataset(hitfile, verbose = True, cache=doCache, reportCount=False)    
    uniqcount = RDS.getUniqsCount()
    print "%d unique reads" % uniqcount

    splicecount = 0
    countDict = {}
    gidList = []
    farList = []
    candidateDict = {}

    gidToGeneDict = {}

    featuresDict = hg.getallGeneFeatures()
    print "got featuresDict"

    outfile = open(outfilename, "w")

    for line in uniquecountfile:
        fields = line.strip().split()
        gid = fields[fieldID]
        gene = fields[1]
        countDict[gid] = float(fields[-1])
        gidList.append(gid)
        gidToGeneDict[gid] = gene

    uniquecountfile.close()

    if dosplicecount:
        for line in splicecountfile:
            fields = line.strip().split()
            gid = fields[fieldID]
            try:
                countDict[gid] += float(fields[-1])
            except:
                print fields
                continue

            splicecount += float(fields[-1])

        splicecountfile.close()

    for line in candidateLines:
        if "#" in line:
            continue

        fields = line.strip().split()
        gid = fields[1]
        gene = fields[0]
        if gid not in gidList:
            if gid not in farList:
                farList.append(gid)
                gidToGeneDict[gid] = gene

            if gid not in countDict:
                countDict[gid] = 0

            countDict[gid] += float(fields[6])

        if gid not in candidateDict:
            candidateDict[gid] = []

        candidateDict[gid].append((float(fields[6]), abs(int(fields[5]) - int(fields[4])), fields[3], fields[4], fields[5]))

    totalCount = (uniqcount + splicecount) / 1000000.
    uniqScale = uniqcount / 1000000.
    for gid in gidList:
        gene = gidToGeneDict[gid]
        featureList = []
        try:
            featureList = featuresDict[gid]
        except:
            try:
                featureList = featuresDict[gene]
            except:
                print gene, gid

        newfeatureList = []
        geneLength = 0.
        for (ftype, chrom, start, stop, sense) in featureList:
            if (start, stop) not in newfeatureList:
                newfeatureList.append((start, stop))
                geneLength += (abs(start - stop) + 1.) / 1000.

        if geneLength < 0.1:
            geneLength = 0.1
        elif geneLength > maxLength:
            geneLength = maxLength

        rpm = countDict[gid] / totalCount
        rpkm = rpm / geneLength
        if gid in candidateDict:
            for (cCount, cLength, chrom, cStart, cStop) in candidateDict[gid]:
                cratio = cCount / (cLength / 1000.)
                cratio = (uniqScale * cratio) / totalCount
                if 10. * cratio < rpkm:
                    continue

                countDict[gid] += cCount
                geneLength += cLength / 1000.
                acceptedfile.write("%s\t%s\t%s\t%s\t%.2f\t%d\t%s\n" % (gid, chrom, cStart, cStop, cratio, cLength, gene))

        rpm = countDict[gid] / totalCount
        rpkm = rpm / geneLength
        outfile.write("%s\t%s\t%.4f\t%.2f\n" %  (gid, gene, geneLength, rpkm))

    for gid in farList:
        gene = gidToGeneDict[gid]
        geneLength = 0
        for (cCount, cLength, chrom, cStart, cStop) in candidateDict[gid]:
            geneLength += cLength / 1000.

        if geneLength < 0.1:
            continue

        for (cCount, cLength, chrom, cStart, cStop) in candidateDict[gid]:
            cratio = cCount / (cLength / 1000.)
            cratio = cratio / totalCount
            acceptedfile.write("%s\t%s\t%s\t%s\t%.2f\t%d\t%s\n" % (gene, chrom, cStart, cStop, cratio, cLength, gene))

        rpm = countDict[gid] / totalCount
        rpkm = rpm / geneLength
        outfile.write('%s\t%s\t%.4f\t%.2f\n' %  (gene, gene, geneLength, rpkm))

    outfile.close()
    try:
        acceptedfile.close()
    except:
        pass

    if doCache:
        uncacheGeneDB(genome)


if __name__ == "__main__":
    main(sys.argv)