##################################
#                                #
# Last modified 06/02/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s transcriptome_make_reads_fasta_output gtf minOverhang outifle' % sys.argv[0]
        print '\t use - to use stdin'
        sys.exit(1)

    reads = sys.argv[1]
    gtf = sys.argv[2]
    minOverhang = int(sys.argv[3])
    outfilename = sys.argv[4]

    TranscriptDict = {}

    linelist = open(gtf)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        start=int(fields[3])
        stop=int(fields[4])
        strand=fields[6]
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcriptName = fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName = transcriptID
        if 'gene_name "' in fields[8]:
            geneName = fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName = geneID
        transcript = geneName + ':' + geneID + ':' + transcriptName + ':' + transcriptID
        if TranscriptDict.has_key(transcript):
            pass
        else:
            TranscriptDict[transcript]=[]
        TranscriptDict[transcript].append((chr,start,stop,strand))

    TranscriptToJuncDict = {}
    JunctionsDict = {}

    for transcript in TranscriptDict.keys():
        TranscriptDict[transcript] = list(Set(TranscriptDict[transcript]))
        TranscriptDict[transcript].sort()
        TranscriptToJuncDict[transcript] = {}
        chr = TranscriptDict[transcript][0][0]
        strand = TranscriptDict[transcript][0][3]
        TL = 0
        for i in range(len(TranscriptDict[transcript])):
            TL += (TranscriptDict[transcript][i][2] - TranscriptDict[transcript][i][1])
        if strand == '+':
            pos = 0
        if strand == '-':
            pos = TL
        for i in range(len(TranscriptDict[transcript])-1):
            junction=(chr,TranscriptDict[transcript][i][2],TranscriptDict[transcript][i+1][1],strand)
            JunctionsDict[junction] = 0
            if strand == '+':
                pos += (TranscriptDict[transcript][i][2] - TranscriptDict[transcript][i][1])
            if strand == '-':
                pos -= (TranscriptDict[transcript][i][2] - TranscriptDict[transcript][i][1])
            TranscriptToJuncDict[transcript][pos] = junction
            
    if reads == '-':
        linelist = sys.stdin
    else:
        linelist = open(reads)
    skipped = 0
    for line in linelist:
        if line.startswith('>'):
            fields = line.strip().split('>')[1].split(':')
            transcript = fields[0] + ':' + fields[1] + ':' + fields[2] + ':' + fields[3]
            positions = fields[-1].split('-')
            pos1 = int(positions[0])
            pos2 = int(positions[1])
#            print transcript, pos1, pos2, pos1 + minOverhang, pos2 - minOverhang, 
            try:
                a = len(TranscriptToJuncDict[transcript].keys())
            except:
                skipped += 1
                continue
            for pos in TranscriptToJuncDict[transcript].keys():
                if (pos1 + minOverhang < pos) and (pos < pos2 - minOverhang):
                    junction = TranscriptToJuncDict[transcript][pos]
                    JunctionsDict[junction] += 1
#                    print pos, junction, JunctionsDict[junction]

    outfile = open(outfilename, 'w')
    outline = '#Skipped reads: ' + str(skipped)
    outfile.write(outline+'\n')
    outline = '#chr\tleft\tright\tstrand\tCounts'
    outfile.write(outline+'\n')

    JunctionsList=JunctionsDict.keys()
    JunctionsList.sort()

    for junction in JunctionsList:
        (chr,left,right,strand) = junction
        outline=chr+'\t'+str(left-1)+'\t'+str(right-1)+'\t'+strand + '\t' + str(JunctionsDict[junction])
        outfile.write(outline+'\n')
   
    outfile.close()

    print skipped

run()

