##################################
#                                #
# Last modified 06/02/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s GTF outfilename' % sys.argv[0]
        print '       GTF file must have biotypes in it' 
        print '       the script will only consider the introns of protein coding genes' 
        sys.exit(1)

    GTF = sys.argv[1]
    outputfilename = sys.argv[2]
    outfile = open(outputfilename, 'w')

    ProteinCodingTranscriptsDict={}
    miRNADict={}

    lineslist = open(GTF)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        biotype = fields[8].split('gene_type "')[1].split('";')[0]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        geneName = fields[8].split('gene_name "')[1].split('";')[0]
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        if miRNADict.has_key(chr):
            pass
        else:
            print chr
            miRNADict[chr]={}
        if biotype == 'miRNA':
            if fields[2] != 'transcript':
                continue
            miRNADict[chr][(left,right,strand)]=(geneName,geneID,transcriptID)
            continue
        if biotype == 'protein_coding':
            if fields[2] != 'exon':
                continue
            if ProteinCodingTranscriptsDict.has_key((geneName,geneID,transcriptID)):
                pass
            else:
                ProteinCodingTranscriptsDict[(geneName,geneID,transcriptID)]=[]
            ProteinCodingTranscriptsDict[(geneName,geneID,transcriptID)].append((chr,left,right,strand))
            continue
        else:
            continue

    IntronDict={}
    for (geneName,geneID,transcriptID) in ProteinCodingTranscriptsDict.keys():
        ProteinCodingTranscriptsDict[(geneName,geneID,transcriptID)].sort()
        chr = ProteinCodingTranscriptsDict[(geneName,geneID,transcriptID)][0][0]
        strand = ProteinCodingTranscriptsDict[(geneName,geneID,transcriptID)][0][3]
        if IntronDict.has_key(chr):
            pass
        else:
            IntronDict[chr]={}
        if len(ProteinCodingTranscriptsDict[(geneName,geneID,transcriptID)])==1:
            continue
        for i in range(len(ProteinCodingTranscriptsDict[(geneName,geneID,transcriptID)])-1):
            intronLeft = ProteinCodingTranscriptsDict[(geneName,geneID,transcriptID)][i][2]
            intronRight = ProteinCodingTranscriptsDict[(geneName,geneID,transcriptID)][i+1][1]
            for (miRNAleft,miRNAright,miRNAstrand) in miRNADict[chr]:
                if miRNAleft > intronLeft and miRNAright < intronRight:
                    (miRNAgeneName,miRNAgeneID,miRNAtranscriptID) = miRNADict[chr][(miRNAleft,miRNAright,miRNAstrand)]
                    outline = miRNAgeneName + '\t' + miRNAgeneID + '\t' + miRNAtranscriptID + '\t' + chr + '\t' + str(miRNAleft) + '\t' + str(miRNAright) + '\t' + miRNAstrand + '\t' + geneName + '\t' + geneID + '\t' + transcriptID + '\t' + chr + '\t' + str(intronLeft) + '\t' + str(intronRight) + '\t' + strand
                    print outline
                    outfile.write(outline + '\n')

    outfile.close()

run()

