##################################
#                                #
# Last modified 08/23/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
import string

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s gtf chrom.sizes outfilename' % sys.argv[0]
        sys.exit(1)

    gtf = sys.argv[1]
    chromsizes = sys.argv[2]
    outputfilename = sys.argv[3]

    ChrDict = {}
    chrEndPoints = []
    lineslist = open(chromsizes)
    TranscriptDict = {}
    TL = 0
    for line in lineslist:
        chrEndPoints.append(TL)
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr = fields[0]
        length = int(fields[1])
        ChrDict[TL] = chr
        TL += length
    chrEndPoints.append(TL)
    ChrDict[TL] = 'end'

    j=0
    lineslist = open(gtf)
    TranscriptDict = {}
    for line in lineslist:
        j+=1
        if j % 100000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName=fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        transcript = (geneID, geneName, transcriptName, transcriptID)
        if TranscriptDict.has_key(transcript):
            pass
        else:
            TranscriptDict[transcript]=[]
        left=int(fields[3])
        right=int(fields[4])
        orientation=fields[6]
        TranscriptDict[transcript].append((chr,left,right,orientation))

    TranscriptToNewTranscriptDict = {}
    for transcript in TranscriptDict.keys():
        coordinates = []
        for (chr,left,right,orientation) in TranscriptDict[transcript]:
            coordinates.append(left)
            coordinates.append(right)
        minPos = min(coordinates)
        maxPos = max(coordinates)
        NotFound = True
        while NotFound:
            newPos = random.randint(0,TL)
            for i in range(len(chrEndPoints)-1):
                if newPos >= chrEndPoints[i] and newPos < chrEndPoints[i+1]:
                    newchr = ChrDict[chrEndPoints[i]]
                    newLeft = newPos - chrEndPoints[i]
                    newRight = newLeft + (maxPos - minPos)
                    if newRight < chrEndPoints[i+1]:
                        NotFound = False
                    break
        TranscriptToNewTranscriptDict[transcript] = (newchr,newLeft,minPos)

    print 'finished generating new transcript ends'

    outfile = open(outputfilename, 'w')

    lineslist = open(gtf)
    TranscriptDict = {}
    for line in lineslist:
        j+=1
        if j % 100000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName=fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        transcript = (geneID, geneName, transcriptName, transcriptID)
        left=int(fields[3])
        right=int(fields[4])
        (newchr,newLeft,minPos) = TranscriptToNewTranscriptDict[transcript]
        leftNew = newLeft + (left - minPos)
        rightNew = newLeft + (right - minPos)
        outline = newchr + '\t' + fields[1] + '\t' + fields[2] + '\t' + str(leftNew) + '\t' + str(rightNew) + '\t' + fields[5] + '\t' +  fields[6] + '\t' +  fields[7] + '\t' +  fields[8]
        outfile.write(outline + '\n')

    outfile.close()

run()

