##################################
#                                #
# Last modified 03/17/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import numpy
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s junctions <junctions counts> <expression table> <gene name field ID> <junctions counts to expression values correspondence table> outfilename' % sys.argv[0]
        print '   junctions format:'
        print '   chrY	9175621	9196544	+	known exon to known exon, different genes	55.0	TSPY4	TSPY8	novel	GT|AG'
        print '   junctions_counts format: chr, left, right, strand in first 4 fields'
        print '   junctions counts to expression values correspondence table foarmt: replicated field ID in junctions table <tab> replicated field ID in expression table <tab>'
        sys.exit(1)

    junctions = sys.argv[1]
    junctions_counts = sys.argv[2]
    expression = sys.argv[3]
    geneNameID = int(sys.argv[4])
    junctions_to_expression = sys.argv[5]
    outfilename = sys.argv[6]

    JunctionsToExprDict={}
    ExprToJunctionsDict={}

    lineslist = open(junctions_to_expression)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        JunctionsID = int(fields[0])
        ExpressionID = int(fields[1])
        JunctionsToExprDict[JunctionsID] = ExpressionID
        ExprToJunctionsDict[ExpressionID] = JunctionsID

    JunctionDict={}
    GeneDict={}
    lineslist = open(junctions)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr=fields[0]
        left=fields[1]
        right=fields[2]
        strand=fields[3]
        JunctionDict[(chr,left,right,strand)] = {}
        gene1 = fields[6]
        gene2 = fields[7]
        GeneDict[gene1]={}
        GeneDict[gene2]={}

    lineslist = open(junctions_counts)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr=fields[0]
        left=fields[1]
        right=fields[2]
        strand=fields[3]
        if JunctionDict.has_key((chr,left,right,strand)):
            pass
        else:
            continue
        for ID in JunctionsToExprDict.keys():
            JunctionDict[(chr,left,right,strand)][ID] = int(fields[ID])

    lineslist = open(expression)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        geneName=fields[geneNameID]
        if GeneDict.has_key(geneName):
            pass
        else:
            continue
        for ID in ExprToJunctionsDict.keys():
            GeneDict[geneName][ID] = float(fields[ID])
            

    outfile = open(outfilename, 'w')

    lineslist = open(junctions)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr=fields[0]
        left=fields[1]
        right=fields[2]
        strand=fields[3]
        gene1 = fields[6]
        gene2 = fields[7]
        fragment_counts = []
        minFPKMs = []
        if len(GeneDict[gene1].keys()) == 0 or len(GeneDict[gene2].keys()) == 0:
            continue
        for ID in JunctionsToExprDict.keys():
            fragment_counts.append(JunctionDict[(chr,left,right,strand)][ID])
#            print gene1, gene2, ID, JunctionsToExprDict[ID]
#            print GeneDict[gene1]
#            print GeneDict[gene2]
            gene1FPKM = GeneDict[gene1][JunctionsToExprDict[ID]]
            gene2FPKM = GeneDict[gene2][JunctionsToExprDict[ID]]
            minFPKMs.append(min(gene1FPKM,gene2FPKM))
        fragment_counts = numpy.array(fragment_counts)
        minFPKMs = numpy.array(minFPKMs)
#        print fragment_counts
#        print minFPKMs
        correlation = numpy.corrcoef(fragment_counts,minFPKMs)
        outline = line.strip() + '\t'  + str(correlation[0,1])
        print outline
        outfile.write(outline + '\n')
 
    outfile.close()

run()
