##################################
#                                #
# Last modified 2017/09/15       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
import numpy as np
import scipy
import random
from sets import Set
from scipy.stats import norm

def GCContent(GenomeDict,chr,left,right):

    sequence = GenomeDict[chr][left:right]
    sequence = sequence.upper()
    GC = sequence.count('G') + sequence.count('C')
    AT = sequence.count('A') + sequence.count('T')
 
    return(GC/(GC + AT + 0.0))

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s DataMatrix chrFieldID FieldIDs motifsTrack chrFieldID motiflabelFieldsID genome.fa outprefix [-singleFieldRegions] [-diffTest config] [-GCAccBins N]' % sys.argv[0]
        print '\tboth input files can be .gz or .bz2'
        print '\tDifferential accessibility testing config file format: labelA <tab> fieldIDsA(comma-separated) labelB <tab> fieldIDsB(comma-separated)'
        sys.exit(1)

    DataMatrix = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    fieldIDs = []
    Flist = sys.argv[3].split(',')
    for F in Flist:
        if ':' in F:
            F1 = int(F.split(':')[0])
            F2 = int(F.split(':')[1])
            for i in range(F1,F2+1):
                fieldIDs.append(i)
        else:
            fieldIDs.append(int(F))
    motifs = sys.argv[4]
    motChrFieldID = int(sys.argv[5])
    motiflabelFieldsIDs = []
    Flist = sys.argv[6].split(',')
    for F in Flist:
        if ':' in F:
            F1 = int(F.split(':')[0])
            F2 = int(F.split(':')[1])
            for i in range(F1,F2+1):
                motiflabelFieldsIDs.append(i)
        else:
            motiflabelFieldsIDs.append(int(F))
    fasta = sys.argv[7]
    outprefix = sys.argv[8]

    doSFR = False
    if '-singleFieldRegions' in sys.argv:
        doSFR = True

    doDiffTest = False
    if '-diffTest' in sys.argv:
        doDiffTest = True
        DiffDictOutfile = {}
        DiffDict = {}
        config = sys.argv[sys.argv.index('-diffTest') + 1]
        linelist = open(config)
        for line in linelist:
            fields = line.strip().split('\t')
            A = fields[0]
            B = fields[2]
            AfieldIDs = []
            for ID in fields[1].split(','):
                AfieldIDs.append(int(ID))
            BfieldIDs = []
            for ID in fields[3].split(','):
                BfieldIDs.append(int(ID))
            DiffDictOutfile[(A,B)] = open(outprefix + '.diff_accessibility.' + A + '-vs-' + B, 'w')
            outline = '#motif\tMeanDeviationA\tMeanDeviationB\tDeivationDifference\tt-statistic\tp-val'
            DiffDictOutfile[(A,B)].write(outline + '\n')
            DiffDict[(A,B)] = (AfieldIDs,BfieldIDs)

    B50 = 50
    if '-GCAccBins' in sys.argv:
        B50 = int(sys.argv[sys.argv.index('-GCAccBins') + 1])
        print 'will use', B50, 'bins for GC-Accessibility background sampling'

    GenomeDict={}
    sequence=''
    if fasta.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + fasta
    elif fasta.endswith('.gz'):
        cmd = 'gunzip -c ' + fasta
    else:
        cmd = 'cat ' + fasta
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
#            print chr
            sequence=[]
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    print 'finished parsing genome sequence'

    GCAcc = []
    minGC = ''
    minGC = ''
    minAvAcc = ''
    maxAvAcc = ''

    if DataMatrix.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + DataMatrix
    elif DataMatrix.endswith('.gz'):
        cmd = 'gunzip -c ' + DataMatrix
    else:
        cmd = 'cat ' + DataMatrix
    p = os.popen(cmd, "r")
    K = 0
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        K+=1
        if K % 10000 == 0:
            print K, 'lines processed'
        fields = line.strip().split('\t')
        if doSFR:
            chr = fields[chrFieldID].split(':')[0]
            left = int(fields[chrFieldID].split(':')[1].split('-')[0])
            right = int(fields[chrFieldID].split(':')[1].split('-')[1])
        else:
            chr = fields[chrFieldID]
            left = int(fields[chrFieldID + 1])
            right = int(fields[chrFieldID + 2])
        GC = GCContent(GenomeDict,chr,left,right)
        Accessibility = []
        for ID in fieldIDs:
            Accessibility.append(float(fields[ID]))
        AvAcc = np.log10(np.mean(Accessibility))
        GCAcc.append([AvAcc,GC])
        if minAvAcc == '':
            minAvAcc = AvAcc
            maxAvAcc = AvAcc
        else:
            minAvAcc = min(minAvAcc,AvAcc)
            maxAvAcc = max(maxAvAcc,AvAcc)
        if minGC == '':
            minGC = GC
            maxGC = GC
        else:
            minGC = min(minGC,GC)
            maxGC = max(maxGC,GC)

    print 'finished parsing data'

    TotalPeaks = len(GCAcc)

    GCAcc = np.array(GCAcc)
    meanAvAcc = np.mean(GCAcc[:,0])
    meanGC = np.mean(GCAcc[:,1])

    GCAccMatrixDict = {}
    for i in range(B50):
        GCAccMatrixDict[i] = {}
        for j in range(B50):
            GCAccMatrixDict[i][j] = 0

    GCAccMatrix = []

    GCBinSize = (maxGC - minGC)/B50
    AvAccBinSize = (maxAvAcc - minAvAcc)/B50

    K=0
    for [AvAcc,GC] in GCAcc:
        K += 1
        if GC == minGC:
            j = 0
        elif GC == maxGC:
            j = B50 - 1
        else:
            BinNotFound = True
            b = 0
            while BinNotFound:
                if GC >= (minGC + b*GCBinSize) and GC < (minGC + (b+1)*GCBinSize):
                    j = b
                    BinNotFound = False
                b+=1
        if AvAcc == minAvAcc:
            i = 0
        elif AvAcc == maxAvAcc:
            i = B50 - 1
        else:
            BinNotFound = True
            b = 0
            while BinNotFound:
                if AvAcc >= (minAvAcc + b*AvAccBinSize) and AvAcc < (minAvAcc + (b+1)*AvAccBinSize):
                    i = b
                    BinNotFound = False
                b+=1
        GCAccMatrixDict[i][j]+=1

    outfile = open(outprefix + '.GC-AvAcc-matrix', 'w')
    outline = '#'
    for i in range(B50):
        outline = outline + '\t' + str(minAvAcc + i*AvAccBinSize)
    outfile.write(outline + '\n')

    for j in range(B50):
        J = []
        outline = str(minGC + j*GCBinSize)
        for i in range(B50):
            J.append(GCAccMatrixDict[i][j])
            outline = outline + '\t' + str(GCAccMatrixDict[i][j])
        GCAccMatrix.append(J)
        outfile.write(outline + '\n')

    outfile.close()

    GCAccMatrix = np.array(GCAccMatrix)
    Sigma = np.cov(GCAccMatrix)
    (L,G) = np.linalg.eig(Sigma)
    LNegSqrt = np.power(np.abs(L),-0.5)       # taking the absolute value to prevent floating point errors on the covariance matrix
    SigmaNegDsqrt = np.matmul(np.matmul(G,np.diag(1/LNegSqrt)),G.T)

    outfile = open(outprefix + '.GC-AvAcc-matrix.transformed', 'w')
    outline = '#'
    for i in range(B50):
        outline = outline + '\t' + str(minAvAcc + i*AvAccBinSize)
    outfile.write(outline + '\n')

    GCAccMatrixMalahanobisTransformed = []
    for j in range(B50):
        a = GCAccMatrix[j,:]
        z = np.matmul(SigmaNegDsqrt,(a - np.mean(a)))
        GCAccMatrixMalahanobisTransformed.append(z)

    GCAccMatrixMalahanobisTransformed = np.array(GCAccMatrixMalahanobisTransformed)
    GCAccMatrixMalahanobisTransformed = GCAccMatrixMalahanobisTransformed.clip(0)
    GCAccMatrixMalahanobisTransformed = GCAccMatrixMalahanobisTransformed/(np.sum(GCAccMatrixMalahanobisTransformed))
    GCAccMatrixMalahanobisTransformed = GCAccMatrixMalahanobisTransformed*(np.sum(GCAccMatrix))

    for j in range(B50):
        outline = str(minGC + j*GCBinSize)
        for i in range(B50):
            outline = outline + '\t' + str(GCAccMatrixMalahanobisTransformed[j,i])
        outfile.write(outline + '\n')

    outfile.close()

    print 'finished transforming GC vs accessibiltiy matrix'

    PBiasCorrectedDictI = {}
    for i in range(B50):
        PI = sum(GCAccMatrixMalahanobisTransformed[i,:])
        PBiasCorrectedDictI[i]=[]
        for j in range(B50):
            d = (i - j - 0.5)*GCBinSize
            N = norm.cdf(d + GCBinSize, 0, 0.01) - norm.cdf(d, 0, 0.01)
