##################################
#                                #
# Last modified 2018/02/09       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
from sets import Set

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s inputs genesfilename chrFieldID leftFieldID rightFieldID strandFieldID radius outputfilename [-geneStrand + | -] [-normalize] [-smooth radius] [-middleLen bp]' % sys.argv[0]
        print 'list of inputs format: label <tab> wiggle filename <tab> sense/antisense'
        print 'the radius is around the TSS and the TTS, the rest of the gene will be rescaled to twice that'
        sys.exit(1)
    
    inputs = sys.argv[1]
    genes = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    leftFieldID = int(sys.argv[4])
    rightFieldID = int(sys.argv[5])
    strandFieldID = int(sys.argv[6])
    radius = int(sys.argv[7])
    outfilename = sys.argv[8]

    doNormalize = False
    if '-normalize' in sys.argv:
        doNormalize=True

    doSmooth = False
    if '-smooth' in sys.argv:
        doSmooth = True
        SR = int(sys.argv[sys.argv.index('-smooth') + 1])
        print 'will smooth values over a radius of', SR

    ML = 2*radius
    if '-middleLen' in sys.argv:
        ML = int(sys.argv[sys.argv.index('-middleLen') + 1])

    doStrand=False
    if '-geneStrand' in sys.argv:
        doStrand=True
        geneStrand = sys.argv[sys.argv.index('-geneStrand')+1]

    GeneDict={}
    ScoreDict={}
    FinalDict={}

    labelDict={}
    listoflines = open(inputs)
    for line in listoflines:
        fields=line.strip().split('\t')
        labelDict[fields[0]]=fields[1]

    listoflines = open(genes)
    for line in listoflines:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if len(fields)<3:
           continue
        chr=fields[chrFieldID]
        left=int(fields[leftFieldID])
        right=int(fields[rightFieldID])
        strand=fields[strandFieldID]
        if doStrand and strand != geneStrand:
            continue
        InternalLength=right-radius-left+radius
        if InternalLength < ML:
            continue
        ratio=InternalLength/ML
        GeneDict[(chr,left,right,strand)]=[]
        for pos in range(left-radius,left+radius):
            GeneDict[(chr,left,right,strand)].append(pos)
        i=0
        for r in range(ML):
            pos=left+radius+int(i)
            GeneDict[(chr,left,right,strand)].append(pos)
            i=i+ratio
        for pos in range(right-radius,right+radius):
            GeneDict[(chr,left,right,strand)].append(pos)
        if strand == '-' or strand == 'R':
            GeneDict[(chr,left,right,strand)].reverse()
        if ScoreDict.has_key(chr):
            pass
        else:
            ScoreDict[chr]={}
        for pos in GeneDict[(chr,left,right,strand)]:
            ScoreDict[chr][pos]=0

    print 'Imported', len(GeneDict.keys()), 'genes'    
    print 'Importing wiggle scores'    
    
    outfile=open(outfilename,'w')

    labels=labelDict.keys()
    labels.sort()
    for label in labels:
        wigfilename=labelDict[label]
#        listoflines = open(wigfilename)
        for chr in ScoreDict.keys():
            for i in ScoreDict[chr].keys():
                ScoreDict[chr][i]=0
        j=0
        FinalDict[label]={}
        if wigfilename.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + wigfilename
        elif wigfilename.endswith('.gz'):
            cmd = 'gunzip -c ' + wigfilename
        elif wigfilename.endswith('.zip'):
            cmd = 'unzip -p ' + wigfilename
        else:
            cmd = 'cat ' + wigfilename
        p = os.popen(cmd, "r")
        line = 'line'
        currentChr = ''
        while line != '':
            line = p.readline().strip()
            if line == '':
                break
#        for line in listoflines:
            j+=1
            if j % 1000000 == 0:
                print label, j, 'lines processed'
            if ' ' in line:
                fields=line.strip().split(' ')
            else:
                fields=line.strip().split('\t')
            chr=fields[0]
            try:
                start=int(fields[1])
            except:
                continue
            stop=int(fields[2])
            if start == stop:
                stop += 1
            for i in range(start,stop):
                if ScoreDict.has_key(chr): 
                    score=float(fields[3])
                    if ScoreDict[chr].has_key(i):
                        ScoreDict[chr][i]=score
#        for (chr,left,right,strand) in GeneDict.keys():
#            GeneDict[(chr,left,right,strand)]=[]
#            for i in range(left-radius,left+radius):
#                GeneDict[(chr,left,right,strand)].append(ScoreDict[chr][i])
#            InternalLength=right-radius-left+radius
#            ratio=InternalLength/ML
#            for i in range(right-radius,right+radius):
#                GeneDict[(chr,left,right,strand)].append(ScoreDict[chr][i])
        for i in range(0-radius,2*radius + ML + radius):
            FinalDict[label][i]=0.0
        for (chr,left,right,strand) in GeneDict.keys():
            for i in range(0-radius,2*radius + ML + radius):
                FinalDict[label][i]+=ScoreDict[chr][GeneDict[(chr,left,right,strand)][i+radius]]

    outline='#'       
    for label in labels:
        outline=outline+'\t'+label
    outfile.write(outline+'\n')
    for i in range(0-radius,2*radius + ML + radius):
        outline = str(i)
        for label in labels:
            if doSmooth:
                final_val = 0.0
                a_i = max(i - SR,min(FinalDict[label].keys()))
                b_i = min(i + SR,max(FinalDict[label].keys()))
                for j in range(a_i,b_i):
                    final_val += FinalDict[label][j]
                final_val = final_val/(b_i - a_i)
            else:
                final_val = FinalDict[label][i]
            if doNormalize:
                outline = outline + '\t' + str(final_val/len(GeneDict))
            else:
                outline = outline + '\t' + str(final_val)
        outfile.write(outline + '\n')

    outfile.close()
   
run()
