##################################
#                                #
# Last modified 11/29/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s gtf outputfilename [-spliceNumber N] [-firstExonLength bp]' % sys.argv[0]
        print '\tuse the -spliceNumber option if you want the n-th (for transcripts with n+1 exons only) exons instead of the first'
        sys.exit(1)
    
    inputfilename = sys.argv[1]
    outfilename = sys.argv[2]

    outfile = open(outfilename, 'w')
    outline = '#chr\tpos\tstrand\tgeneID\tgeneName\ttranscriptID\ttranscriptName'
    outfile.write(outline+'\n')

    N = 1
    if '-spliceNumber' in sys.argv:
        N = int(sys.argv[sys.argv.index('-spliceNumber')+1])
        print 'will output the splice site number', N

    FirstExonLength = 0
    if '-firstExonLength' in sys.argv:
        FirstExonLength = int(sys.argv[sys.argv.index('-firstExonLength')+1])
        print 'Will only consider first exons with longer than', FirstExonLength, 'bp'

    TranscriptDict={}

    linelist = open(inputfilename)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        TranscriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            TranscriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            TranscriptName = TranscriptID
        GeneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            GeneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            GeneName = GeneID
        ID = (GeneID,GeneName,TranscriptID,TranscriptName)
        if TranscriptDict.has_key(ID):
            pass
        else:
            TranscriptDict[ID] = []
        TranscriptDict[ID].append((chr,left,right,strand))

    FirstSpliceDictPos = {}

    for ID in TranscriptDict.keys():
        if len(TranscriptDict[ID]) <= N:
            continue
        TranscriptDict[ID].sort()
        chr = TranscriptDict[ID][0][0]
        strand = TranscriptDict[ID][0][3]
        if strand == '+':
            FirstExonLeft = TranscriptDict[ID][N-1][1]
            FirstExonRight = TranscriptDict[ID][N-1][2]
            if FirstExonRight - FirstExonLeft < FirstExonLength:
                continue
            pos = TranscriptDict[ID][N-1][2]
        if strand == '-':
            FirstExonLeft = TranscriptDict[ID][-N][1]
            FirstExonRight = TranscriptDict[ID][-N][2]
            if FirstExonRight - FirstExonLeft < FirstExonLength:
                continue
            pos = TranscriptDict[ID][-N][1]
        splice = (chr,pos,strand)
        if FirstSpliceDictPos.has_key(splice):
            pass
        else:
            FirstSpliceDictPos[splice]=[]
        FirstSpliceDictPos[splice].append(ID)

    for splice in FirstSpliceDictPos.keys():
        (chr,pos,strand) = splice
        outline = chr + '\t' + str(pos) + '\t' + strand
        GeneID = ''
        GeneName = ''
        TranscriptID = ''
        TranscriptName = ''
        for ID in FirstSpliceDictPos[splice]:
            GeneID = GeneID + ID[0] + ','
            GeneName = GeneName + ID[1] + ','
            TranscriptID = TranscriptID + ID[2] + ','
            TranscriptName = TranscriptName + ID[3] + ','
        outline = outline + '\t' + GeneID[0:-1] + '\t' + GeneName[0:-1] + '\t' + TranscriptID[0:-1] + '\t' + TranscriptName[0:-1]
        outfile.write(outline + '\n')
   
    outfile.close()
   
run()
