##################################
#                                #
# Last modified 11/20/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import math

def getFPKMCutoff(FPKMcutoffs,FPKM):

    if FPKM >= max(FPKMcutoffs):
        FPKMcutoff = max(FPKMcutoffs)
    else:
        for i in range(len(FPKMcutoffs)-1):
            if FPKM >= FPKMcutoffs[i] and FPKM < FPKMcutoffs[i+1]:
                FPKMcutoff = FPKMcutoffs[i]
                break

    return FPKMcutoff

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s FPKM_table gene_fieldIDs reference_fields FPKM_fieldIDs FPKM_cutoffs detection_cutoff outfilename' % sys.argv[0]
        print '\tNote: geneID fields, FPKM_fieldID(s) and FPKM_cutoffs should be comma-separated; FPKM_fieldID(s) can also take from:to input'
        print '\tNote: a header line beginning with "#" or with "tracking_id" and containing library names is expected'
        print '\tNote: the maximum of the reference FPKM values will be used for each gene'
        sys.exit(1)

    FPKM_table = sys.argv[1]

    fields = sys.argv[2].split(',')
    geneIDs = []
    for ID in fields:
        geneIDs.append(int(ID))

    fields = sys.argv[3].split(',')
    referenceIDs = []
    for ID in fields:
        referenceIDs.append(int(ID))

    fields = sys.argv[4].split(',')
    FPKMIDs = []
    for ID in fields:
        if ':' in ID:
            start = int(ID.split(':')[0])
            end = int(ID.split(':')[1])
            for i in range(start,end+1):
                FPKMIDs.append(i)
        else:
            FPKMIDs.append(int(ID))

    fields = sys.argv[5].split(',')
    FPKMcutoffs = []
    for FPKM in fields:
        FPKMcutoffs.append(float(FPKM))

    print FPKMcutoffs

    detectionCutoff = float(sys.argv[6])
    outfilename = sys.argv[7]

    FPKMcutoffs = list(Set(FPKMcutoffs))
    FPKMIDs = list(Set(FPKMIDs))
    referenceIDs = list(Set(referenceIDs))

    FPKMcutoffs.sort()
    referenceIDs.sort()
    FPKMIDs.sort()

    print referenceIDs
    print FPKMIDs
    print FPKMcutoffs

    ReferenceCutoffDict = {}
    SampleDetectionDict = {}

    for FPKM in FPKMcutoffs:
        ReferenceCutoffDict[FPKM]={}
        ReferenceCutoffDict[FPKM]={}
        ReferenceCutoffDict[FPKM]=[]

    lineslist = open(FPKM_table)
    for line in lineslist:
        fields = line.strip().split('\t')
        if line.startswith('#') or line.startswith('tracking_id'):
            IDtoNameDict = {}
            for ID in referenceIDs:
                IDtoNameDict[ID] = fields[ID]
            for ID in FPKMIDs:
                IDtoNameDict[ID] = fields[ID]
                SampleDetectionDict[fields[ID]]={}
            continue
        gene = []
        for ID in geneIDs:
            gene.append(fields[ID])
        gene = tuple(gene)
        if len(gene) == 1:
            gene = gene[0]
        referenceFPKM = []
        for ID in referenceIDs:
            referenceFPKM.append(float(fields[ID]))
        FPKMcutoff = getFPKMCutoff(FPKMcutoffs,max(referenceFPKM))
        ReferenceCutoffDict[FPKMcutoff].append(gene)
        for ID in FPKMIDs:
            FPKM = float(fields[ID])
            sample = IDtoNameDict[ID]
            if FPKM >= detectionCutoff:
                SampleDetectionDict[sample][gene]=0

    outfile = open(outfilename, 'w')

    outline = '#'
    for FPKM in FPKMcutoffs:
        outline = outline + '\t' + str(FPKM)
    outfile.write(outline+'\n')

    for FPKM in ReferenceCutoffDict.keys():
        print FPKM, len(ReferenceCutoffDict[FPKM])

    for ID in FPKMIDs:
        sample = IDtoNameDict[ID]
        outline = sample
        for FPKM in FPKMcutoffs:
            refGenes = ReferenceCutoffDict[FPKM]
            sampleGenes = 0.0
            for gene in refGenes:
                if SampleDetectionDict[sample].has_key(gene):
                    sampleGenes+=1
            outline = outline + '\t' + str(sampleGenes/len(refGenes))
        outfile.write(outline+'\n')

    outfile.close()

run()