#            print i, j, d, GCBinSize, norm.cdf(d + GCBinSize, 0, 0.01), norm.cdf(d, 0, 0.01), N
            for t in range(int(10*N*PI)):
                PBiasCorrectedDictI[i].append(j)

    PBiasCorrectedDictJ = {}
    for i in range(B50):
        PI = sum(GCAccMatrixMalahanobisTransformed[:,i])
        PBiasCorrectedDictJ[i]=[]
        for j in range(B50):
            d = (i - j - 0.5)*AvAccBinSize
            N = norm.cdf(d + AvAccBinSize, 0, 0.01) - norm.cdf(d, 0, 0.01)
#            print i, j, d, AvAccBinSize, norm.cdf(d + AvAccBinSize, 0, 0.01), norm.cdf(d, 0, 0.01), N
            for t in range(int(10*N*PI)):
                PBiasCorrectedDictJ[i].append(j)

    GCAccMatrixDict = {}
    for i in range(B50):
        GCAccMatrixDict[i] = {}
        for j in range(B50):
            GCAccMatrixDict[i][j] = []

    print 'finished binning transformed matrix'

    HeaderFields = {}
    for ID in fieldIDs:
        HeaderFields[ID] = str(ID)

    if DataMatrix.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + DataMatrix
    elif DataMatrix.endswith('.gz'):
        cmd = 'gunzip -c ' + DataMatrix
    else:
        cmd = 'cat ' + DataMatrix
    p = os.popen(cmd, "r")
    PeakDict = {}
    CovDict = {}
    TotalRPMDict = {}
    K = 0
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        fields = line.strip().split('\t')
        if line.startswith('#'):
            for ID in fieldIDs:
                H = fields[ID]
                HeaderFields[ID] = H
            continue
        K += 1
        if K % 10000 == 0:
            print K, 'lines processed'
        if doSFR:
            chr = fields[chrFieldID].split(':')[0]
            left = int(fields[chrFieldID].split(':')[1].split('-')[0])
            right = int(fields[chrFieldID].split(':')[1].split('-')[1])
        else:
            chr = fields[chrFieldID]
            left = int(fields[chrFieldID + 1])
            right = int(fields[chrFieldID + 2])
        GC = GCContent(GenomeDict,chr,left,right)
        Accessibility = []
        for ID in fieldIDs:
            if TotalRPMDict.has_key(ID):
                pass
            else:
                TotalRPMDict[ID] = 0
            Accessibility.append(float(fields[ID]))
            TotalRPMDict[ID] += float(fields[ID])
        AvAcc = np.log10(np.mean(Accessibility))
        if GC == minGC:
            j = 0
        elif GC == maxGC:
            j = B50 - 1
        else:
            BinNotFound = True
            b = 0
            while BinNotFound:
                if GC >= (minGC + b*GCBinSize) and GC < (minGC + (b+1)*GCBinSize):
                    j = b
                    BinNotFound = False
                b += 1
        GCbin = j
        if AvAcc == minAvAcc:
            i = 0
        elif AvAcc == maxAvAcc:
            i = B50 - 1
        else:
            BinNotFound = True
            b = 0
            while BinNotFound:
                if AvAcc >= (minAvAcc + b*AvAccBinSize) and AvAcc < (minAvAcc + (b+1)*AvAccBinSize):
                    i = b
                    BinNotFound = False
                b += 1
        AvAccbin = i
        GCAccMatrixDict[GCbin][AvAccbin].append(K)
        if CovDict.has_key(chr):
            pass
        else:
            CovDict[chr] = {}
        for i in range(left,right):
            CovDict[chr][i] = K
        PeakDict[K] = (chr,left,right,Accessibility,GCbin,AvAccbin)

    print 'finished second parsing of data'

    print 'TotalRPMs per sample:'
    TotalRPM = 0
    for ID in TotalRPMDict.keys():
        print ID, TotalRPMDict[ID]
        TotalRPM += TotalRPMDict[ID]

    if motifs.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + motifs
    elif motifs.endswith('.gz'):
        cmd = 'gunzip -c ' + motifs
    else:
        cmd = 'cat ' + motifs
    p = os.popen(cmd, "r")
    M = 0
    MotifOverlapDict = {}
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        M+=1
        if M % 500000 == 0:
            print str(M/1000000.) + 'M lines processed'
        fields = line.strip().split('\t')
        if line.startswith('#'):
            continue
        chr = fields[motChrFieldID]
        left = int(fields[motChrFieldID+1])
        right = int(fields[motChrFieldID+2])
        motif = []
        for ID in motiflabelFieldsIDs:
            motif.append(fields[ID])
        motif = tuple(motif)
        motif = str(motif).replace('(','').replace(')','').replace(', ','|').replace("'",'')
        if MotifOverlapDict.has_key(motif):
            pass
        else:
            MotifOverlapDict[motif] = {}
        if CovDict.has_key(chr):
            if CovDict[chr].has_key(left) and CovDict[chr].has_key(right):
                if CovDict[chr][left] == CovDict[chr][right]:
                    K = CovDict[chr][left]
                    MotifOverlapDict[motif][K] = 1

    print 'finished parsing motifs'

    outfile = open(outprefix + '.deviations_and_z-scores', 'w')
    outline = '#motif'
    for ID in fieldIDs:
        outline = outline + '\t' + HeaderFields[ID] + ' raw_deviation' + '\t' + HeaderFields[ID] + ' background_mean_deviation' + '\t' + HeaderFields[ID] + ' background_mean_deviation_std'  + '\t' + HeaderFields[ID] + ' bias-corrected deviation' + '\t' + HeaderFields[ID] + ' Z-score'
    outfile.write(outline + '\n')

    MotifOverlapDictMotifs = MotifOverlapDict.keys()
    MotifOverlapDictMotifs.sort()

    for motif in MotifOverlapDictMotifs:
        ObservedMotDict = {}
        TotalObservedMot = 0
        for ID in fieldIDs:
            ObservedMotDict[ID] = 0
        for K in MotifOverlapDict[motif].keys():
            (chr,left,right,Accessibility,GCbin,AvAccbin) = PeakDict[K]
            i = 0
            for ID in fieldIDs:
                ObservedMotDict[ID] += Accessibility[i]
                TotalObservedMot += Accessibility[i]
                i+=1
        print motif
        BackgroundDeviations = []
        for S in range(50):
            Devs = []
            STotalObservedMot = 0
            SObservedMotDict = {}
            for ID in fieldIDs:
                SObservedMotDict[ID] = 0
            for K in MotifOverlapDict[motif].keys():
                (chr,left,right,Accessibility,GCbin,AvAccbin) = PeakDict[K]
                if len(PBiasCorrectedDictI[GCbin]) == 0:
                     print 'warning: empty bin found, switching to next one', 'GCbin', GCbin
                     GCbin = GCbin + 1
                     if GCbin == B50:
                         GCbin = GCbin - 2
                     if len(PBiasCorrectedDictI[GCbin]) == 0:
                         GCbin = GCbin - 2
                if len(PBiasCorrectedDictJ[AvAccbin]) == 0:
                     print 'warning: empty bin found, switching to next one', 'AvAccbin', AvAccbin
                     AvAccbin = AvAccbin + 1
                     if AvAccbin == B50:
                         AvAccbin = AvAccbin - 2
                     if len(PBiasCorrectedDictJ[AvAccbin]) == 0:
                         AvAccbin = AvAccbin - 2
                sampleGCbin = random.choice(PBiasCorrectedDictI[GCbin])
                sampleAvAccbin = random.choice(PBiasCorrectedDictJ[AvAccbin])
                if len(GCAccMatrixDict[sampleGCbin][sampleAvAccbin]) == 0:
                    print 'emtpy GCAccMatrix bin found: (' + str(sampleGCbin) + ',' + str(sampleAvAccbin)
                    print 'sampling from nearest non-empty bin, which is:'
                    (nI,nJ) = (sampleGCbin,sampleAvAccbin)
                    nIJ = 0
                    nIJdist = 2*B50
                    for I in range(B50):
                        for J in range(B50):
                            if len(GCAccMatrixDict[I][J]) > nIJ:
                                dist = math.fabs(I - nI) + math.fabs(J - nJ)
                                if dist <= nIJdist:
                                    (nI,nJ) = (I,J)
                                    nIJdist = dist
                                    nIJ = len(GCAccMatrixDict[I][J])
                    sampleGCbin = nI
                    sampleAvAccbin = nJ
                    print '(' + str(sampleGCbin) + ',' + str(sampleAvAccbin) + ')'
                    newK = random.choice(GCAccMatrixDict[sampleGCbin][sampleAvAccbin])
                else:
                    newK = random.choice(GCAccMatrixDict[sampleGCbin][sampleAvAccbin])
