##################################
#                                #
# Last modified 10/31/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s gtf outfilename ' % sys.argv[0]
        sys.exit(1)

    gtf = sys.argv[1]
    outputfilename = sys.argv[2]

    GeneDict={}
    listoflines = open(gtf)
    i=0
    for line in listoflines:
        i+=1
        if i % 100000 == 0:
            print i
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if GeneDict.has_key(geneID):
            pass
        else:
            GeneDict[geneID]={}
        if GeneDict[geneID].has_key(transcriptID):
            pass
        else:
            GeneDict[geneID][transcriptID]=[]
        chr = fields[1]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        GeneDict[geneID][transcriptID].append((chr,left,right,strand))

    Total=0
    SharedFirstExonsStarts=0
    SharedFirstExonsEnds=0
    SharedEndExonsStarts=0
    SharedEndExonsEnds=0
 
    for geneID in GeneDict:
        if len(GeneDict[geneID].keys()) == 1:
            continue
        FirstExonsStarts = []
        EndExonsStarts = []
        FirstExonsEnds = []
        EndExonsEnds = []
        for transcriptID in GeneDict[geneID].keys():
            GeneDict[geneID][transcriptID].sort()
            strand = GeneDict[geneID][transcriptID][0][3]
            if strand == '+':
                FirstExonsStarts.append(GeneDict[geneID][transcriptID][0][1])
                EndExonsStarts.append(GeneDict[geneID][transcriptID][-1][1])
                FirstExonsEnds.append(GeneDict[geneID][transcriptID][0][2])
                EndExonsEnds.append(GeneDict[geneID][transcriptID][-1][2])
            if strand == '-':
                FirstExonsStarts.append(GeneDict[geneID][transcriptID][-1][2])
                EndExonsStarts.append(GeneDict[geneID][transcriptID][0][2])
                FirstExonsEnds.append(GeneDict[geneID][transcriptID][-1][1])
                EndExonsEnds.append(GeneDict[geneID][transcriptID][0][1])
        Total += 1
        if len(list(Set(FirstExonsStarts))) == 1:
            SharedFirstExonsStarts+=1
        if len(list(Set(EndExonsStarts))) == 1:
            SharedEndExonsStarts+=1
        if len(list(Set(FirstExonsEnds))) == 1:
            SharedFirstExonsEnds+=1
        if len(list(Set(EndExonsEnds))) == 1:
            SharedEndExonsEnds+=1

    outfile=open(outputfilename,'w')

    outline = 'Total:' + '\t' + str(Total) + '\n'
    outfile.write(outline)
    outline = 'SharedFirstExonsStarts:' + '\t' + str(SharedFirstExonsStarts) + '\n'
    outfile.write(outline)
    outline = 'SharedFirstExonsEnds:' + '\t' + str(SharedFirstExonsEnds) + '\n'
    outfile.write(outline)
    outline = 'SharedEndExonsStarts:' + '\t' + str(SharedEndExonsStarts) + '\n'
    outfile.write(outline)
    outline = 'SharedEndExonsEnds:' + '\t' + str(SharedEndExonsEnds) + '\n'
    outfile.write(outline)

    outfile.close()

run()

