##################################
#                                #
# Last modified 05/06/2016       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s input p-value outputfilename [-padj] [-switch]' % sys.argv[0]
        print '       Note: the script relies on the DESeq2 output having been passed through the fixDESeq2Ouput.py script, i.e. tab-delimited and with a header line starting with the # sign'
        sys.exit(1)
    
    input = sys.argv[1]
    MaxPValue = float(sys.argv[2])
    outfileUp = open(sys.argv[3] + '-up', 'w')
    outfileDown = open(sys.argv[3] + '-down', 'w')

    doSwitch = False
    if '-switch' in sys.argv:
        doSwitch = True

    doPAdj = False
    if '-padj' in sys.argv:
        doPAdj = True

    linelist = open(input)
    for line in linelist:
        fields=line.strip().split('\t')
        if line.startswith('#'):
            outfileUp.write(line)
            outfileDown.write(line)
            if doPAdj:
                pvalueFieldID = fields.index('padj')
            else:
                pvalueFieldID = fields.index('pvalue')
            logfoldChangeFieldID = fields.index('log2FoldChange')
            continue
        if fields[pvalueFieldID] == 'NA':
            continue
        pval = float(fields[pvalueFieldID])
        if pval <= MaxPValue:
            foldChange = math.pow(2,float(fields[logfoldChangeFieldID]))
            if doSwitch:
                if foldChange > 1:
                    outfileDown.write(line)
                if foldChange < 1:
                    outfileUp.write(line)
            else:
                if foldChange > 1:
                    outfileUp.write(line)
                if foldChange < 1:
                    outfileDown.write(line)

    outfileUp.close()
    outfileDown.close()
   
run()