##################################
#                                #
# Last modified 03/10/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import numpy
import operator

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s GCskewoutput outprefix [-NaN symbol] [-minLen bp] [-normalizeLength] [-normalizeSignal]' % sys.argv[0]
        print 'Assumed input format:'
        print '\t#chr\twindow\tGCskew\tATskew\tGCskew_cum\tATskew_cum'
        print '\tthe -normalizeLength option will rescale verything to the size of the shortest chromosome'
        sys.exit(1)

    input = sys.argv[1]
    outprefix = sys.argv[2]

    NaN = 'nan'

    if '-NaN' in sys.argv:
        NaN = sys.argv[sys.argv.index('-NaN') + 1]

    doNormalizeLength = False
    if '-normalizeLength' in sys.argv:
        doNormalizeLength = True

    doNormalizeSignal = False
    if '-normalizeSignal' in sys.argv:
        doNormalizeSignal = True

    doMinLen = False
    if '-minLen' in sys.argv:
        doMinLen = True
        minLen = int(sys.argv[sys.argv.index('-minLen') + 1])

    outfileGC = open(outprefix + '.GC.txt', 'w')
    outfileAT = open(outprefix + '.AT.txt', 'w')

    ChromosomeDict = {}

    linelist = open(input)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        if ChromosomeDict.has_key(chr):
            pass
        else:
            ChromosomeDict[chr] = {}
            ChromosomeDict[chr]['GC'] = {}
            ChromosomeDict[chr]['AT'] = {}
        pos = int(fields[1])
        GC = float(fields[4])
        AT = float(fields[5])
        ChromosomeDict[chr]['GC'][pos] = GC
        ChromosomeDict[chr]['AT'][pos] = AT

    chrlengths = []
    for chr in ChromosomeDict.keys():
        length = max(ChromosomeDict[chr]['GC'].keys())
        if doMinLen:
            if length < minLen:
                continue
        chrlengths.append((chr,length))

    chrlengths.sort(key=operator.itemgetter(1))

    shortestChr = chrlengths[0][0]
    longestChr = chrlengths[-1][0]

    outline = 'chr'

    if doNormalizeLength:
        positions = ChromosomeDict[shortestChr]['AT'].keys()
    else:
        positions = ChromosomeDict[longestChr]['AT'].keys()
    positions.sort()
    for pos in positions:
        outline = outline + '\t' + str(pos)
    outfileGC.write(outline + '\n')
    outfileAT.write(outline + '\n')

    for (chr,length) in chrlengths:
        outline = chr
        if doNormalizeLength:
            step = len(ChromosomeDict[chr]['GC'].keys())/(len(positions) + 0.0)
            pos0 = min(ChromosomeDict[chr]['GC'].keys())
            NewDictGC = {}
            NewDictAT = {}
            P = pos0
            K = pos0
            for pos in positions:
                scoresGC = []
                scoresAT = []
                for i in range(K, int(P + step*pos0), pos0):
                    scoresGC.append(ChromosomeDict[chr]['GC'][i])
                    scoresAT.append(ChromosomeDict[chr]['AT'][i])
                P += int(step*pos0)
                K = i + pos0
                NewDictGC[pos] = numpy.mean(scoresGC)
                NewDictAT[pos] = numpy.mean(scoresAT)
            ChromosomeDict[chr]['GC'] = {}
            ChromosomeDict[chr]['AT'] = {}
            for pos in positions:
                ChromosomeDict[chr]['GC'][pos] = NewDictGC[pos]
                ChromosomeDict[chr]['AT'][pos] = NewDictAT[pos]
	        if doNormalizeSignal:
            maxAbsSignalGC = 0.0
            maxAbsSignalAT = 0.0
            for pos in positions:
                if ChromosomeDict[chr]['GC'].has_key(pos):
                   if math.fabs(ChromosomeDict[chr]['GC'][pos]) > maxAbsSignalGC:
                       maxAbsSignalGC = math.fabs(ChromosomeDict[chr]['GC'][pos])
                if ChromosomeDict[chr]['AT'].has_key(pos):
                   if math.fabs(ChromosomeDict[chr]['AT'][pos]) > maxAbsSignalAT:
                       maxAbsSignalAT = math.fabs(ChromosomeDict[chr]['AT'][pos])
        for pos in positions:
            if ChromosomeDict[chr]['GC'].has_key(pos):
                if doNormalizeSignal:
                    outline = outline + '\t' + str(ChromosomeDict[chr]['GC'][pos]/maxAbsSignalGC)
                else:
                    outline = outline + '\t' + str(ChromosomeDict[chr]['GC'][pos])
            else:
                outline = outline + '\t' + NaN
        outfileGC.write(outline + '\n')
        outline = chr
        for pos in positions:
            if ChromosomeDict[chr]['AT'].has_key(pos):
                if doNormalizeSignal:
                    outline = outline + '\t' + str(ChromosomeDict[chr]['AT'][pos]/maxAbsSignalAT)
                else:
                    outline = outline + '\t' + str(ChromosomeDict[chr]['AT'][pos])
            else:
                outline = outline + '\t' + NaN
        outfileAT.write(outline + '\n')

    outfileGC.close()
    outfileAT.close()
   
run()
