##################################
#                                #
# Last modified 5/6/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 6:
        print '\nusage: python %s file_with_gene-sites_pairs file_with_expression_values desired_group_of_genes minRPKM minFoldChange outfilename [-selectedsubset selectedsubsetfilename]' % sys.argv[0]
        print '\n\tfile with gene-sites pairs is the output of getBindingSitesRelativeToTSS.py'
        print '\tfile with expression values should be in the format: geneID geneName	exp1 exp2 exp2-exp1_Net_Difference exp2-exp1_FoldChange' 
        print '\tdesired group of genes: "upregulated", "downregulated", "flatexpressed" or "flatunexpressed"\n'

        sys.exit(1)

    listofgenesitepairs = sys.argv[1]
    expressionfile = sys.argv[2]
    parameter = sys.argv[3]
    minRPKM = float(sys.argv[4])
    minFoldChange = float(sys.argv[5])
    outfilename = sys.argv[6]
    doSelectedsubset=False
    if '-selectedsubset' in sys.argv:
        doSelectedsubset=True
        selectedsubsetfile = open(sys.argv[sys.argv.index('-selectedsubset') + 1])
        print '\n\tonly the selected subset is considered'
        lineslist = selectedsubsetfile.readlines()
        selectedsubset = []
        for line in lineslist:
            selectedsubset.append(line.split('\n')[0].split('\t')[0])

    expressionfile = open(expressionfile)
    lineslist = expressionfile.readlines()
    genes = {}
#    if doSelectedsubset:
#        for line in lineslist:
#            fields = line.split('\n')[0].split('\t')
#            if fields[1] not in selectedsubset:
#                continue
#            else:
#                genes[fields[0]]={}
#                genes[fields[0]]['sites']=[]
#                genes[fields[0]]['geneID']=fields[0]
#                genes[fields[0]]['geneName']=fields[1]
#                genes[fields[0]]['exp1']=float(fields[2])
#                genes[fields[0]]['exp2']=float(fields[3])
#                genes[fields[0]]['netdifference']=float(fields[4])
#                genes[fields[0]]['foldchange']=float(fields[5])
#    else:
    for line in lineslist:
        fields = line.split('\n')[0].split('\t')
        genes[fields[0]]={}
        genes[fields[0]]['sites']=[]
        genes[fields[0]]['geneID']=fields[0]
        genes[fields[0]]['geneName']=fields[1]
        genes[fields[0]]['exp1']=float(fields[2])
        genes[fields[0]]['exp2']=float(fields[3])
        genes[fields[0]]['netdifference']=float(fields[4])
        genes[fields[0]]['foldchange']=float(fields[5])

    print 'len(genes.keys())', len(genes.keys())

    listofgenesitepairs = open(listofgenesitepairs)
    lineslist = listofgenesitepairs.readlines()
    header = lineslist[0]
    lineslist.remove(lineslist[0])
    for line in lineslist:
        fields = line.split('\n')[0].split('\t')
        genes[fields[0]]['sites'].append(line)

    outfile = open(outfilename, 'w')
    outfile.write(header)

    if parameter == 'upregulated':
        print 'extracting upregulated genes'
        for geneID in genes.keys():
            if doSelectedsubset:
                if (genes[geneID]['exp1'] > minRPKM or genes[geneID]['exp2']>minRPKM) and (genes[geneID]['foldchange']>minFoldChange) and (genes[geneID]['geneName'] in selectedsubset):
                    for line in genes[geneID]['sites']:
                        outfile.write(line)
                else:
                    continue
            else:
                if (genes[geneID]['exp1'] > minRPKM or genes[geneID]['exp2']>minRPKM) and (genes[geneID]['foldchange']>minFoldChange):
                    for line in genes[geneID]['sites']:
                        outfile.write(line)
                else:
                    continue
    if parameter == 'downregulated':
        print 'extracting downregulated genes'
        for geneID in genes.keys():
            if doSelectedsubset:
                if ((genes[geneID]['exp1']>minRPKM or genes[geneID]['exp2']>minRPKM) and genes[geneID]['foldchange']<minFoldChange) and (genes[geneID]['geneName'] in selectedsubset):
                    for line in genes[geneID]['sites']:
                        outfile.write(line)
                else:
                    continue
            else:
                if ((genes[geneID]['exp1']>minRPKM or genes[geneID]['exp2']>minRPKM) and genes[geneID]['foldchange']<minFoldChange):
                    for line in genes[geneID]['sites']:
                        outfile.write(line)
                else:
                    continue    
    if parameter == 'flatexpressed':
        print 'extracting flat expressed genes'
        for geneID in genes.keys():
            if doSelectedsubset:
                if ((genes[geneID]['exp1']>minRPKM and genes[geneID]['exp2']>minRPKM) and (genes[geneID]['foldchange']<minFoldChange) and (genes[geneID]['foldchange']>1/minFoldChange)) and (genes[geneID]['geneName'] in selectedsubset):
                    for line in genes[geneID]['sites']:
                        outfile.write(line)
                else:
                    continue
            else:
                if ((genes[geneID]['exp1']>minRPKM and genes[geneID]['exp2']>minRPKM) and (genes[geneID]['foldchange']<minFoldChange) and (genes[geneID]['foldchange']>1/minFoldChange)):
                    for line in genes[geneID]['sites']:
                        outfile.write(line)
                else:
                    continue
    if parameter == 'flatunexpressed':
        print 'extracting flat unexpressed genes'
        for geneID in genes.keys():
            if doSelectedsubset:
                if ((genes[geneID]['exp1']<minRPKM and genes[geneID]['exp2']<minRPKM) and (genes[geneID]['foldchange']<minFoldChange) and (genes[geneID]['foldchange']>1/minFoldChange)) and (genes[geneID]['geneName'] in selectedsubset):
                    for line in genes[geneID]['sites']:
                        outfile.write(line)
                else:
                    continue
            else:
                if ((genes[geneID]['exp1']<minRPKM and genes[geneID]['exp2']<minRPKM) and (genes[geneID]['foldchange']<minFoldChange) and (genes[geneID]['foldchange']>1/minFoldChange)):
                    for line in genes[geneID]['sites']:
                        outfile.write(line)
                else:
                    continue
run()