##################################
#                                #
# Last modified 5/23/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s  expressionfilename listofwantedgenes outfilename [-skipLOC] [-capAll] [-startField fieldID] [-ExprStartField fieldID] [-expressedOnly threshold startfield] numfields] [-upregulated/-downregulated/-flat_on/-flat_off minRPKM FoldChange fieldID1 fieldID2] [-autonameoutfile prefix] [-concatenate]' % sys.argv[0]
        sys.exit(1)

    expressiondatafilename = sys.argv[1]
    listofgenesfilename = sys.argv[2]
    outfilename = sys.argv[3]
    if '-autonameoutfile' in sys.argv:
        if '-expressedOnly' in sys.argv:
            outfilename=sys.argv[sys.argv.index('-autonameoutfile') + 1]+listofgenesfilename.split('.')[0]+'-minRPKM='+sys.argv[sys.argv.index('-expressedOnly') + 1]+'.'+listofgenesfilename.split('.')[1]
        else:
            outfilename=sys.argv[sys.argv.index('-autonameoutfile') + 1]+listofgenesfilename
    doExpressedOnly = False
    doUpregulated = False
    doDownregulated = False
    doFlatOn = False
    doFlatOff = False
    doSkipLOC = False
    doCapitalize = False
    doCat=False
    if '-concatenate' in sys.argv:
        doCat=True
        concat = {}
    if '-capAll' in sys.argv:
        doCapitalize = True
        print 'doCapitalize'
    if '-skipLOC' in sys.argv:
        doSkipLOC = True
        print 'Skipping LOCs, Riks etc.'
    LWGstartField=0
    if '-startField' in sys.argv:
        LWGstartField = int(sys.argv[sys.argv.index('-startField') + 1])
    EstartField=1
    if '-ExprStartField' in sys.argv:
        EstartField = int(sys.argv[sys.argv.index('-ExprStartField') + 1])
    if '-expressedOnly' in sys.argv:
        doExpressedOnly = True
        threshold = float(sys.argv[sys.argv.index('-expressedOnly') + 1])
        startfield = int(sys.argv[sys.argv.index('-expressedOnly') + 2])
        endfield = startfield + int(sys.argv[sys.argv.index('-expressedOnly') + 3])
        print 'Doing "ExpressedOnly"'
    if '-upregulated' in sys.argv:
        doUpregulated = True
        minRPKM = float(sys.argv[sys.argv.index('-upregulated') + 1])
        minFoldChange = float(sys.argv[sys.argv.index('-upregulated') + 2])
        fieldID1 = int(sys.argv[sys.argv.index('-upregulated') + 3])
        fieldID2 = int(sys.argv[sys.argv.index('-upregulated') + 4])
        print 'Extracting upregulated genes with minRPKM = ', minRPKM, 'and minimal fold change = ', minFoldChange
    if '-downregulated' in sys.argv:
        doDownregulated = True
        minRPKM = float(sys.argv[sys.argv.index('-downregulated') + 1])
        minFoldChange = float(sys.argv[sys.argv.index('-downregulated') + 2])
        fieldID1 = int(sys.argv[sys.argv.index('-downregulated') + 3])
        fieldID2 = int(sys.argv[sys.argv.index('-downregulated') + 4])
        print 'Extracting downregulated genes with minRPKM = ', minRPKM, 'and minimal fold change = ', minFoldChange
    if '-flat_on' in sys.argv:
        doFlatOn = True
        minRPKM = float(sys.argv[sys.argv.index('-flat_on') + 1])
        maxFoldChange = float(sys.argv[sys.argv.index('-flat_on') + 2])
        fieldID1 = int(sys.argv[sys.argv.index('-flat_on') + 3])
        fieldID2 = int(sys.argv[sys.argv.index('-flat_on') + 4])
        print 'Extracting flat-on genes with minRPKM = ', minRPKM, 'and maximal fold change = ', maxFoldChange
    if '-flat_off' in sys.argv:
        doFlatOff = True
        maxRPKM = float(sys.argv[sys.argv.index('-flat_off') + 1])
        maxFoldChange = float(sys.argv[sys.argv.index('-flat_off') + 2])
        fieldID1 = int(sys.argv[sys.argv.index('-flat_off') + 3])
        fieldID2 = int(sys.argv[sys.argv.index('-flat_off') + 4])
        print 'Extracting flat-off genes with miaxnRPKM = ', maxRPKM, 'and maximal fold change = ', maxFoldChange


    if (doUpregulated and doExpressedOnly) or (doDownregulated and doExpressedOnly):
        print 'Warning: it is better not to use both the -expressedOnly and -upregulated filter'
        sys.exit(1)
    if doDownregulated and doUpregulated:
        print 'Can not extract both up- and down-regulated genes'
        sys.exit(1)

    outfile = open(outfilename, 'w')

    listofgenesfile = open(listofgenesfilename)
    lineslist = listofgenesfile.readlines()
    wantedgenes = {}
    selectedgenes = {}
    for line in lineslist:
        fields = line.split('\n')[0].split('\t')[LWGstartField].split(' ')
        if doSkipLOC:
            name=fields[0].strip(' ')
            if name[0:2].isupper() or name[len(name)-3:len(name)]=='Rik':
                continue
        if doCapitalize:
            name=fields[0].strip(' ')
            name=name.lower()
            wantedgenes[name]={}
        else:
            name=fields[0].strip(' ')
        wantedgenes[name]={}        
        if doCat:
            concat[name]=line.split('\n')[0]

    expressiondatafile = open(expressiondatafilename)
    expressiondatalist = expressiondatafile.readlines()
    expressionline1 = expressiondatalist[0]
    expressiondatalist.remove(expressiondatalist[0])
    k=0
    for line in expressiondatalist:
        fields = line.split('\n')[0].split('\t') 
        if fields[EstartField][0:3]=='Igk':
            continue
        if doCapitalize:
            fields[EstartField]=fields[EstartField].lower()
        if doExpressedOnly:
            expressionfields=[]
            for ff in range(startfield,endfield):
                expressionfields.append(float(fields[ff]))
            if (fields[EstartField] in wantedgenes.keys()) and (max(expressionfields)>=threshold):
                selectedgenes[fields[EstartField]]={}
                selectedgenes[fields[EstartField]]['expression']=line.split('\n')[0]
            else:
                continue
        elif doUpregulated:
            if float(fields[fieldID1])==0.0:
                fields[fieldID1]=0.01
            if float(fields[fieldID2])==0.0:
                fields[fieldID2]=0.01
            if (fields[EstartField] in wantedgenes.keys()) and (float(fields[fieldID1])>=minRPKM or float(fields[fieldID2])>=minRPKM) and (float(fields[fieldID2])/float(fields[fieldID1])>=minFoldChange):
                selectedgenes[fields[EstartField]]={}
                selectedgenes[fields[EstartField]]['expression']=line.split('\n')[0]
                k+=1
