##################################
#                                #
# Last modified 04/25/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import time
import math
import random
from sets import Set

def run():

    if len(sys.argv) < 21:
        print 'usage: python %s file1 seqFieldID1 CountsFieldID1 chrField1 PositionField1 ReadEndAtPosition1 StrandFieldID1 number_reads_to_sample1 file2 seqFieldID2 CountsFieldID2 chrField2 PositionField2 ReadEndAtPosition2 StrandFieldID2 number_reads_to_sample2 iterations minLength maxLength Overlap outfilename [-collapseDups]' % sys.argv[0]
        print '\tNote: the script will discard sequences with 0 counts' 
        print '\tReadEndAtPosition1: 5 or 3, refers to minus strand reads only' 
        sys.exit(1)

    doCollapseDups = False
    if '-collapseDups' in sys.argv:
        doCollapseDups = True

    file1 = sys.argv[1]
    seqFieldID1 = int(sys.argv[2])
    CountsFieldID1 = int(sys.argv[3])
    chrFieldID1 = int(sys.argv[4])
    PositionFieldID1 = int(sys.argv[5])
    ReadEndAtPosition1 = int(sys.argv[6])
    StrandFieldID1 = int(sys.argv[7])
    numReads1 = int(sys.argv[8])
    file2 = sys.argv[9]
    seqFieldID2 = int(sys.argv[10])
    CountsFieldID2 = int(sys.argv[11])
    chrFieldID2 = int(sys.argv[12])
    PositionFieldID2 = int(sys.argv[13])
    ReadEndAtPosition2 = int(sys.argv[14])
    StrandFieldID2 = int(sys.argv[15])
    numReads2 = int(sys.argv[16])
    iterations = int(sys.argv[17])
    minLength = int(sys.argv[18])
    maxLength = int(sys.argv[19])
    Overlap = int(sys.argv[20])
    outfilename = sys.argv[21]

    ReadList1 = []
    ReadList2 = []

    lineslist = open(file1)
    for line in lineslist:
        if line[0]=='#':
            continue
        fields = line.strip().split('\t')
        counts = int(fields[CountsFieldID1])
        if counts < 1:
            continue
        seq = fields[seqFieldID1]
        if len(seq) < minLength or len(seq) > maxLength:
            continue
        chr = fields[chrFieldID1]
        strand = fields[StrandFieldID1]
        pos = int(fields[PositionFieldID1])
        if ReadEndAtPosition1 == 3 and strand == '-':
            pos = pos + len(seq)
        read = (chr,pos,strand)
        for i in range(counts):
            ReadList1.append(read)

    if doCollapseDups:
        ReadList1 = list(Set(ReadList1))

    lineslist = open(file2)
    for line in lineslist:
        if line[0]=='#':
            continue
        fields = line.strip().split('\t')
        print fields
        counts = int(fields[CountsFieldID2])
        if counts < 1:
            continue
        seq = fields[seqFieldID2]
        if len(seq) < minLength or len(seq) > maxLength:
            continue
        chr = fields[chrFieldID2]
        strand = fields[StrandFieldID2]
        pos = int(fields[PositionFieldID2])
        if ReadEndAtPosition2 == 3 and strand == '-':
            pos = pos + len(seq)
        read = (chr,pos,strand)
        for i in range(counts):
            ReadList2.append(read)

    outfile = open(outfilename, 'w')
    outline = '#Iteration\tTotalReads1\tTotalReads2\tSampleReads1\tSampleReads2\tSampledReadsInPingPongPairs-1\tFraction-1'
    outfile.write(outline + '\n')

    for i in range(iterations):
        print i
        start = time.time()
        sampledReadList1 = random.sample(ReadList1,min(numReads1,len(ReadList1)))
        sampledReadList2 = random.sample(ReadList2,min(numReads2,len(ReadList2)))
        CovDict1 = {}
        CovDict2 = {}
        for read in sampledReadList1:
            (chr,pos,strand) = read
            if CovDict1.has_key(chr):
                pass
            else:
                CovDict1[chr]={}
                CovDict1[chr]['+'] = {}
                CovDict1[chr]['-'] = {}
            if strand == '+':
                if CovDict1[chr]['+'].has_key(pos):
                    pass
                else:
                    CovDict1[chr]['+'][pos] = 0
                CovDict1[chr]['+'][pos]+=1
            if strand == '-':
                if CovDict1[chr]['-'].has_key(pos):
                    pass
                else:
                    CovDict1[chr]['-'][pos] = 0
                CovDict1[chr]['-'][pos]+=1
        for read in sampledReadList2:
            (chr,pos,strand) = read
            if CovDict2.has_key(chr):
                pass
            else:
                CovDict2[chr]={}
                CovDict2[chr]['+'] = {}
                CovDict2[chr]['-'] = {}
            if strand == '+':
                if CovDict2[chr]['+'].has_key(pos):
                    pass
                else:
                    CovDict2[chr]['+'][pos] = 0
                CovDict2[chr]['+'][pos]+=1
            if strand == '-':
                if CovDict2[chr]['-'].has_key(pos):
                    pass
                else:
                    CovDict2[chr]['-'][pos] = 0
                CovDict2[chr]['-'][pos]+=1
        InPingPongPairs = 0.0
        for chr in CovDict1.keys():
            for pos in CovDict1[chr]['+'].keys():
                if CovDict2[chr]['-'].has_key(pos + Overlap):
                    InPingPongPairs += CovDict1[chr]['+'][pos]
            for pos in CovDict1[chr]['-'].keys():
                if CovDict2[chr]['+'].has_key(pos - Overlap):
                    InPingPongPairs += CovDict1[chr]['-'][pos]
        outline = str(i) + '\t' + str(len(ReadList1)) + '\t' + str(len(ReadList2)) + '\t' + str(len(sampledReadList1)) + '\t' + str(len(sampledReadList2)) + '\t' + str(InPingPongPairs) + '\t' + str(InPingPongPairs/(len(sampledReadList1)))
        outfile.write(outline + '\n')
        end = time.time()
        print end - start

    outfile.close()
        
run()

