##################################
#                                #
# Last modified 2017/09/28       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s GTF windowsize phastConsNway.wig.bz2 outputfilename' % sys.argv[0]

        sys.exit(1)
    
    gtf = sys.argv[1]
    window = int(sys.argv[2])
    ConsFile = sys.argv[3]
    outfilename = sys.argv[4]

    phastConsDict = {}
    maxScoreDict = {}
    TranscriptDict = {}

    lineslist = open(gtf)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName=fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        transcript = (geneID, geneName, transcriptName, transcriptID)
        if TranscriptDict.has_key(chr):
            pass
        else:
            TranscriptDict[chr] = {}
            phastConsDict[chr] = {}
        if TranscriptDict[chr].has_key(transcript):
            pass
        else:
            TranscriptDict[chr][transcript]=[]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        TranscriptDict[chr][transcript].append((geneName,chr,left,right,strand))
        for i in range(left-10,right+10):
            phastConsDict[chr][i] = 0

    if ConsFile.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + ConsFile
    elif ConsFile.endswith('.gz'):
        cmd = 'gunzip -c ' + ConsFile
    else:
        cmd = 'cat ' + ConsFile
    p = os.popen(cmd, "r")
    L = 0
    while line != '':
        line = p.readline()
        if line == '':
            break
        L+=1
        if L % 1000000 == 0:
            print str(L/1000000) + 'M lines processed'
        if line.startswith('fixedStep'):
            currentChr = line.split('chrom=')[1].split(' ')[0]
            currentPos = int(line.split('start=')[1].split(' ')[0])
            step = int(line.split('step=')[1].split(' ')[0])
        else:
            if phastConsDict.has_key(currentChr):
                if phastConsDict[currentChr].has_key(currentPos):
                    phastConsDict[currentChr][currentPos] = float(line.strip())
            currentPos = currentPos + step

    for chr in TranscriptDict.keys():
        print chr, len(phastConsDict[chr].keys())
        for transcript in TranscriptDict[chr].keys():
            bases = []
            for (geneName,chr,left,right,strand) in TranscriptDict[chr][transcript]:
                for i in range(left,right):
                    bases.append(i)
            bases.sort()
            maxScore = 0
            for i in range(len(bases) - window):
                score = 0.0
                for j in range(i,i + window):
                    score += phastConsDict[chr][bases[j]]
                score = score/window
                if maxScore < score:
                    maxScore = score
            maxScoreDict[transcript] = maxScore
          
    outfile = open(outfilename, 'w')
    outline = '#geneID\tgeneName\ttranscriptName\ttranscriptID\tmaxConsScore'
    outfile.write(outline + '\n')

    for transcript in maxScoreDict.keys():
        (geneID, geneName, transcriptName, transcriptID) = transcript
        outline = geneID + '\t' + geneName + '\t' + transcriptName + '\t' + transcriptID + '\t' + str(maxScoreDict[transcript])
        outfile.write(outline + '\n')

    outfile.close()
   
run()
