##################################
#                                #
# Last modified 05/10/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
from sets import Set
import scipy.stats

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s list-of-wiggle-tracks normalization_percentile mappability_track mappability_track_max_score outfilename_prefix [-average bp]' % sys.argv[0]
        print '\tlist-of-wiggle-tracks format: label\ttrack_name'
        print '\tnormalization_percentile: the percentile of scores (only perfectly mappable bases considered) for which the mean is to be calculated; i.e. if 0.90 is given, the Gamma distribution parameters for the first 90% of perfectly mappable bases (ranked by wiggle signal) will be taken, and the Z-score for all 100% of bases will be calculated using those'
        print '\tmappability_track_max_score: the read length for which the mappbility track was generated, i.e. 36 for 36bp reads'
        sys.exit(1)

    inputfilename = sys.argv[1]
    percentile = float(sys.argv[2])
    mappability = sys.argv[3]
    mappabilityMaxScore = int(sys.argv[4])
    outputfileprefix = sys.argv[5]

    average = 1
    if '-average' in sys.argv:
        average = int(sys.argv[sys.argv.index('-average') + 1])
        print 'will average signal over windows of', average, 'bp'

    OutlineDict={}
    MappabilityDict={}
    lineslist = open(mappability)
    i=0
    Mappable = 0
    NotMappable = 0
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        if MappabilityDict.has_key(chr):
            pass
        else:
            MappabilityDict[chr]={}
        start = int(fields[1])
        end = int(fields[2])
        score = float(fields[3])
        for i in range(start,end):
            MappabilityDict[chr][i]=score
            OutlineDict[(chr,i)] = chr + '\t' + str(i)
            if score >= mappabilityMaxScore:
                Mappable += 1
            else:
                NotMappable += 1
    maxMapRange = i
    HeaderOutline = '#'

    MaxZScoreDict = {}
    MinZScoreDict = {}
   
    HeaderOutline = '#Sample'
    for i in range(0,maxMapRange,average):
        HeaderOutline = HeaderOutline + '\t' + str(i)
    outfile = open(outputfileprefix + '.ZScores.table','w') 
    outfile.write(HeaderOutline + '\n')

    lineslist = open(inputfilename)
    i=0
    for line1 in lineslist:
        fields1 = line1.strip().split('\t')
        label = fields1[0]
        file = fields1[1]
        lines = open(file)
        ScoreList=[]
        lc = 0
        for line in lines:
            lc += 1
            fields = line.strip().split('\t')
            chr = fields[0]
            start = int(fields[1])
            end = int(fields[2])
            score = float(fields[3])
            for i in range(start,min(end,maxMapRange)):
                if MappabilityDict[chr].has_key(i):
                    pass
                else:
                    MappabilityDict[chr][i] = 0
                if MappabilityDict[chr][i] >= mappabilityMaxScore:
                    ScoreList.append(score)
        ScoreList.sort()
        NumberMappableBases = len(ScoreList)
        PercentileBases = int(percentile*NumberMappableBases)
        print maxMapRange, PercentileBases, NumberMappableBases, label
        if lc == 0:
            print 'no lines in track, skipping'
            continue
        Scores = []
        for i in range(PercentileBases):
            Scores.append(ScoreList[i])
        (shape,loc,scale) = scipy.stats.gamma.fit(Scores)
        rv = scipy.stats.gamma(shape, loc=loc, scale=scale)
        GammaMean = rv.mean()
        GammaSTD = rv.std()
        ZScores = []
        lines = open(file)
        GammaZScoreDict = {}
        for i in range(maxMapRange):
            GammaZScoreDict[i] = 0
        for line in lines:
            fields = line.strip().split('\t')
            chr = fields[0]
            start = int(fields[1])
            end = int(fields[2])
            score = float(fields[3])
            for i in range(start,min(end,maxMapRange)):
                if GammaSTD == 0 or MappabilityDict[chr][i] < mappabilityMaxScore:
                    GammaZScore = 0
                else:
                    GammaZScore = (score - GammaMean)/GammaSTD
                ZScores.append((GammaZScore,chr,start))
                GammaZScoreDict[i] = GammaZScore
        outline = label
        for i in range(0,maxMapRange,average):
            zs = []
            for j in range(i,min(maxMapRange,i+average)):
                zs.append(GammaZScoreDict[j])
            outline = outline + '\t' + str("{0:.2f}".format(sum(zs)/(average + 0.0)))
        outfile.write(outline + '\n')
        MaxZScoreDict[label] = max(ZScores)
        MinZScoreDict[label] = min(ZScores)

    outfile.close()

    outfile = open(outputfileprefix + '.MaxZScores','w') 
    keys = MaxZScoreDict.keys()
    keys.sort()
    for label in keys:
        outfile.write(label + '\t' + str(MaxZScoreDict[label][0]) + '\t' + str(MaxZScoreDict[label][1]) + '\t' + str(MaxZScoreDict[label][2]) + '\t' + str(MinZScoreDict[label][0]) + '\t' + str(MinZScoreDict[label][1]) + '\t' + str(MinZScoreDict[label][2]) + '\n')
    outfile.close()

run()

