##################################
#                                #
# Last modified 12/23/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s gff outputfilename [-CDS]' % sys.argv[0]
        sys.exit(1)
    
    inputfilename = sys.argv[1]
    outfilename = sys.argv[2]

    doCDS=False
    if '-CDS' in sys.argv:
        doCDS=True

    outfile = open(outfilename, 'w')

    linelist = open(inputfilename)

    TranscriptDict={}
    JunctionsList=[]
    
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if doCDS:
            if fields[2]!='CDS':
                continue
        else:
            if fields[2]!='exon':
                continue
        chr=fields[0]
        start=fields[3]
        stop=fields[4]
        strand=fields[6]
        mRNA=fields[8]
        if TranscriptDict.has_key(mRNA):
            pass
        else:
            TranscriptDict[mRNA]=[]
        TranscriptDict[mRNA].append((chr,start,stop,strand))
        
    for mRNA in TranscriptDict.keys():
        chr=TranscriptDict[mRNA][0][0]
        strand=TranscriptDict[mRNA][0][3]
        for i in range(len(TranscriptDict[mRNA])-1):
            junction=(chr,TranscriptDict[mRNA][i][2],TranscriptDict[mRNA][i+1][1],strand)
            JunctionsList.append(junction)

    print 'total junctions:', len(JunctionsList)
    
    JunctionsList=list(Set(JunctionsList))

    JunctionsList.sort()

    print 'unique junctions:', len(JunctionsList)

    for (chr,left,right,strand) in JunctionsList:
        outline=chr+'\t'+str(int(left)-1)+'\t'+str(int(right)-1)+'\t'+strand
        outfile.write(outline+'\n')
   
    outfile.close()
   
run()
