try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

import sys
from cistematic.genomes import Genome


class GeneSymbolAndCount():

    def __init__(self, symbol="", uniqueCount=0, spliceCount=0):
        self.symbol = symbol
        self.uniqueCount = uniqueCount
        self.spliceCount = spliceCount


def main(argv=None):
    if not argv:
        argv = sys.argv

    print "predictSpliceCount: version 1.2"

    if len(argv) < 6:
        print 'usage: python %s genome maxBorder uniquecountfile splicecountfile outfile' % argv[0]
        sys.exit(1)

    genome = argv[1]
    # number of nucleotides at the end of each exon that is affected by splicing
    splicelead = int(argv[2])
    uniquefilecount = argv[3]
    splicefilecount =  argv[4]
    outfilename = argv[5]

    predictSpliceCount(genome, splicelead, uniquefilecount, splicefilecount, outfilename)


def predictSpliceCount(genome, splicelead, uniquefilecount, splicefilecount, outfilename):
    hg = Genome(genome)

    gidDict = getGeneData(uniquefilecount, splicefilecount)
    gidList = gidDict.keys()
    gidList.sort()

    outfile = open(outfilename, "w")
    for gid in gidList:
        featureList = hg.getGeneFeatures((genome, gid))
        featuresizesum, splicearea = getStats(featureList, splicelead)
        fractionCoverage = featuresizesum / float(splicearea + featuresizesum)
        geneData = gidDict[gid]
        uniqueCount = geneData.uniqueCount
        expectedSpliceCount = int(round(uniqueCount/fractionCoverage)) - uniqueCount

        # this p-value is based on the observed unique count, not the expected total count
        # nor the multi-read adjusted count
        pvalue = 1 - pow(1 - float(splicelead)/featuresizesum, uniqueCount)
        symbol = geneData.symbol
        spliceCount = geneData.spliceCount
        print '%s %s %f %d %d' % (gid, symbol, pvalue, expectedSpliceCount, spliceCount)
        outfile.write('%s\t%s\t%f\t%d\t%d\n' % (gid, symbol, pvalue, expectedSpliceCount, spliceCount))

    outfile.close()


def getGeneData(uniquefilecount, splicefilecount):
    gidDict = {}
    uniquefile = open(uniquefilecount)
    for line in uniquefile:
        fields = line.strip().split()
        geneData = GeneSymbolAndCount(symbol=fields[1], uniqueCount=int(fields[2]))
        gidDict[fields[0]] = geneData

    uniquefile.close()
    splicefile = open(splicefilecount)
    for line in splicefile:
        fields = line.strip().split()
        gidDict[fields[0]].spliceCount = int(fields[2])

    splicefile.close()

    return gidDict


def getStats(featureList, splicelead):
    newfeatureList = []
    featuresizesum = 0
    for (ftype, chrom, start, stop, sense) in featureList:
        if (start, stop) not in newfeatureList:
            newfeatureList.append((start, stop))
            featuresizesum += stop - start + 1

    if featuresizesum < 1:
        featuresizesum = 1

    splicearea = (len(newfeatureList) - 1) * splicelead
    if splicearea < splicelead:
        splicearea = 0

    return featuresizesum, splicearea


if __name__ == "__main__":
    main(sys.argv)