##################################
#                                #
# Last modified 12/29/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s gtf peaks chrFieldID peakFieldID TSSradius outputfilename' % sys.argv[0]
        print '\tNote: use "narrowpeak" for the peakFieldID parameter if peaks are in narrowpeak format'
        print '\tNote: use "bed" for the peakFieldID parameter if not single peak coordinate is specified; the middle of the region will be used then'
        sys.exit(1)
    
    GTF = sys.argv[1]
    peaks = sys.argv[2]
    chrID = int(sys.argv[3])
    TSSradius = int(sys.argv[5])
    outfilename = sys.argv[6]

    doNarrowPeak = False
    doBED = False
    if sys.argv[4] == 'narrowpeak':
        doNarrowPeak = True
    elif sys.argv[4] == 'bed':
        doBED = True
    else:
        peakID = int(sys.argv[4])

    linelist=open(peaks)

    PeakDict = {}
    for line in linelist:
        if line.startswith('#') or line.startswith('track type'):
            continue
        fields=line.strip().split('\t')
        chr = fields[chrID]
        if doNarrowPeak:
            peak = int(fields[1]) + int(fields[9])
        elif doBED:
            peak = (int(fields[chrID+1]) + int(fields[chrID+2]))/2
        else:
            peak = int(fields[peakID])
        PeakDict[(chr,peak)] = 'IG'

    linelist=open(GTF)
    ExonDict={}
    TranscripDict={}
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        if ExonDict.has_key(chr):
            pass
        else:
            ExonDict[chr]={}
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if TranscripDict.has_key(transcriptID):
            pass
        else:
            TranscripDict[transcriptID]=[]
        TranscripDict[transcriptID].append((chr,left,right,strand))

    print 'finished parsing GTF file'

    for transcriptID in TranscripDict.keys():
        TranscripDict[transcriptID].sort()
        if TranscripDict[transcriptID][0][3] == '-':
            TSS = TranscripDict[transcriptID][-1][2]
            TranscripDict[transcriptID].reverse()
        if TranscripDict[transcriptID][0][3] == '+':
            TSS = TranscripDict[transcriptID][0][1]
        chr = TranscripDict[transcriptID][0][0]
        for i in range(TSS-TSSradius,TSS+TSSradius):
            if PeakDict.has_key((chr,i)):
                PeakDict[(chr,i)]='TSS'

    print 'finished parsing TSSs'

    for transcriptID in TranscripDict.keys():
        e = 0
        numExons = len(TranscripDict[transcriptID])
        for (chr,left,right,strand) in TranscripDict[transcriptID]:
            e+=1
            for i in range(left,right):
                if PeakDict.has_key((chr,i)) and PeakDict[(chr,i)] != 'TSS':
                    if e == 1 and numExons > 1:
                        PeakDict[(chr,i)]='E1'
                        continue
                    if e == numExons  and numExons > 1:
                        PeakDict[(chr,i)]='EL'
                        continue
                    PeakDict[(chr,i)]='E'

    print 'finished parsing exons'

    k=0
    numTranscripts = len(TranscripDict.keys())
    for transcriptID in TranscripDict.keys():
        k+=1
        if k % 1000 == 0:
            print numTranscripts - k
        TranscripDict[transcriptID].sort()
        if TranscripDict[transcriptID][0][3] == '-':
            TSS = TranscripDict[transcriptID][-1][2]
            TranscripDict[transcriptID].reverse()
        if TranscripDict[transcriptID][0][3] == '+':
            TSS = TranscripDict[transcriptID][0][1]
        numExons = len(TranscripDict[transcriptID])
        chr = TranscripDict[transcriptID][0][0]
        if numExons == 1:
            continue
        I = 0
        for i in range(numExons-1):
            I+=1
            for j in range(TranscripDict[transcriptID][i][2],TranscripDict[transcriptID][i+1][1]):
                if PeakDict.has_key((chr,j)):
                    if PeakDict[(chr,j)]=='IG':
                        PeakDict[(chr,j)] = 'I'
                        if I == 1 and numExons > 2:
                            PeakDict[(chr,j)] = 'I1'
                        if I == numExons-1 and numExons > 2:
                            PeakDict[(chr,j)] = 'IL'

    print 'finished parsing introns'

    Intergenic = 0
    Intronic = 0
    Exonic = 0
    TSS = 0
    Intronic1 = 0
    Exonic1 = 0
    IntronicL = 0
    ExonicL = 0

    for (chr,j) in PeakDict.keys():
        if PeakDict[(chr,j)]=='IG':
            Intergenic += 1
        if PeakDict[(chr,j)]=='TSS':
            TSS += 1
        if PeakDict[(chr,j)]=='I':
            Intronic += 1
        if PeakDict[(chr,j)]=='E':
            Exonic += 1
        if PeakDict[(chr,j)]=='I1':
            Intronic1 += 1
        if PeakDict[(chr,j)]=='E1':
            Exonic1 += 1
        if PeakDict[(chr,j)]=='IL':
            IntronicL += 1
        if PeakDict[(chr,j)]=='EL':
            ExonicL += 1

    outfile = open(outfilename, 'w')

    outline = '#Class\tNumber\tFraction\n'
    outfile.write(outline)
    Total = len(PeakDict.keys()) + 0.0

    outline = 'Intergenic' + '\t' + str(Intergenic) + '\t' + str(Intergenic/Total) + '\n'
    outfile.write(outline)
    outline = 'Intronic' + '\t' + str(Intronic) + '\t' + str(Intronic/Total) + '\n'
    outfile.write(outline)
    outline = 'Exonic' + '\t' + str(Exonic) + '\t' + str(Exonic/Total) + '\n'
    outfile.write(outline)
    outline = 'TSS_' + str(TSSradius) + '_bp\t' + str(TSS) + '\t' + str(TSS/Total) + '\n'
    outfile.write(outline)
    outline = 'Intronic_first' + '\t' + str(Intronic1) + '\t' + str(Intronic1/Total) + '\n'
    outfile.write(outline)
    outline = 'Exonic_first' + '\t' + str(Exonic1) + '\t' + str(Exonic1/Total) + '\n'
    outfile.write(outline)
    outline = 'Intronic_last' + '\t' + str(IntronicL) + '\t' + str(IntronicL/Total) + '\n'
    outfile.write(outline)
    outline = 'Exonic_last' + '\t' + str(ExonicL) + '\t' + str(ExonicL/Total) + '\n'
    outfile.write(outline)
   
    outfile.close()
   
run()
