##################################
#                                #
# Last modified 2017/07/10       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import string
import os
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s gtf repeatmask chrFieldID labelFieldID(s) outprefix [-minOverlap bp] [-wholeGene extension(bp)]' % sys.argv[0]
        sys.exit(1)

    gtf = sys.argv[1]
    repeatmask = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    labelFieldIDs = []
    Ls = sys.argv[4].split(',')
    for L in Ls:
        labelFieldIDs.append(int(L))
    outprefix = sys.argv[5]

    doWG = False
    if '-wholeGene' in sys.argv:
        doWG = True
        WGext = int(sys.argv[sys.argv.index('-wholeGene') + 1])

    minOL = 1
    if '-minOverlap' in sys.argv:
        minOL = int(sys.argv[sys.argv.index('-minOverlap') + 1])

    TranscriptDict = {}

    lineslist = open(gtf)
    ExonDict = {}
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            geneName = fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneID = geneName
        if 'transcript_name "' in fields[8]:
            transcriptName = fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName = transcriptID
        if ExonDict.has_key(chr):
            pass
        else:
            ExonDict[chr] = {}
        if TranscriptDict.has_key((geneID,geneName,transcriptID,transcriptName)):
            pass
        else:
            TranscriptDict[(geneID,geneName,transcriptID,transcriptName)] = []
        TranscriptDict[(geneID,geneName,transcriptID,transcriptName)].append((chr,left,right,strand))
        ExonDict[chr][left] = {}
        ExonDict[chr][left][right] = (chr,left,right,strand,geneID,geneName,transcriptID,transcriptName)
        ExonDict[chr][right] = {}
        ExonDict[chr][right][left] = (chr,left,right,strand,geneID,geneName,transcriptID,transcriptName)

    print 'finished inputting annotation'

    if doWG:
        ExonDict = {}
        for (geneID,geneName,transcriptID,transcriptName) in TranscriptDict.keys():
            coordinates = []
            for (chr,left,right,strand) in TranscriptDict[(geneID,geneName,transcriptID,transcriptName)]:
                coordinates.append(left)
                coordinates.append(right)
            left = min(coordinates) - WGext
            right = max(coordinates) + WGext
            if ExonDict.has_key(chr):
                pass
            else:
                ExonDict[chr] = {}
            ExonDict[chr][left] = {}
            ExonDict[chr][left][right] = (chr,left,right,strand,geneID,geneName,transcriptID,transcriptName)
            ExonDict[chr][right] = {}
            ExonDict[chr][right][left] = (chr,left,right,strand,geneID,geneName,transcriptID,transcriptName)

    OLDict = {}
    OLList = []
    if repeatmask.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + repeatmask
    elif repeatmask.endswith('.gz'):
        cmd = 'zcat ' + repeatmask
    else:
        cmd = 'cat ' + repeatmask
    print cmd
    p = os.popen(cmd, "r")
    line = 'line'
    LK = 0
    while line != '':
        LK += 1
        if LK % 100000 == 0:
            print LK, 'lines processed'
        line = p.readline()
        if line == '\n':
            continue
        if line == '':
            break
        while '  ' in line:
            line = line.replace('  ',' ')
        line = line.strip().replace(' ','\t')
        if line.startswith('SW'):
            continue
        if line.startswith('score'):
            continue
        fields = line.split('\t')
        rchr = fields[chrFieldID]
        rleft = int(fields[chrFieldID + 1])
        rright = int(fields[chrFieldID + 2])
#        rstrand = fields[chrFieldID + ]
        label = ''
        for L in labelFieldIDs:
            label = label + '::' + fields[L]
        label = label[2:]
        OLDict[label] = {}
        for i in range(rleft,rright):
            if ExonDict.has_key(rchr):
                if ExonDict[rchr].has_key(i):
                    for j in ExonDict[rchr][i].keys():
                        (chr,left,right,strand,geneID,geneName,transcriptID,transcriptName) = ExonDict[rchr][i][j]
                        OL = min(right,rright) - max(left,rleft)
                        if OL >= minOL:
                            OLList.append((chr,left,right,strand,geneID,geneName,transcriptID,transcriptName,label,rleft,rright))

    OLList = list(Set(OLList)) 
    OLList.sort()

    outfile1 = open(outprefix + '.individual_overlaps', 'w')
#    outline = '#geneID\tgeneName\ttranscriptID\ttranscriptName\tchr\texon_left\texon_right\tstrand\tlabel\tleft\tright\strand'
    outline = '#geneID\tgeneName\ttranscriptID\ttranscriptName\tchr\tstrand\texon_left\texon_right\tlabel\tleft\tright'
    outfile1.write(outline + '\n')

    print 'Total transcripts:', len(TranscriptDict.keys())

    for (chr,left,right,strand,geneID,geneName,transcriptID,transcriptName,label,rleft,rright) in OLList:
        outline = geneID + '\t' + geneName + '\t' + transcriptID + '\t' + transcriptName + '\t' + chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + label + '\t' + str(rleft) + '\t' + str(rright)
        outfile1.write(outline + '\n')
        OLDict[label][geneID] = 1
        if TranscriptDict.has_key((geneID,geneName,transcriptID,transcriptName)):
            del TranscriptDict[(geneID,geneName,transcriptID,transcriptName)]

    print 'Total transcripts without an overlap:', len(TranscriptDict.keys())

    for (geneID,geneName,transcriptID,transcriptName) in TranscriptDict.keys():
        outline = geneID + '\t' + geneName + '\t' + transcriptID + '\t' + transcriptName + '\t' + 'nan' + '\t' + 'nan' + '\t' + 'nan' + '\t' + 'nan' + '\t' + 'nan' + '\t' + 'nan' + '\t' + 'nan'
        outfile1.write(outline + '\n')

    outfile2 = open(outprefix + '.repeat_counts', 'w')
    outline = '#repeat\tnumber_overlapping_genes'
    outfile2.write(outline + '\n')
    for label in OLDict.keys():
        outline = label + '\t' + str(len(OLDict[label].keys()))
        outfile2.write(outline + '\n')
    
    outfile1.close()
    outfile2.close()

run()

