##################################
#                                #
# Last modified 06/23/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s gtf min_distance_from_end outfilename' % sys.argv[0]
        print "       Note: the min_distance_from_end parameter is the distance from the 3' and 5' end that the transcript has to be unique"
        print '       for example, if the first 3 exons are 200bp long each, and the third one is alternatively spliced, if the value is 300, the gene will be picked, but if it is 500, it will be discarded' 
        sys.exit(1)

    gtf = sys.argv[1]
    distance = sys.argv[2]
    outfile = open(sys.argv[3], 'w')

    GeneDict={}
    lineslist = open(gtf)
    i=0
    for line in lineslist:
        if line.startswith('#'):
            continue
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        fields = line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        strand=fields[6]
        left=int(fields[3])
        right=int(fields[4])
        gene = fields[8].split('gene_id "')[1].split('";')[0]
        transcript = fields[8].split('transcript_id "')[1].split('";')[0]
        if GeneDict.has_key(gene):
            pass
        else:
            GeneDict[gene]={}
        if GeneDict[gene].has_key(transcript):
            pass
        else:
            GeneDict[gene][transcript]=[]
        GeneDict[gene][transcript].append((chr,left,right,strand))

    GenesToOutput={}

    for gene in GeneDict.keys():
        if len(GeneDict[gene].keys())==1:
            GenesToOutput[gene]=''
            continue
        ends5=[]
        ends3=[]
        for transcript in GeneDict[gene].keys():
            strand=GeneDict[gene][transcript][0][3]
            coordinates=[]
            for (chr,left,right,strand) in GeneDict[gene][transcript]:
                coordinates.append(left)
                coordinates.append(right)
            if strand == '+' or strand == 'F':
                ends5.append(min(coordinates))
                ends3.append(max(coordinates))
            if strand == '-' or strand == 'R':
                ends3.append(min(coordinates))
                ends5.append(max(coordinates))
        ends5=list(Set(ends5))    
        ends3=list(Set(ends3))    
        if len(ends5) != 1 or len(ends3) != 1:
            continue
        transcript=GeneDict[gene].keys()[0]
        leftexons_1=[]
        rightexons_1=[]
        coverednt=0
        GeneDict[gene][transcript].sort()
        for (chr,left,right,strand) in GeneDict[gene][transcript]:
            coverednt+=(right-left)
            leftexons_1.append((chr,left,right,strand))
            if coverednt >= distance:
                break        
        coverednt=0
        for (chr,left,right,strand) in reversed(GeneDict[gene][transcript]):
            coverednt+=(right-left)
            rightexons_1.append((chr,left,right,strand))
            if coverednt >= distance:
                break
        Distinct=True
        for transcript in GeneDict[gene].keys():
            leftexons=[]
            rightexons=[]
            GeneDict[gene][transcript].sort()
            for i in range(len(leftexons_1)):
                leftexons.append(GeneDict[gene][transcript][i])
            for i in range(len(rightexons_1)):
                rightexons.append(GeneDict[gene][transcript][-i])
            if leftexons != leftexons_1 or rightexons != rightexons_1:
                Distinct=False
                break
        if Distinct:
            GenesToOutput[gene]=''

    lineslist = open(gtf)
    i=0
    for line in lineslist:
        if line.startswith('#'):
            outfile.write(line)
            continue
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        fields = line.strip().split('\t')
        gene = fields[8].split('gene_id "')[1].split('";')[0]
        if GenesToOutput.has_key(gene):
            outfile.write(line)

    outfile.close()

run()

