##################################
#                                #
# Last modified 11/15/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import math
import numpy as np

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s Expression_table gene_fieldID FPKM_fieldIDs outfilename [-minExpression value]' % sys.argv[0]
        print '\tNote: if you use the -minExpression option, genes which are expressed below that threshold in all samples, will be excluded'
        print '\tNote: FPKM_fieldIDs should be a combination of commas, and from:to (including both ends) values'
        sys.exit(1)

    expression_table = sys.argv[1]
    geneFieldID = int(sys.argv[2])
    fields = sys.argv[3].split(',')
    FPKMFieldIDs = []
    for ID in fields:
        if ':' in ID:
            start = int(ID.split(':')[0])
            end = int(ID.split(':')[1])
            for i in range(start,end+1):
                FPKMFieldIDs.append(i)
        else:
            FPKMFieldIDs.append(int(ID))
    FPKMFieldIDs = list(Set(FPKMFieldIDs))
    FPKMFieldIDs.sort()
    print 'FPKM fields:', FPKMFieldIDs
    outfilename = sys.argv[4]

    doMinExpression = False
    if '-minExpression' in sys.argv:
        doMinExpression = True
        minExpressionCutoff = float(sys.argv[sys.argv.index('-minExpression')+1])

    GeneDict = {}

    linelist = open(expression_table)
    for line in linelist:
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        fields = line.strip().split('\t')
        gene = fields[geneFieldID]
        expression_values = []
        for ID in FPKMFieldIDs:
            expression_values.append(float(fields[ID]))
        if doMinExpression:
            if minExpressionCutoff > max(expression_values):
                continue
        expression_values = np.array(expression_values)
        GeneDict[gene] = expression_values

    genes = GeneDict.keys()
    genes.sort()

    print 'retained', len(genes), 'genes'

    outfile = open(outfilename, 'w')

    outline = '#'
    for gene in genes:
        outline = outline + '\t' + gene
    outfile.write(outline+'\n')

    i = 0
    for gene1 in genes:
        i+=1
        if i % 100 == 0:
            print i, 'genes with expression values higher than', minExpressionCutoff, 'processed'
        outline = gene1
        for gene2 in genes:
            outline = outline + '\t' + str(np.corrcoef(GeneDict[gene1],GeneDict[gene2])[0,1])[0:4]
        outfile.write(outline+'\n')
    outfile.close()

run()
