##################################
#                                #
# Last modified 01/29/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s GENCODE-gtf outfileprefix [-CDS]' % sys.argv[0]
        sys.exit(1)

    inputfilename = sys.argv[1]
    outputfilename = sys.argv[2]

    doCDS=False
    if '-CDS' in sys.argv:
        doCDS=True
        print 'will only consider the CDS regions'

    outfile = open(outputfilename, 'w')

    GeneDict={}

    lineslist=open(inputfilename)
    i=0
    for line in lineslist:
        if line.startswith('#'):
            continue
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        fields=line.split('\t')
        if doCDS:
            if fields[2] != 'CDS':
                continue
        else:
            if fields[2] != 'exon':
                continue
        GeneID=fields[8].split('gene_id "')[1].split('";')[0]
        TranscriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            GeneName=fields[8].split('gene_name "')[1].split('";')[0]
            TranscriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            GeneName=GeneID
            TranscripName=TranscriptID
        chr=fields[0]
        strand=fields[6]
        left=int(fields[3])
        right=int(fields[4])
        if GeneDict.has_key(GeneID):
            pass
        else:
            GeneDict[GeneID]={}
            GeneDict[GeneID]['name']=GeneName
            GeneDict[GeneID]['transcripts']={}
        if GeneDict[GeneID]['transcripts'].has_key(TranscriptID):
            pass
        else:
            GeneDict[GeneID]['transcripts'][TranscriptID]={}
            GeneDict[GeneID]['transcripts'][TranscriptID]['name']=TranscriptName
            GeneDict[GeneID]['transcripts'][TranscriptID]['exons']=[]
        GeneDict[GeneID]['transcripts'][TranscriptID]['exons'].append((chr,left,right,strand))

    outline='#GeneID\tGeneName\tNumberTranscrips\tTranscriptID\tTranscriptName\tchr\tstrand\tUniqueCoordinates\n'
    outfile.write(outline)

    N=0
    for GeneID in GeneDict.keys():
        N+=1
        if N % 1000 == 0:
            print N, 'loci processed'
        if len(GeneDict[GeneID])==1:
            continue
        PositionsDict={}
        UniquePositionsDict={}
        for TranscriptID in GeneDict[GeneID]['transcripts'].keys():
            PositionsDict[TranscriptID]=[]
            positions=[]
            for (chr,left,right,strand) in GeneDict[GeneID]['transcripts'][TranscriptID]['exons']:
                for i in range(left,right):
                    positions.append(i)
                positions.sort()
            for i in range(len(positions)-1):
                PositionsDict[TranscriptID].append((positions[i],positions[i+1]))
            PositionsDict[TranscriptID]=Set(PositionsDict[TranscriptID])
        for TranscriptID in GeneDict[GeneID]['transcripts'].keys():
            UniquePositionsDict[TranscriptID]=PositionsDict[TranscriptID]
            for TranscriptID2 in GeneDict[GeneID]['transcripts'].keys():
                if TranscriptID2 == TranscriptID:
                    continue
                UniquePositionsDict[TranscriptID] = UniquePositionsDict[TranscriptID].difference(PositionsDict[TranscriptID2])
            UniquePositionsDict[TranscriptID]=list(UniquePositionsDict[TranscriptID])
            UniquePositionsDict[TranscriptID].sort()
            outline=GeneID+'\t'+GeneName+'\t'+TranscriptID+'\t'+TranscriptName+'\t'+chr+'\t'+strand+'\t'
            if len(UniquePositionsDict[TranscriptID])==0:
                outfile.write(outline+'none\n')
                continue
            PosToPosDict={}
            for (pos1,pos2) in UniquePositionsDict[TranscriptID]:
                PosToPosDict[pos1]=pos2
            positionsOrder=PosToPosDict.keys()
            positionsOrder.sort()
            outline=outline+str(positionsOrder[0])
            currentPos=positionsOrder[0]
            lastPos=positionsOrder[0]
            for pos in positionsOrder:
                if pos == currentPos:
                    if pos == max(positionsOrder):
                        if outline.endswith(str(pos)):
                            pass
                        else:
                            outline=outline + '-' + str(pos)
                        continue
                    if pos + 1 == PosToPosDict[pos]:
                        currentPos=PosToPosDict[pos]
                    else:
                        if outline.endswith(str(pos)):
                            outline=outline + '|' + str(PosToPosDict[pos])
                        else:
                            outline=outline + '-' + str(pos) + '|' + str(PosToPosDict[pos])
                        currentPos=PosToPosDict[pos]
                else:
                    if pos + 1 == PosToPosDict[pos]:
                        currentPos=PosToPosDict[pos]
                        if outline.endswith(str(lastPos)):
                            outline=outline + ',' + str(pos)
                        else:
                            outline=outline + '-' + str(lastPos) + ',' + str(pos)
                    else:
                        if outline.endswith(str(lastPos)):
                            outline=outline + ',' + str(pos) + '|' + str(PosToPosDict[pos])
                        else:
                            outline=outline + '-' + str(lastPos) + ',' + str(pos) + '|' + str(PosToPosDict[pos])
                        currentPos=PosToPosDict[pos]
                lastPos=PosToPosDict[pos]
            outfile.write(outline+'\n')

    outfile.close()

run()

