##################################
#                                #
# Last modified 2017/06/15       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
import string
import numpy as np

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s gtf minExonNumber outfilename' % sys.argv[0]
        sys.exit(1)

    gtf=sys.argv[1]
    minExNum = int(sys.argv[2])
    outputfilename = sys.argv[3]

    print outputfilename

    NumExons = 0
    j=0
    lineslist = open(gtf)
    TranscriptDict={}
    for line in lineslist:
        j+=1
        if j % 100000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        NumExons += 1
        chr=fields[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName=fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        transcript = (geneID, geneName, transcriptName, transcriptID)
        if TranscriptDict.has_key(transcript):
            pass
        else:
            TranscriptDict[transcript]=[]
        left=int(fields[3])
        right=int(fields[4])
        orientation=fields[6]
        TranscriptDict[transcript].append((chr,left,right,orientation))

    IntronLengthDict = {}
    AllILs = []
    nonFirstILs = []

    print 'Found', len(TranscriptDict.keys()), 'transcripts, and ', NumExons, ' exons'

    for transcript in TranscriptDict.keys():
        if len(TranscriptDict[transcript]) < minExNum:
            continue
        strand = TranscriptDict[transcript][0][3]
        TranscriptDict[transcript].sort()
        if strand == '+':
            for i in range(1,len(TranscriptDict[transcript])):
                IL = TranscriptDict[transcript][i][1] - TranscriptDict[transcript][i-1][2]
                if IntronLengthDict.has_key(i):
                    pass
                else:
                    IntronLengthDict[i] = []
                IntronLengthDict[i].append(IL)
                AllILs.append(IL)
                if i != 1:
                    nonFirstILs.append(IL)
        if strand == '-':
            TranscriptDict[transcript].reverse()
            for i in range(1,len(TranscriptDict[transcript])):
                IL = TranscriptDict[transcript][i-1][1] - TranscriptDict[transcript][i][2]
                if IntronLengthDict.has_key(i):
                    pass
                else:
                    IntronLengthDict[i] = []
                IntronLengthDict[i].append(IL)
                AllILs.append(IL)
                if i != 1:
                    nonFirstILs.append(IL)

    outfile = open(outputfilename, 'w')

    outline = '#Intron_position\tnumber\tMean_intron_length\tstdev\tRatio_to_mean_intron_length\tRatio_to_mean_non-first_intron_length\tMedian_intron_length\tRatio_to_median_intron_length\tRatio_to_median_non-first_intron_length'
    outfile.write(outline + '\n')

    MeanIL = np.mean(AllILs)
    nonFirstMeanIL = np.mean(nonFirstILs)
    MedianIL = np.median(AllILs)
    nonFirstMedianIL = np.median(nonFirstILs)

    ILs = IntronLengthDict.keys()
    ILs.sort()
    for i in ILs:
        MedIL = np.median(IntronLengthDict[i])
        MIL = np.mean(IntronLengthDict[i])
        numI = len(IntronLengthDict[i])
        stdIL = np.std(IntronLengthDict[i])
        ratiotoMeanIL = MIL/MeanIL
        ratiotoNonFirstMeanIL = MIL/nonFirstMeanIL
        ratiotoMedianIL = MedIL/MedianIL
        ratiotoNonFirstMedianIL = MedIL/nonFirstMedianIL
        outline = str(i) + '\t' + str(numI) + '\t' + str(MIL) + '\t' + str(stdIL) + '\t' + str(ratiotoMeanIL) + '\t' + str(ratiotoNonFirstMeanIL) + '\t' + str(MedIL) + '\t' + str(ratiotoMedianIL) + '\t' + str(ratiotoNonFirstMedianIL)
        outfile.write(outline + '\n')

    outfile.close()

run()

