##################################
#                                #
# Last modified 2018/12/17       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import math
from sets import Set

def run():

    if len(sys.argv) < 11:
        print 'usage: python %s SingleMoleculeCorrelation_output TSS.bed chrFieldID leftFiled RightFieldID strandFieldID labelFieldID expression_file labelFieldID TPMFieldID outfile' % sys.argv[0]
        print 'Note: the script assumes that the input is the output from SingleMoleculeCorrelation.py, i.e. regions are sorted by coordinates!!!'
        sys.exit(1)

    SMC = sys.argv[1]
    peaks = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    leftFieldID = int(sys.argv[4])
    rightFieldID = int(sys.argv[5])
    strandFieldID = int(sys.argv[6])
    labelFieldID = int(sys.argv[7])
    expressionfile = sys.argv[8]
    explabelFieldID = int(sys.argv[9])
    TPMFieldID = int(sys.argv[10])
    outfilename = sys.argv[11]

    TPMDict = {}
    if expressionfile.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + expressionfile
    elif expressionfile.endswith('.gz') or expressionfile.endswith('.bgz'):
        cmd = 'zcat ' + expressionfile
    elif expressionfile.endswith('.zip'):
        cmd = 'unzip -p ' + expressionfile
    else:
        cmd = 'cat ' + expressionfile
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        label = fields[explabelFieldID]
        TPM = fields[TPMFieldID]
        TPMDict[label] = TPM

    TSSDict = {}
    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz') or peaks.endswith('.bgz'):
        cmd = 'zcat ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        left = int(fields[leftFieldID])
        right = int(fields[rightFieldID])
        strand = fields[strandFieldID]
        labels = fields[labelFieldID].split(',')
        if TSSDict.has_key(chr):
            pass
        else:
            TSSDict[chr] = {}
        TSSDict[chr][(left,right)] = []
        for label in labels:
            TSSDict[chr][(left,right)].append((strand,label))

    print 'finished inputting peaks'

    outfile = open(outfilename,'w')

    SMCDict = {}
    if SMC.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + SMC
    elif SMC.endswith('.gz') or SMC.endswith('.bgz'):
        cmd = 'zcat ' + SMC
    elif SMC.endswith('.zip'):
        cmd = 'unzip -p ' + SMC
    else:
        cmd = 'cat ' + SMC
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            outline = line.strip() + '\t' + 'label1' + '\t' + 'strand1' + '\t' +  'label2' + '\t' + 'strand2' + '\t' + 'TSS_array_N' + '\t' + 'distance' + '\t' + 'expression_1' + '\t' + 'expression_2'
            outfile.write(outline + '\n')
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        L1 = int(fields[1])
        R1 = int(fields[2])
        L2 = int(fields[6])
        R2 = int(fields[7])
        if SMCDict.has_key(chr):
            pass
        else:
            SMCDict[chr] = {}
        if SMCDict[chr].has_key((L1,R1)):
            pass
        else:
            SMCDict[chr][(L1,R1)] = []
        SMCDict[chr][(L1,R1)].append((L2,R2))
        N = SMCDict[chr][(L1,R1)].index((L2,R2))
        for (S1,ID1) in TSSDict[chr][(L1,R1)]:
            for (S2,ID2) in TSSDict[chr][(L2,R2)]:
                outline = line.strip() + '\t' + ID1 + '\t' + S1 + '\t' + ID2 + '\t' + S2 + '\t' + str(N+1) + '\t' + str(R2 - R1)
                outline = outline + '\t' + TPMDict[ID1] + '\t' + TPMDict[ID2]
                outfile.write(outline + '\n')

    outfile.close()
            
run()

