##################################
#                                #
# Last modified 03/16/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s list_of_genes fieldID(s) expression_table fieldID(s) expression_field_IDs expression_values outfile' % sys.argv[0]
        print '\tfieldID(s) and expression values  should be comma-separated'
        print '\texpression_field_IDs can be a combination of comma separated and from:to (included) terms'
        sys.exit(1)

    genes=sys.argv[1]

    genefieldIDs = [] 
    fields = sys.argv[2].split(',')
    for ID in fields:
        genefieldIDs.append(int(ID))
    genefieldIDs = list(Set(genefieldIDs))
    genefieldIDs.sort()

    expression=sys.argv[3]

    expression_file_genefieldIDs = [] 
    fields = sys.argv[4].split(',')
    for ID in fields:
        expression_file_genefieldIDs.append(int(ID))
    expression_file_genefieldIDs = list(Set(expression_file_genefieldIDs))
    expression_file_genefieldIDs.sort()

    fields = sys.argv[5].split(',')
    FPKMFieldIDs = []
    for ID in fields:
        if ':' in ID:
            start = int(ID.split(':')[0])
            end = int(ID.split(':')[1])
            for i in range(start,end+1):
                FPKMFieldIDs.append(i)
        else:
            FPKMFieldIDs.append(int(ID))
    FPKMFieldIDs = list(Set(FPKMFieldIDs))
    FPKMFieldIDs.sort()
    print 'FPKM fields:', FPKMFieldIDs

    FPKM_cutoffs = []
    for FPKM in sys.argv[6].split(','):
        FPKM_cutoffs.append(float(FPKM))
    FPKM_cutoffs.append(0)
    FPKM_cutoffs = list(Set(FPKM_cutoffs))
    FPKM_cutoffs.sort()
    FPKMDict={}
    for FPKM in FPKM_cutoffs:
        FPKMDict[FPKM] = {}
        for i in range(len(FPKMFieldIDs)+1):
            FPKMDict[FPKM][i] = 0

    outfilename = sys.argv[7]

    GeneDict={}

    linelist=open(genes)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        gene = []
        for ID in genefieldIDs:
            gene.append(fields[ID])
        gene = tuple(gene)
        GeneDict[gene]=0

    linelist=open(expression)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        gene = []
        for ID in expression_file_genefieldIDs:
            gene.append(fields[ID])
        gene = tuple(gene)
        if GeneDict.has_key(gene):
            pass
        else:
            continue
        FPKMs = []
        zeros = 0
        for ID in FPKMFieldIDs:
            if float(fields[ID]) == 0:
                zeros+=1
            else:
                FPKMs.append(float(fields[ID]))
        if len(FPKMs) == 0:
            continue
        average = numpy.mean(FPKMs)
        if average >= FPKM_cutoffs[-1]:
            FPKM = FPKM_cutoffs[-1]
        else:
            for i in range(len(FPKM_cutoffs)-1):
                if average >= FPKM_cutoffs[i] and average < FPKM_cutoffs[i+1]:
                    FPKM = FPKM_cutoffs[i]
                    break
        FPKMDict[FPKM][zeros]+=1

    outfile = open(outfilename, 'w')

    outline = '#expression_level\tnumber_genes'
    for i in range(len(FPKMFieldIDs)+1):
        outline = outline + '\t' + str(i)
    outfile.write(outline + '\n')

    for FPKM in FPKM_cutoffs:
        outline = str(FPKM)
        total_genes = 0.0
        for i in range(len(FPKMFieldIDs)+1):
            total_genes += FPKMDict[FPKM][i]
        outline = outline + '\t' + str(total_genes)
        for i in range(len(FPKMFieldIDs)+1):
            if total_genes == 0:
                outline = outline + '\t' + str(0)
            else:
                outline = outline + '\t' + str(FPKMDict[FPKM][i]/total_genes)
        outfile.write(outline + '\n')

    outfile.close()

run()