##################################
#                                #
# Last modified 07/22/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import numpy
from sets import Set

def chrCV(CoverageDict,ProcessedGeneDict,currentChr,minCoverage):

    CVs = []

    for geneID in ProcessedGeneDict[currentChr].keys():
        if 'wholetranscript' in ProcessedGeneDict[currentChr][geneID].keys():
            coverage = []
            for (left,right,strand) in ProcessedGeneDict[currentChr][geneID]['wholetranscript']:
                for i in range(left,right):
                    if CoverageDict.has_key(i):
                        coverage.append(CoverageDict[i])
                    else:
                        coverage.append(0)
            if len(coverage) == 0:
                continue
            meanCov = numpy.mean(coverage)
            if meanCov <= minCoverage:
                continue
            CVs.append(numpy.std(coverage)/numpy.mean(coverage))
        if 'exons' in ProcessedGeneDict[currentChr][geneID].keys():
            for (left,right,strand) in ProcessedGeneDict[currentChr][geneID]['exons']:
                coverage = []
                for i in range(left,right):
                    if CoverageDict.has_key(i):
                        coverage.append(CoverageDict[i])
                    else:
                        coverage.append(0)
                meanCov = numpy.mean(coverage)
                if meanCov <= minCoverage:
                    continue
                CVs.append(numpy.std(coverage)/meanCov)

    return CVs

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s gtf wig outfile [-field1 biotype] [-gene_type genetype] [-wholeSingleModelTranscripts] [-minExonLength bp] [-maxExonLength bp] [-minCoverage Nx] [-no3exon] [-no5exon] [-minTranscriptLength bp] [-maxTranscriptLength bp]' % sys.argv[0]
        print 'Note: the [-maxTranscriptLength] and [-minTranscriptLength] options only work together with the [-wholeSingleModelTranscripts] option'
        sys.exit(1)

    gtf = sys.argv[1]
    wig = sys.argv[2]
    outputfilename = sys.argv[3]

    doBioType=False
    BioType=''
    if '-field1' in sys.argv:
        doBioType=True
        BioType=sys.argv[sys.argv.index('-field1')+1]
        print 'will only consider', BioType, 'genes'

    doGeneType=False
    GeneType=''
    if '-gene_type' in sys.argv:
        doGeneType=True
        GeneType=sys.argv[sys.argv.index('-gene_type')+1]
        print 'will only consider', GeneType, 'genes'

    doMaxExonLength=False
    if '-maxExonLength' in sys.argv:
        doMaxExonLength=True
        maxExonLength=int(sys.argv[sys.argv.index('-maxExonLength')+1])
        print 'will only consider exons shorter than', maxExonLength

    doMinExonLength=False
    if '-minExonLength' in sys.argv:
        doMinExonLength=True
        minExonLength=int(sys.argv[sys.argv.index('-minExonLength')+1])
        print 'will only consider exons shorter than', minExonLength

    minCoverage = 0
    if '-minCoverage' in sys.argv:
        minCoverage = float(sys.argv[sys.argv.index('-minCoverage')+1])
        print 'will only consider exons with coverage greater than', minCoverage

    doSingleModel=False
    if '-wholeSingleModelTranscripts' in sys.argv:
        doSingleModel=True
        print 'will merge the exons of single-transcript model genes'
        doMaxTranscriptLength=False
        if '-maxTranscriptLength' in sys.argv:
            doMaxTranscriptLength=True
            maxTranscriptLength=int(sys.argv[sys.argv.index('-maxTranscriptLength')+1])
            print 'will only consider Transcripts shorter than', maxTranscriptLength
        doMinTranscriptLength=False
        if '-minTranscriptLength' in sys.argv:
            doMinTranscriptLength=True
            minTranscriptLength=int(sys.argv[sys.argv.index('-minTranscriptLength')+1])
            print 'will only consider Transcripts longer than', minTranscriptLength

    no3Exon=False
    if '-no3exon' in sys.argv:
        no3Exon=True
        print 'will not consider the last exons of transcripts'

    no5Exon=False
    if '-no5exon' in sys.argv:
        no5Exon=True
        print 'will not consider the first exons of transcripts'

    listoflines = open(gtf)
    GeneDict={}
    for line in listoflines:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        if doBioType:
            if fields[1] != BioType:
                continue
        if doGeneType:
            gene_type = fields[8].split('gene_type "')[1].split('";')[0]
            if gene_type != BioType:
                continue
        chr=fields[0]
        if GeneDict.has_key(chr):
            pass
        else:
            GeneDict[chr] = {}
        strand=fields[6]
        left=int(fields[3])
        right=int(fields[4])
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if GeneDict[chr].has_key(geneID):
            pass
        else:
            GeneDict[chr][geneID]={}
        if GeneDict[chr][geneID].has_key(transcriptID):
            pass
        else:
            GeneDict[chr][geneID][transcriptID]=[]
        GeneDict[chr][geneID][transcriptID].append((left,right,strand))

    print 'finished inputting annotation'

    ProcessedGeneDict = {}

    for chr in GeneDict.keys():
        ProcessedGeneDict[chr] = {}
        for geneID in GeneDict[chr].keys():
            ProcessedGeneDict[chr][geneID]={}
            for transcriptID in GeneDict[chr][geneID].keys():
                GeneDict[chr][geneID][transcriptID].sort()
                (left,right,strand) = GeneDict[chr][geneID][transcriptID][0]
                if strand == '-':
                    GeneDict[chr][geneID][transcriptID].reverse()
            if doSingleModel:
                if len(GeneDict[chr][geneID].keys()) == 1:
                    wholetranscript = []
                    TotalLength = 0
                    for transcriptID in GeneDict[chr][geneID].keys():
                        i = 0
                        for (left,right,strand) in GeneDict[chr][geneID][transcriptID]:
                            i+=1
                            if no3Exon and i == len(GeneDict[chr][geneID][transcriptID]):
                                continue
                            if no5Exon and i == 1:
                                continue
                            wholetranscript.append((left,right,strand))
                            TotalLength += math.fabs(right-left)
                    if doMaxTranscriptLength and TotalLength > maxTranscriptLength:
                        continue
                    if doMinTranscriptLength and TotalLength < minTranscriptLength:
                        continue
                    ProcessedGeneDict[chr][geneID]['wholetranscript'] = wholetranscript
                else:
                    continue
            else:
                exons = []
                for transcriptID in GeneDict[chr][geneID].keys():
                    i = 0
                    for (left,right,strand) in GeneDict[chr][geneID][transcriptID]:
                        i+=1
                        if no3Exon and i == len(GeneDict[chr][geneID][transcriptID]):
                            continue
                        if no5Exon and i == 1:
                            continue
                        if doMaxExonLength and math.fabs(right - left) > MaxExonLength:
                            continue
                        if doMinExonLength and math.fabs(right - left) < MinExonLength:
                            continue
                        exons.append((left,right,strand))
                ProcessedGeneDict[chr][geneID]['exons'] = list(Set(exons))

    print 'finished processing annotation'

    CVlist = []

    currentChr = ''
    listoflines = open(wig)
    for line in listoflines:
        if line.startswith('track'):
            continue
        if line.startswith('#'):
            continue
        fields=line.replace(' ','\t').strip().split('\t')
        chr=fields[0]
        if chr != currentChr:
            if currentChr == '':
                currentChr = chr
                CoverageDict = {}
            else:
                if ProcessedGeneDict.has_key(currentChr):
                    CVlist = CVlist + chrCV(CoverageDict,ProcessedGeneDict,currentChr,minCoverage)
                else:
                    pass
                currentChr = chr
                CoverageDict = {}
            print currentChr
        left=int(fields[1])
        right=int(fields[2])
        score=float(fields[3])
        for j in range(left,right):
            CoverageDict[j]=score

    if ProcessedGeneDict.has_key(currentChr):
        CVlist = CVlist + chrCV(CoverageDict,ProcessedGeneDict,currentChr,minCoverage)
    else:
        pass

    outfile=open(outputfilename,'w')

    outline = wig + '\t' + str(numpy.mean(CVlist))
    print outline
    outfile.write(outline + '\n')

    outfile.close()

run()
