##################################
#                                #
# Last modified 08/24/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import numpy
from scipy.stats import norm
import math
import random
import string
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s list-of-shuffled-overlap-files overlap transcriptIDfieldIDs overlapTypeFieldID elementIDs outfilename [-countRepeatsMoreThanOnce]' % sys.argv[0]
        print '\tNote: by default, the script will count each lncRNAs once for each repeat, even if it overlaps multiple copies of it; the [-countRepeatsMoreThanOnce] option will make it count it once for each repeat copy'
        sys.exit(1)

    list_of_files = sys.argv[1]
    overlap_file = sys.argv[2]
    transcriptIDFieldIDs = []
    fields = sys.argv[3].split(',')
    for ID in fields:
        transcriptIDFieldIDs.append(int(ID))
    transcriptIDFieldIDs.sort()
    ElementFieldIDs = []
    overlapTypeFieldID = int(sys.argv[4])
    fields = sys.argv[5].split(',')
    for ID in fields:
        ElementFieldIDs.append(int(ID))
    ElementFieldIDs.sort()
    outputfilename = sys.argv[6]

    doCRMTO = False
    if '-countRepeatsMoreThanOnce' in sys.argv:
        doCRMTO = True
        print '[-countRepeatsMoreThanOnce] option enabled'

    RepeatsToConsiderDict = {}
    ShuffledTrnascriptsRepeatOverlapDict = {}
    ShuffledTrnascriptsRepeatOverlapDict['exonic'] = {}
    ShuffledTrnascriptsRepeatOverlapDict['intronic-only'] = {}

    lineslist = open(overlap_file)
    RepeatOverlapDict = {}
    RepeatOverlapDict['exonic'] = {}
    RepeatOverlapDict['intronic-only'] = {}
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        transcript = []
        for ID in transcriptIDFieldIDs:
            transcript.append(fields[ID])
        transcript = tuple(transcript)
        repeat = []
        for ID in ElementFieldIDs:
            repeat.append(fields[ID])
        repeat = tuple(repeat)
        RepeatsToConsiderDict[repeat] = 0
        ShuffledTrnascriptsRepeatOverlapDict['exonic'][repeat] = []
        ShuffledTrnascriptsRepeatOverlapDict['intronic-only'][repeat] =[]
        overlap = fields[overlapTypeFieldID]
        if overlap == 'exonic':
            if RepeatOverlapDict['exonic'].has_key(repeat):
                pass
            else:
                RepeatOverlapDict['exonic'][repeat] = []
            RepeatOverlapDict['exonic'][repeat].append(transcript)
        if overlap == 'intronic-only':
            if RepeatOverlapDict['intronic-only'].has_key(repeat):
                pass
            else:
                RepeatOverlapDict['intronic-only'][repeat] = []
            RepeatOverlapDict['intronic-only'][repeat].append(transcript)

    print 'finished inputting overlap file'

    j=0
    lineslist = open(list_of_files)
    for line1 in lineslist:
        j+=1
        if j % 1 == 0:
            print j, 'shuffled files processed', line1.strip()
        if line1.startswith('#'):
            continue
        fields1=line1.strip().split('\t')
        file = fields1[0]
        TempDict = {}
        TempDict['exonic'] = {}
        TempDict['intronic-only'] = {}
        lines = open(file)
        for line in lines:
            if line.startswith('#'):
                continue
            fields=line.strip().split('\t')
            transcript = []
            for ID in transcriptIDFieldIDs:
                transcript.append(fields[ID])
            transcript = tuple(transcript)
            repeat = []
            for ID in ElementFieldIDs:
                repeat.append(fields[ID])
            repeat = tuple(repeat)
            overlap = fields[overlapTypeFieldID]
            if overlap == 'exonic':
                if TempDict['exonic'].has_key(repeat):
                    pass
                else:
                    TempDict['exonic'][repeat] = []
                TempDict['exonic'][repeat].append(transcript)
            if overlap == 'intronic-only':
                if TempDict['intronic-only'].has_key(repeat):
                    pass
                else:
                    TempDict['intronic-only'][repeat] = []
                TempDict['intronic-only'][repeat].append(transcript)
        for repeat in RepeatsToConsiderDict.keys():
            if TempDict['exonic'].has_key(repeat):
                if doCRMTO:
                    N_transcripts = len(TempDict['exonic'][repeat])
                else:
                    N_transcripts = len(list(Set(TempDict['exonic'][repeat])))
            else:
                N_transcripts = 0
            ShuffledTrnascriptsRepeatOverlapDict['exonic'][repeat].append(N_transcripts)
            if TempDict['intronic-only'].has_key(repeat):
                if doCRMTO:
                    N_transcripts = len(TempDict['intronic-only'][repeat])
                else:
                    N_transcripts = len(list(Set(TempDict['intronic-only'][repeat])))
            else:
                N_transcripts = 0
            ShuffledTrnascriptsRepeatOverlapDict['intronic-only'][repeat].append(N_transcripts)
            
    print 'finished parsing repeats'

    outfile = open(outputfilename, 'w')

    outline = '#type-overlap\t'
    for i in range(len(ElementFieldIDs)):
        outline = outline + 'repeat_level_' + str(i+1) + '\t'
    outline = outline + 'N_transcripts\tshuffled_mean\tshuffled_stdev\tp-value'
    outfile.write(outline + '\n')

    for repeat in RepeatOverlapDict['intronic-only'].keys():
        outline = 'intronic-only' + '\t'
        for i in range(len(repeat)):
            outline = outline + repeat[i] + '\t'
        NT = len(list(Set(RepeatOverlapDict['intronic-only'][repeat])))
        shuffled_mean = numpy.mean(ShuffledTrnascriptsRepeatOverlapDict['intronic-only'][repeat])
        shuffled_std = numpy.std(ShuffledTrnascriptsRepeatOverlapDict['intronic-only'][repeat])
        outline = outline + str(NT) + '\t' + str(shuffled_mean) + '\t' + str(shuffled_std) + '\t'
        rv = norm(shuffled_mean,shuffled_std)
        pval = 1 - rv.cdf(NT)
        outline = outline + str(pval)
        outfile.write(outline + '\n')

    for repeat in RepeatOverlapDict['exonic'].keys():
        outline = 'exonic' + '\t'
        for i in range(len(repeat)):
            outline = outline + repeat[i] + '\t'
        NT = len(list(Set(RepeatOverlapDict['exonic'][repeat])))
        shuffled_mean = numpy.mean(ShuffledTrnascriptsRepeatOverlapDict['exonic'][repeat])
        shuffled_std = numpy.std(ShuffledTrnascriptsRepeatOverlapDict['exonic'][repeat])
        outline = outline + str(NT) + '\t' + str(shuffled_mean) + '\t' + str(shuffled_std) + '\t'
        rv = norm(shuffled_mean,shuffled_std)
        pval = 1 - rv.cdf(NT)
        outline = outline + str(pval)
        outfile.write(outline + '\n')

    outfile.close()

run()

