try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

import sys
from cistematic.genomes import Genome


def main(argv=None):
    if not argv:
        argv = sys.argv

    print "predictSpliceCount: version 1.2"

    if len(argv) < 6:
        print 'usage: python %s genome maxBorder uniquecountfile splicecountfile outfile' % argv[0]
        sys.exit(1)

    genome = argv[1]
    # number of nucleotides at the end of each exon that is affected by splicing
    splicelead = int(argv[2])
    uniquefilecount = argv[3]
    splicefilecount =  argv[4]
    outfilename = argv[5]

    predictSpliceCount(genome, splicelead, uniquefilecount, splicefilecount, outfilename)


def predictSpliceCount(genome, splicelead, uniquefilecount, splicefilecount, outfilename):
    hg = Genome(genome)

    gidDict = {}
    gidList = []
    uniqueCountDict = {}
    spliceCountDict = {}

    uniquefile = open(uniquefilecount)
    for line in uniquefile:
        fields = line.strip().split()
        gidDict[fields[0]] = fields[1]
        gidList.append(fields[0])
        uniqueCountDict[fields[0]] = int(fields[2])

    splicefile = open(splicefilecount)
    for line in splicefile:
        fields = line.strip().split()
        spliceCountDict[fields[0]] = int(fields[2])

    outfile = open(outfilename,'w')

    gidList.sort()
    for gid in gidList:
        symbol = gidDict[gid]
        featureList = hg.getGeneFeatures((genome, gid))
        newfeatureList = []
        featuresizesum = 0
        for (ftype, chrom, start, stop, sense) in featureList:
            if (start, stop) not in newfeatureList:
                newfeatureList.append((start, stop))
                featuresizesum += stop - start + 1

        if featuresizesum < 1:
            featuresizesum = 1

        splicearea = (len(newfeatureList) - 1) * splicelead
        if splicearea < splicelead:
            splicearea = 0

        fractionCoverage = featuresizesum / float(splicearea + featuresizesum)
        expectedSpliceCount = int(round(uniqueCountDict[gid]/fractionCoverage)) - uniqueCountDict[gid]

        # this p-value is based on the observed unique count, not the expected total count
        # nor the multi-read adjusted count
        pvalue = 1 - pow(1 - float(splicelead)/featuresizesum, uniqueCountDict[gid])
        print '%s %s %f %d %d' % (gid, symbol, pvalue, expectedSpliceCount, spliceCountDict[gid])
        outfile.write('%s\t%s\t%f\t%d\t%d\n' % (gid, symbol, pvalue, expectedSpliceCount, spliceCountDict[gid]))

    outfile.close()


if __name__ == "__main__":
    main(sys.argv)