#                while MotifOverlapDict[motif].has_key(NewK):
#                    newK = random.sample(GCAccMatrixDict[sampleGCbin][sampleAvAccbin])
                (chr,left,right,Accessibility,GCbin,AvAccbin) = PeakDict[newK]
                i = 0
                for ID in fieldIDs:
                    SObservedMotDict[ID] += Accessibility[i]
                    STotalObservedMot += Accessibility[i]
                    i+=1
            for ID in fieldIDs:
                SExpectedMot = (STotalObservedMot/TotalRPM)*TotalRPMDict[ID]
                SRawDeviationMot = (SObservedMotDict[ID] - SExpectedMot)/SExpectedMot
                Devs.append(SRawDeviationMot)
            BackgroundDeviations.append(Devs)
        BackgroundDeviations = np.array(BackgroundDeviations)
        i = 0
        outline = motif
        BiasCorrectedDeviations = {}
        for ID in fieldIDs:
            ExpectedMot = (TotalObservedMot/TotalRPM)*TotalRPMDict[ID]
            RawDeviationMot = (ObservedMotDict[ID] - ExpectedMot)/ExpectedMot
            BiasCorrectedDeviationMot = RawDeviationMot - np.mean(BackgroundDeviations[:,i])
            BiasCorrectedDeviations[ID] = BiasCorrectedDeviationMot
            Z = (BiasCorrectedDeviationMot - np.mean(BackgroundDeviations[:,i]))/np.std(BackgroundDeviations[:,i])
            outline = outline + '\t' + str(RawDeviationMot) + '\t' + str(np.mean(BackgroundDeviations[:,i])) + '\t' + str(np.std(BackgroundDeviations[:,i]))
            outline = outline + '\t' + str(BiasCorrectedDeviationMot) + '\t' + str(Z)
            i+=1
        outfile.write(outline + '\n')
        if doDiffTest:
            for (A,B) in DiffDictOutfile.keys():
                (AfieldIDs,BfieldIDs) = DiffDict[(A,B)]
                Avalues = []
                for ID in AfieldIDs:
                    Avalues.append(BiasCorrectedDeviations[ID])
                Bvalues = []
                for ID in BfieldIDs:
                    Bvalues.append(BiasCorrectedDeviations[ID])
                Avalues = np.array(Avalues)
                Bvalues = np.array(Bvalues)
                (t,p) = scipy.stats.ttest_ind(Avalues,Bvalues,equal_var=False)
                outline = motif + '\t' + str(np.mean(Avalues)) + '\t' + str(np.mean(Bvalues)) + '\t' + str(np.mean(Bvalues) - np.mean(Avalues)) + '\t' + str(t) + '\t' + str(p)
                DiffDictOutfile[(A,B)].write(outline + '\n')

#   deviation_mot = (observed_mot - expected_mot/expected_mot)
#   expected_mot = (fragments_(all_samples,mot)/fragments(all_samples))*FRIP
#   background_deviation_mot = (observed_mot - expected_mot/expected_mot) for background set
#   repeat 50 times
#   bias-corrected_deviation_mot = deviation_mot - mean(background_deviation_mot)
#   Z-score = bias-corrected_deviation_mot/std(background_deviation_mot)

    outfile.close()

    if doDiffTest:
        for (A,B) in DiffDictOutfile.keys():
            DiffDictOutfile[(A,B)].close()

run()
