##################################
#                                #
# Last modified 2018/07/06       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import math
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s SingleMoleculeCorrelation_output TSS.bed chrFieldID leftFiled RightFieldID strandFieldID outfile' % sys.argv[0]
        print 'Note: the script assumes that the input is the output from SingleMoleculeCorrelation.py, i.e. regions are sorted by coordinates!!!'
        sys.exit(1)

    SMC = sys.argv[1]
    peaks = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    leftFieldID = int(sys.argv[4])
    rightFieldID = int(sys.argv[5])
    strandFieldID = int(sys.argv[6])
    outfilename = sys.argv[7]

    TSSDict = {}
    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz') or peaks.endswith('.bgz'):
        cmd = 'zcat ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        left = int(fields[leftFieldID])
        right = int(fields[rightFieldID])
        strand = fields[strandFieldID]
        if TSSDict.has_key(chr):
            pass
        else:
            TSSDict[chr] = {}
        TSSDict[chr][(left,right)] = strand

    print 'finished inputting peaks'

    outfile = open(outfilename,'w')

    SMCDict = {}
    if SMC.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + SMC
    elif SMC.endswith('.gz') or SMC.endswith('.bgz'):
        cmd = 'zcat ' + SMC
    elif SMC.endswith('.zip'):
        cmd = 'unzip -p ' + SMC
    else:
        cmd = 'cat ' + SMC
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            outline = line.strip() + '\t' + 'strand1' + '\t' + 'strand2' + '\t' + 'TSS_array_N' + '\t' + 'distance'
            outfile.write(outline + '\n')
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        L1 = int(fields[1])
        R1 = int(fields[2])
        L2 = int(fields[6])
        R2 = int(fields[7])
        if SMCDict.has_key(chr):
            pass
        else:
            SMCDict[chr] = {}
        if SMCDict[chr].has_key((L1,R1)):
            pass
        else:
            SMCDict[chr][(L1,R1)] = []
        SMCDict[chr][(L1,R1)].append((L2,R2))
        N = SMCDict[chr][(L1,R1)].index((L2,R2))
        S1 = TSSDict[chr][(L1,R1)]
        S2 = TSSDict[chr][(L2,R2)]
        outline = line.strip() + '\t' + S1 + '\t' + S2 + '\t' + str(N+1) + '\t' + str(R2 - R1)
        outfile.write(outline + '\n')

    outfile.close()
            
run()

