##################################
#                                #
# Last modified 06/08/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s gtf outputfilename [-expressed minFPKM transcript.expr]' % sys.argv[0]
        sys.exit(1)
    
    gtf = sys.argv[1]
    outfilename = sys.argv[2]
    doExpression=False
    if '-expressed' in sys.argv:
        doExpression=True
        minFPKM=float(sys.argv[sys.argv.index('-expressed')+1])
        trascnriptExpression=sys.argv[sys.argv.index('-expressed')+2]

    outfile = open(outfilename, 'w')

    lineslist=open(gtf)
    GeneDict={}
    i=0
    for line in lineslist:
        fields=line.strip().split('\t')
        if fields[2]!='transcript':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        if 'gene_name' not in fields[8]:
            gene=fields[8].split('gene_id "')[1].split('";')[0]
        else:
            gene=fields[8].split('gene_name "')[1].split('";')[0]
        if 'transcript_name' not in fields[8]:
            transcript=fields[8].split('transcript_id "')[1].split('";')[0]
        else:
            transcript=fields[8].split('transcript_name "')[1].split('";')[0]
        if GeneDict.has_key(gene):
            GeneDict[gene].append((transcript,chr,left,right,strand))
        else:
            GeneDict[gene]=[]
            GeneDict[gene].append((transcript,chr,left,right,strand))
        i+=1
        if i % 10000 == 0:
            print i, 'transcripts processed' 
        
    print 'found', len(GeneDict.keys()), 'genes'
    print 'found', i, 'transcripts'

    genes=GeneDict.keys()
    genes.sort()
    outfile.write('gene\t5-ends\tDistinct_5-ends\n')
    for gene in genes:
        ends=[]
        for (transcript,chr,left,right,strand) in GeneDict[gene]:
            if strand=='+':
                ends.append(left)
            if strand=='-':
                ends.append(right)
        outline=gene+'\t'+str(len(ends))+'\t'
        distinct=len(list(Set(ends)))
        outline=outline+str(distinct)
        outfile.write(outline+'\n')

    outfile.close()
   
run()
