##################################
#                                #
# Last modified 2018/10/08       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import numpy as np
import math
from sets import Set

def run():

    if len(sys.argv) < 11:
        print 'usage: python %s SingleMoleculeCorrelation_output TSS.bed chrFieldID leftFiled RightFieldID strandFieldID labeFieldID expression_matrix labelFieldID scoreFieldIDs outfile' % sys.argv[0]
        print '\tNote: the script assumes that the input is the output from SingleMoleculeCorrelation.py, i.e. regions are sorted by coordinates!!!'
        print '\tscore fields can be any combination of commas and from:to'
        sys.exit(1)

    SMC = sys.argv[1]
    peaks = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    leftFieldID = int(sys.argv[4])
    rightFieldID = int(sys.argv[5])
    strandFieldID = int(sys.argv[6])
    labelFieldIDTSS = int(sys.argv[7])
    matrix = sys.argv[8]
    labelFieldIDMatrix = int(sys.argv[9])
    scoreFields = []
    for block in sys.argv[10].split(','):
        [ID1,ID2] = block.split(':')
        for ID in range(int(ID1),int(ID2)):
            scoreFields.append(int(ID))
    outfilename = sys.argv[11]

    TSSDict = {}
    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz') or peaks.endswith('.bgz'):
        cmd = 'zcat ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        left = int(fields[leftFieldID])
        right = int(fields[rightFieldID])
        strand = fields[strandFieldID]
        label = fields[labelFieldIDTSS]
        if TSSDict.has_key(chr):
            pass
        else:
            TSSDict[chr] = {}
        TSSDict[chr][(left,right)] = (strand,label)

    print 'finished inputting peaks'

    ExprDict = {}
    if matrix.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + matrix
    elif matrix.endswith('.gz') or matrix.endswith('.bgz'):
        cmd = 'zcat ' + matrix
    elif matrix.endswith('.zip'):
        cmd = 'unzip -p ' + matrix
    else:
        cmd = 'cat ' + matrix
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        RN += 1
        if line.startswith('#'):
            continue
        if line.startswith('YORF\t'):
            continue
        if line.startswith('EWEIGHT\t'):
            continue
        fields = line.strip().split('\t')
        label = fields[labelFieldIDMatrix]
        scores = []
        for ID in scoreFields:
            scores.append(float(fields[ID]))
        M = np.mean(scores)
        scores = (scores - M)/M
        ExprDict[label] = scores

    print 'finished inputting peaks'

    outfile = open(outfilename,'w')

    if SMC.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + SMC
    elif SMC.endswith('.gz') or SMC.endswith('.bgz'):
        cmd = 'zcat ' + SMC
    elif SMC.endswith('.zip'):
        cmd = 'unzip -p ' + SMC
    else:
        cmd = 'cat ' + SMC
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            outline = line.strip() + '\t' + 'expression_correlation'
            outfile.write(outline + '\n')
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        L1 = int(fields[1])
        R1 = int(fields[2])
        L2 = int(fields[6])
        R2 = int(fields[7])
        (strand1,label1) = TSSDict[chr][(L1,R1)]
        (strand2,label2) = TSSDict[chr][(L2,R2)]
        if ExprDict.has_key(label1) and ExprDict.has_key(label2):
            CORR = float(np.corrcoef(ExprDict[label1],ExprDict[label2])[0,1])
        else:
            CORR = 'nan'
        outline = line.strip() + '\t' + str(CORR)
        outfile.write(outline + '\n')

    outfile.close()
            
run()

