##################################
#                                #
# Last modified 04/18/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s FPKM_table geneFieldID referenceFPKMFieldID FPKMfieldIDs FPKM_cutoffs minFractionDeviation outfile' % sys.argv[0]
        print '\t minFractionDeviation - cutoff for the deviation from the a/(a+b) ration, where a is the smaller expression value'
        print '\t FPKMfieldIDs a combination of comma-separated and from:to (both included) values'
        print '\t FPKM_cutoffs - comma separated'
        print '\t A header line with column labels is assumed'
        sys.exit(1)

    FPKMtable = sys.argv[1]
    geneFieldID = int(sys.argv[2])
    refFPKMFieldID = int(sys.argv[3])
    FPKMFieldIDs = []
    fields = sys.argv[4].split(',')
    for f in fields:
        if ':' in f:
            start = int(f.split(':')[0])
            end = int(f.split(':')[1])+1
            for i in range (start,end):
                FPKMFieldIDs.append(i)
        else:
            FPKMFieldIDs.append(int(f))
    FPKMFieldIDs.sort()
    fields = sys.argv[5].split(',')
    FPKMcutoffs = []
    for f in fields:
        FPKMcutoffs.append(float(f))
    FPKMcutoffs.append(0.0)
    FPKMcutoffs = list(Set(FPKMcutoffs))
    FPKMcutoffs.sort()
    minFracDev = float(sys.argv[6])
    outfilename = sys.argv[7]

    GeneDict = {}
    refFPKMGeneDict = {}

    FieldToLabelDict = {}

    linelist = open(FPKMtable)
    for line in linelist:
        fields = line.strip().split('\t')
        if line.startswith('#'):
            for ID in FPKMFieldIDs:
                FieldToLabelDict[ID] = fields[ID]
            continue
        refFPKM = float(fields[refFPKMFieldID])
        if refFPKM == 0:
            continue
        gene = fields[geneFieldID]
        if refFPKM >= max(FPKMcutoffs):
            FPKMcutoff = max(FPKMcutoffs)
        else:
            for i in range(len(FPKMcutoffs)):
                if refFPKM >= FPKMcutoffs[i] and refFPKM < FPKMcutoffs[i+1]:
                    FPKMcutoff = FPKMcutoffs[i]
                    break
        refFPKMGeneDict[gene] = (refFPKM,FPKMcutoff)
        GeneDict[gene]={}
        for ID in FPKMFieldIDs:
            FPKM = float(fields[ID])
            GeneDict[gene][FieldToLabelDict[ID]] = FPKM

    print 'finished parsing FPKMs, found', len(GeneDict.keys()), 'genes with expression values greater than zero'

    Labels = GeneDict[gene].keys()
    Labels.sort()

    outfile = open(outfilename, 'w')

    general_outline = '#'
    for FPKM in FPKMcutoffs:
        general_outline = general_outline + '\t' + str(FPKM)
    general_outline = general_outline + '\n'

    k = 0 
    for label in Labels:
        k+=1
        print k
        outline = '##################################'
        outfile.write(outline + '\n')
        outline = label
        outfile.write(outline + '\n')
        outline = '##################################'
        outfile.write(outline + '\n')
        FractionDict = {}
        for FPKM1 in FPKMcutoffs:
            FractionDict[FPKM1]={}
            for FPKM2 in FPKMcutoffs:
                FractionDict[FPKM1][FPKM2] = [0,0]
        for gene1 in GeneDict.keys():
            FPKM1 = GeneDict[gene1][label]
            (refFPKM1,FPKMcutoff1) = refFPKMGeneDict[gene1]
            for gene2 in GeneDict.keys():
                if gene1 == gene2:
                    continue
                FPKM2 = GeneDict[gene2][label]
                (refFPKM2,FPKMcutoff2) = refFPKMGeneDict[gene2]
                if refFPKM1 + refFPKM2 == 0:
                    continue
                if FPKM1 + FPKM2 == 0:
                    continue
                refRatio = min(refFPKM1,refFPKM2)/(refFPKM1 + refFPKM2)
                if min(refFPKM1,refFPKM2) == refFPKM1:
                    Ratio = FPKM1/(FPKM1 + FPKM2)
                else:
                    Ratio = FPKM2/(FPKM1 + FPKM2)
                FractionDict[FPKMcutoff1][FPKMcutoff2][1]+=1
                if math.fabs(Ratio - refRatio)/refRatio < minFracDev:
                    FractionDict[FPKMcutoff1][FPKMcutoff2][0]+=1
        outline = '#'
        for FPKM in FPKMcutoffs:
            outline = outline + '\t' + str(FPKM)
        outfile.write(outline + '\n')
        for FPKM1 in FPKMcutoffs:
            outline = str(FPKM1)
            for FPKM2 in FPKMcutoffs:
                total = FractionDict[FPKM1][FPKM2][1]
                passing = FractionDict[FPKM1][FPKM2][0]
                if total == 0:
                    outline = outline + '\t'
                else:
                    outline = outline + '\t' + str(passing/(total + 0.0))
            outfile.write(outline + '\n')

    outfile.close()

run()