##################################
#                                #
# Last modified 06/06/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import time
import math
import random
import scipy.stats
from sets import Set

def reverseComplement(sequence,DNA):
    
    reversesequence=''
    for i in range(len(sequence)):
        reversesequence=reversesequence+DNA[sequence[len(sequence)-i-1]]
    
    return reversesequence

def binomialProportionPvalueCI(pp1,numReads1,pp2,numReads2,numLabels):

    if numReads1 == 0 or numReads2 == 0:
        (pval,padj,CI) = ('nan','nan','nan')
    else:
        q = (pp1*numReads1 + pp2*numReads2)/(numReads1 + numReads2)
#        print pp1,numReads1,pp2,numReads2,numLabels
#        print pp1 - pp2, math.fabs(pp1 - pp2), pp1*numReads1, pp2*numReads2, q, 1-q, 1./numReads1, 1./numReads2
        if pp1 == pp2 == 0:
            (pval,padj,CI) = ('nan','nan','nan')
        else:
            Z = math.fabs(pp1 - pp2)/math.sqrt(q*(1-q)*(1./numReads1 + 1./numReads2))
            pval = scipy.stats.norm.sf(Z)*2
            padj = min(1,pval*numLabels)
            CI = 1.96*math.sqrt(q*(1-q)*(1./numReads1 + 1./numReads2))

    print 'pp1,pp2: ', pp1, pp2, 'numReads:', numReads1, numLabels, 'p-values: ', pval, padj, 'diff: ', pp1 - pp2, CI
    
    return (pval,padj,CI)

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s config sampling_factor iterations minLength maxLength Overlap outfilename [-collapseDups]' % sys.argv[0]
        print '\tconfig file format:' 
        print '\tLabel\tSample1_name\tSample1_file(s)\tSample2_name\tSample2_file(s)' 
        print '\tMultipel files can be entered in each Sample field, they should be comma-separated' 
        print '\tNote: This script will not enforce the 1U rule' 
        print '\tNote: For each label entry, the script will take the set of reads and subsample down to the number of reads in the sample with fewer reads'
        print '\t      multiplied by the sampling factor (which should be <= 1) It will then calculate the fraction of read in ping-pong pairs, and will'
        print '\t      output the binomial proportion p-value for the null hypothesis that the two are equal, as well as the 95% confidence interval for the difference' 
        sys.exit(1)

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n','-':'-'}
    RNA = {'A':'U','U':'A','G':'C','C':'G','N':'N','a':'u','u':'a','g':'c','c':'g','n':'n','-':'-'}


    doCollapseDups = False
    if '-collapseDups' in sys.argv:
        doCollapseDups = True

    config = sys.argv[1]
    SF = float(sys.argv[2])
    if SF > 1:
        print 'the sampling factor cannot be larger than 1, exiting'
        sys.exit(1)
    iterations = int(sys.argv[3])
    minLength = int(sys.argv[4])
    maxLength = int(sys.argv[5])
    Overlap = int(sys.argv[6])
    outfilename = sys.argv[7]

    LabelDict = {}
    SampleDict = {}
    lineslist = open(config)
    for line in lineslist:
        if line[0]=='#' or line.strip() == '':
            continue
        fields = line.strip().split('\t')
        label = fields[0]
        LabelDict[label] = {}
        print fields
        sample1 = fields[1]
        sample2 = fields[3]
        SampleDict[sample1] = 1
        SampleDict[sample2] = 1
        LabelDict[label][sample1] = fields[2].split(',')
        LabelDict[label][sample2] = fields[4].split(',')

    if len(SampleDict.keys()) > 2:
        print 'more than two samples entered, exiting'
        sys.exit(1)

    keys = SampleDict.keys()
    keys.sort()
    (sample1,sample2) = tuple(keys)

    outfile = open(outfilename,'w')
    outline = '#Label\tSample1\tSample2\tpp_fraction_1\tpp_fraction_2\tsubsampled_reads\tsubsampled_pp_fraction_1\tsubsmapled_pp_fraction_2\tp-val\tp-adj\tss_pp_fraction1-ss_pp_fraction2\tss_pp_fraction1-ss_pp_fraction2,95%CI'
    outfile.write(outline + '\n')

    numLabels = len(LabelDict.keys())

    for label in LabelDict.keys():
        print label
        SubsamplingDict = {}
        SubsamplingDict[sample1] = []
        SubsamplingDict[sample2] = []
        outline = label + '\t' + sample1 + '\t' + sample2

        ReadList1 = []
        for fasta in LabelDict[label][sample1]:
            lineslist = open(fasta)
            for line in lineslist:
                if line[0]=='>':
                    continue
                read=line.strip().replace('U','T')
                if len(read) < minLength or len(read) > maxLength:
                    continue
                ReadList1.append(read)
        if doCollapseDups:
            ReadList1 = list(Set(ReadList1))

        ReadList2 = []
        for fasta in LabelDict[label][sample2]:
            lineslist = open(fasta)
            for line in lineslist:
                if line[0]=='>':
                    continue
                read=line.strip().replace('U','T')
                if len(read) < minLength or len(read) > maxLength:
                    continue
                ReadList2.append(read)
        if doCollapseDups:
            ReadList2 = list(Set(ReadList2))

        for i in range(iterations):
            sampledReadList = random.sample(ReadList1,int(SF*min(len(ReadList1),len(ReadList2))))
            First10KmerDict = {}
            First10RevKmerDict = {}
            Palindromic = 0
            InPingPongPairs = 0.0
            for read in sampledReadList:
                kmer = read[0:Overlap]
                First10KmerDict[kmer] = 1
            for read in sampledReadList:
                revkmer = reverseComplement(read[0:Overlap],DNA)
                First10RevKmerDict[revkmer] = 1
            for read in sampledReadList:
                kmer = read[0:Overlap]
                if First10RevKmerDict.has_key(kmer):
                    InPingPongPairs += 1
            if len(sampledReadList) == 0:
                pp = 0
            else:
                pp = InPingPongPairs/len(sampledReadList)
            SubsamplingDict[sample1].append(pp)

        for i in range(iterations):
            sampledReadList = random.sample(ReadList2,int(SF*min(len(ReadList1),len(ReadList2))))
            First10KmerDict = {}
            First10RevKmerDict = {}
            Palindromic = 0
            InPingPongPairs = 0.0
            for read in sampledReadList:
                kmer = read[0:Overlap]
                First10KmerDict[kmer] = 1
            for read in sampledReadList:
                revkmer = reverseComplement(read[0:Overlap],DNA)
                First10RevKmerDict[revkmer] = 1
            for read in sampledReadList:
                kmer = read[0:Overlap]
                if First10RevKmerDict.has_key(kmer):
                    InPingPongPairs += 1
            if len(sampledReadList) == 0:
                pp = 0
            else:
                pp = InPingPongPairs/len(sampledReadList)
            SubsamplingDict[sample2].append(pp)

        First10KmerDict = {}
        First10RevKmerDict = {}
        Palindromic = 0
        InPingPongPairs = 0.0
        for read in ReadList1:
            kmer = read[0:Overlap]
            First10KmerDict[kmer] = 1
        for read in ReadList1:
            revkmer = reverseComplement(read[0:Overlap],DNA)
            First10RevKmerDict[revkmer] = 1
        for read in ReadList1:
            kmer = read[0:Overlap]
            if First10RevKmerDict.has_key(kmer):
                InPingPongPairs += 1
        if len(sampledReadList) == 0:
            pp = 0
        else:
            pp = InPingPongPairs/len(ReadList1)
        outline = outline + '\t' + str(pp)

        First10KmerDict = {}
        First10RevKmerDict = {}
        Palindromic = 0
        InPingPongPairs = 0.0
        for read in ReadList2:
            kmer = read[0:Overlap]
            First10KmerDict[kmer] = 1
        for read in ReadList2:
            revkmer = reverseComplement(read[0:Overlap],DNA)
            First10RevKmerDict[revkmer] = 1
        for read in ReadList2:
            kmer = read[0:Overlap]
            if First10RevKmerDict.has_key(kmer):
                InPingPongPairs += 1
        if len(sampledReadList) == 0:
            pp = 0
        else:
            pp = InPingPongPairs/len(ReadList2)
        outline = outline + '\t' + str(pp)

        pp1 = sum(SubsamplingDict[sample1])/iterations
        pp2 = sum(SubsamplingDict[sample2])/iterations
        outline = outline + '\t' + str(int(SF*min(len(ReadList1),len(ReadList2)))) + '\t' + str(pp1) + '\t' + str(pp2)
        (pval,padj,CI) = binomialProportionPvalueCI(pp1,int(SF*min(len(ReadList1),len(ReadList2))),pp2,int(SF*min(len(ReadList1),len(ReadList2))),numLabels)
        outline = outline + '\t' + str(pval) + '\t' + str(padj) + '\t' + str(pp1 - pp2) + '\t+/-' + str(CI)
        outfile.write(outline + '\n')

    outfile.close()
        
run()

