##################################
#                                #
# Last modified 2017/05/30       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s cluster_significance.txt leafcutter_effect_sizes.txt p-value minDeltaPsi outputfilename' % sys.argv[0]
        print '       Note: the script relies on the DESeq2 output having been passed through the fixDESeq2Ouput.py script, i.e. tab-delimited and with a header line starting with the # sign'
        sys.exit(1)
    
    cluster_significance = sys.argv[1]
    effect_sizes = sys.argv[2]
    MaxPValue = float(sys.argv[3])
    minDeltaPsi = float(sys.argv[4])

    DXClusterDict = {}
    linelist = open(cluster_significance)
    for line in linelist:
        fields=line.strip().split('\t')
        if line.startswith('cluster'):
            continue
        if fields[4] == 'NA':
            continue
        pval = float(fields[4])
        cluster = fields[0].split(':')[1]
        if pval <= MaxPValue:
            DXClusterDict[cluster] = 1

    outfile = open(sys.argv[5], 'w')

    SC = {}

    linelist = open(effect_sizes)
    for line in linelist:
        fields=line.strip().split('\t')
        if line.startswith('intron'):
            outfile.write('#' + line)
            continue
        deltapsi = float(fields[4])
        cluster = fields[0].split(':')[3]
        if DXClusterDict.has_key(cluster):
            pass
        else:
            continue
        if deltapsi >= minDeltaPsi:
            SC[cluster] = 1
            outfile.write(line)

    print 'found', len(SC.keys()), 'clusters passing requirements'
    outfile.write('# found ' + str(len(SC.keys())) + ' clusters passing requirements' + '\n')

    outfile.close()
   
run()