##################################
#                                #
# Last modified 09/24/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
from sets import Set

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s <cuffdiff file> <list of sample pairs> <fold change threshold> outputfilename [-pvalue number] [-minFPKM number] [-outputCuffdiffLine]' % sys.argv[0]
        print '       list of sample pairs format: sample1 <tab> sample2'
        print '       minFPKM in either of the two samples'
        print '       if you use the -pvalue option the significance test yes or no value will be ignored and the p-value will be used instead'
        sys.exit(1)

    inputfilename = sys.argv[1]
    listofpairs = sys.argv[2]
    threshold = float(sys.argv[3])
    outputfilename = sys.argv[4]

    doPvalue=False
    if '-pvalue' in sys.argv:
        doPvalue=True
        minP = float(sys.argv[sys.argv.index('-pvalue')+1])

    doOutputLine=False
    if '-outputCuffdiffLine' in sys.argv:
        doOutputLine=True

    doMinFPKM=False
    if '-minFPKM' in sys.argv:
        doMinFPKM=True
        minFPKM = float(sys.argv[sys.argv.index('-minFPKM')+1])

    lineslist = open(listofpairs)
    PairDict={}    
    for line in lineslist:
        fields=line.strip().split('\t')
        PairDict[(fields[0],fields[1])]=''

    print PairDict

    outfile = open(outputfilename, 'w')

    GeneList=[]

    lineslist = open(inputfilename)
    i=0
    for line in lineslist:
        i+=1
        fields=line.strip().split('\t')
        if i == 1:
            outfile.write(line)
            fold_change_fieldID = fields.index('log2(fold_change)')
            sample1_fieldID = fields.index('sample_1')
            sample2_fieldID = fields.index('sample_2')
            value1_fieldID = fields.index('value_2')
            value2_fieldID = fields.index('value_2')
            status_fieldID = fields.index('status')
            significant_fieldID = fields.index('significant')
            pvalue_fieldID = fields.index('p_value')
            qvalue_fieldID = fields.index('q_value')
            gene_id_fieldID = fields.index('gene_id')
            gene_fieldID = fields.index('gene')
        if i % 1000000 == 0:
            print i, 'lines processed'
        sample1=fields[sample1_fieldID]
        sample2=fields[sample2_fieldID]
        if PairDict.has_key((sample1,sample2)):
            pass
        else:
            continue
        status=fields[status_fieldID]
        if status != 'OK':
            continue
        significant=fields[significant_fieldID]
        if doPvalue:
            P = float(fields[pvalue_fieldID])
            if P > minP:
                continue
        else:
            if significant != 'yes':
                continue
        try:
            foldchange=math.pow(2,float(fields[fold_change_fieldID]))
        except:
            print 'math error, assuming differential expression', fields[fold_change_fieldID], fields
            foldchange = threshold+1
        ID=fields[gene_id_fieldID]
        gene=fields[gene_fieldID]
        if doMinFPKM:
            FPKM1=float(fields[value1_fieldID])
            FPKM2=float(fields[value1_fieldID])
            if FPKM1 >= minFPKM or FPKM2 >= minFPKM:
                pass
            else:
                continue
        print fields, FPKM1, FPKM2, foldchange
        if foldchange > threshold or foldchange < 1/threshold:
            GeneList.append((ID,gene))
            if doOutputLine:
                outfile.write(line)

    GeneList=list(Set(GeneList))

    GeneList.sort()
    if doOutputLine:
        pass
    else:
        for (ID,gene) in GeneList:
            outfile.write(ID + '\t' + gene + '\n')

    outfile.close()

run()