#                print float(fields[fieldID2])/float(fields[fieldID1]), minFoldChange
#                print 'True'
#                print line
#                print fields[fieldID1], fields[fieldID2]
            else:
                continue
        elif doDownregulated:
            if float(fields[fieldID2])==0.0:
                fields[fieldID2]=0.01
            if float(fields[fieldID1])==0.0:
                fields[fieldID1]=0.01
            if (fields[EstartField] in wantedgenes.keys()) and (float(fields[fieldID1])>=minRPKM or float(fields[fieldID2])>=minRPKM) and (float(fields[fieldID1])/float(fields[fieldID2])>=minFoldChange):
                selectedgenes[fields[EstartField]]={}
                selectedgenes[fields[EstartField]]['expression']=line.split('\n')[0]
            else:
                continue
        elif doFlatOn:
            if float(fields[fieldID2])==0.0:
                fields[fieldID2]=0.01
            if float(fields[fieldID1])==0.0:
                fields[fieldID1]=0.01
            if (fields[EstartField] in wantedgenes.keys()) and (float(fields[fieldID1])>=minRPKM or float(fields[fieldID2])>=minRPKM) and (float(fields[fieldID1])/float(fields[fieldID2])<=maxFoldChange) and (float(fields[fieldID2])/float(fields[fieldID1])<=maxFoldChange):
                selectedgenes[fields[EstartField]]={}
                selectedgenes[fields[EstartField]]['expression']=line.split('\n')[0]
            else:
                continue
        elif doFlatOff:
            if float(fields[fieldID2])==0.0:
                fields[fieldID2]=0.01
            if float(fields[fieldID1])==0.0:
                fields[fieldID1]=0.01
            if (fields[EstartField] in wantedgenes.keys()) and (float(fields[fieldID1])<=maxRPKM and float(fields[fieldID1])<=maxRPKM) and (float(fields[fieldID1])/float(fields[fieldID2])<=maxFoldChange) and (float(fields[fieldID2])/float(fields[fieldID1])<=maxFoldChange):
                selectedgenes[fields[EstartField]]={}
                selectedgenes[fields[EstartField]]['expression']=line.split('\n')[0]
            else:
                continue
        else:    
            if fields[EstartField] in wantedgenes.keys():
                selectedgenes[fields[EstartField]]={}
                selectedgenes[fields[EstartField]]['expression']=line.split('\n')[0]
            else:
                continue

#    print 'len(selectedgenes.keys())', len(selectedgenes.keys())
#    print 'k', k

    outfile.write(expressionline1.split('\n')[0])
    outfile.write('\n')
    for gene in selectedgenes.keys():
        if doCat:
            newline=concat[gene]+'\t'+selectedgenes[gene]['expression']
            outfile.write(newline)
        else:
            outfile.write(selectedgenes[gene]['expression'])
        outfile.write('\n')

run()

