##################################
#                                #
# Last modified 01/17/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import math
import numpy as np

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s Expression_table modules_file gene_fieldID FPKM_fieldIDs outfile_prefix [-rescale factor]' % sys.argv[0]
        print '\tmodules_file format: gene\tmodule_number'
        print '\tNote: FPKM_fieldIDs should be a combination of commas, and from:to (including both ends) values'
        print '\tNote: use the rescale option if the output file is too big to be conviniently handled; the number of genes in each module will be divided by that number and this will determine the size of the new data matrix'
        sys.exit(1)

    expression_table = sys.argv[1]
    modules_file = sys.argv[2]
    geneFieldID = int(sys.argv[3])
    fields = sys.argv[4].split(',')
    FPKMFieldIDs = []
    for ID in fields:
        if ':' in ID:
            start = int(ID.split(':')[0])
            end = int(ID.split(':')[1])
            for i in range(start,end+1):
                FPKMFieldIDs.append(i)
        else:
            FPKMFieldIDs.append(int(ID))
    FPKMFieldIDs = list(Set(FPKMFieldIDs))
    FPKMFieldIDs.sort()
    print 'FPKM fields:', FPKMFieldIDs
    outfile_prefix = sys.argv[5]

    doRescale = False
    if '-rescale' in sys.argv:
        doRescale = True
        resize = int(sys.argv[sys.argv.index('-rescale')+1])

    ModuleDict = {}
    GeneDict = {}

    linelist = open(modules_file)
    for line in linelist:
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        fields = line.strip().split('\t')
#        print fields
        gene = fields[0]
        module = int(fields[1]) 
        if ModuleDict.has_key(module):
            pass
        else:
            ModuleDict[module]=[]
        ModuleDict[module].append(gene)

    linelist = open(expression_table)
    for line in linelist:
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        fields = line.strip().split('\t')
        gene = fields[geneFieldID]
        expression_values = []
        for ID in FPKMFieldIDs:
            expression_values.append(float(fields[ID]))
        expression_values = np.array(expression_values)
        GeneDict[gene] = expression_values

    genes = GeneDict.keys()
    genes.sort()

    outfile_all = open(outfile_prefix + '.all', 'w')
    outfile_average = open(outfile_prefix + '.average', 'w')

    modules = ModuleDict.keys()
    modules.sort()

    if doRescale:
        outline = '#'
        for module in modules:
            ModuleSize = len(ModuleDict[module])
            newModuleSize = ModuleSize/resize
            for i in range(newModuleSize):
                outline = outline + '\t' + str(module) + '::' + str(i)
        outfile_average.write(outline+'\n')
        outline = '#'
        for module in modules:
            for gene in ModuleDict[module]:
                outline = outline + '\t' + str(module) + '::' + gene
        outfile_all.write(outline+'\n')
    else:
        outline = '#'
        for module in modules:
            for gene in ModuleDict[module]:
                outline = outline + '\t' + str(module) + '::' + gene
        outfile_all.write(outline+'\n')
        outfile_average.write(outline+'\n')

    ModuleModuleDict={}

    i = 0
    for module1 in modules:
        if ModuleModuleDict.has_key(module1):
            pass
        else:
            ModuleModuleDict[module1]={}
        for module2 in modules:
            if ModuleModuleDict[module1].has_key(module2):
                pass
            else:
                ModuleModuleDict[module1][module2]=[]
    for module1 in modules:
        for gene1 in ModuleDict[module1]:
            outline = str(module1) + '::' + gene1
            i+=1
            if i % 100 == 0:
                print i, 'genes processed'
            for module2 in modules:
                for gene2 in ModuleDict[module2]:
                    correlation = np.corrcoef(GeneDict[gene1],GeneDict[gene2])[0,1]
                    if gene1 != gene2 and str(correlation) != 'nan':
                        ModuleModuleDict[module1][module2].append(correlation)
                    outline = outline + '\t' + str(correlation)[0:6]
            outfile_all.write(outline+'\n')

    ModuleModuleAverageCorrelationDict={}
    for module1 in modules:
        ModuleModuleAverageCorrelationDict[module1]={}
        for module2 in modules:
            if len(ModuleModuleDict[module1][module2]) == 0 or sum(ModuleModuleDict[module1][module2]) == 0:
                ModuleModuleAverageCorrelationDict[module1][module2] = 0
            else:
                ModuleModuleAverageCorrelationDict[module1][module2] = sum(ModuleModuleDict[module1][module2])/(len(ModuleModuleDict[module1][module2])+0.0)

    i=0
    for module1 in modules:
        if doRescale:
            Module1Size = len(ModuleDict[module1])
            newModule1Size = Module1Size/resize
            for i in range(newModule1Size):
                outline = str(module1) + '::' + str(i)
                for module2 in modules:
#                    print ModuleModuleAverageCorrelationDict[module1][module2], str(ModuleModuleAverageCorrelationDict[module1][module2])[0:6]
                    Module2Size = len(ModuleDict[module2])
                    newModule2Size = Module2Size/resize
                    for j in range(newModule2Size):
                        outline = outline + '\t' + str(ModuleModuleAverageCorrelationDict[module1][module2])[0:6]
                outfile_average.write(outline+'\n')
        else:
            for gene1 in ModuleDict[module1]:
                outline = str(module1) + '::' + gene1
                i+=1
                if i % 100 == 0:
                    print i, 'genes processed in average correlation output'
                for module2 in modules:
                    for gene2 in ModuleDict[module2]:
                        outline = outline + '\t' + str(ModuleModuleAverageCorrelationDict[module1][module2])[0:6]
                outfile_average.write(outline+'\n')
	
    outfile_all.close()
    outfile_average.close()

run()
