##################################
#                                #
# Last modified 03/11/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import numpy
from commoncode import *

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s refFlat-cleared list-of-rdsfiles genes.expr radius binsize outputfilename [-cache size] [-control list-of-control-rds]' % sys.argv[0]
        print "        refFlat-cleared is the output of the refFlat-geneClearance.py.py script, genes.expr from Cufflinks"
        print "        list of rds files format: label <tab> rds"
        sys.exit(1)
		
    genes = sys.argv[1]
    ChIP_files = sys.argv[2]
    expr = sys.argv[3]
    radius = int(sys.argv[4])
    binsize = int(sys.argv[5])
    outfile = open(sys.argv[6],'w')

    if radius % binsize != 0:
        print 'the radius has to be divisible by the binsize'
        sys.exit(1)
        
    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])

    doControl=False
    if '-control' in sys.argv:
        doControl=True
        Control_ChIP_files = sys.argv[sys.argv.index('-control') + 1]
    
    ExprDict={}
    lineslist=open(expr)
    for line in lineslist:
        if line.startswith('gene_id'):
            continue
        fields=line.split()
        geneID=fields[0]
        FPKM=float(fields[5])
        ExprDict[geneID]=FPKM
        
    GeneDict={}
    lineslist=open(genes)
    for line in lineslist:
        if line.startswith('gene_id'):
            continue
        fields=line.strip().split('\t')
        gene=fields[0]
        chr=fields[1]
        left=int(fields[2])
        right=int(fields[3])
        strand=fields[4]
        IDs=fields[5].split(',')
        FPKM=0
        for ID in IDs:
            try:
                FPKM+=ExprDict[ID]
            except:
                continue
        GeneDict[gene]={}
        GeneDict[gene]['coordinates']=(chr,left,right,strand)
        GeneDict[gene]['FPKM']=FPKM
        GeneDict[gene]['ChIP_Scores']={}
        if doControl: 
            GeneDict[gene]['Control_Scores']={}

    ChIPDatasets=[]        
    lineslist=open(ChIP_files)
    for line in lineslist:
        fields=line.strip().split('\t')
        print fields
        label=fields[0]
        hitfilename=fields[1]
        try:
            hitRDS = readDataset(hitfilename, verbose = True, cache=True)
            if cachePages > hitRDS.getDefaultCacheSize():
                hitRDS.setDBcache(cachePages)
            normalizeBy = len(hitRDS) / 1000000.
        except:
            continue
        ChIPDatasets.append(label)
        for gene in GeneDict.keys():
            GeneDict[gene]['ChIP_Scores'][label]=[]
            (chr,left,right,strand)=GeneDict[gene]['coordinates']
            if strand == '+':
                for i in range(left-radius,left+radius,binsize):
                    v1=1. + hitRDS.getCounts(chrom=chr, rmin=i, rmax=i+binsize, uniqs=True, multi=True, splices=False, reportCombined=True)
                    RPM=v1/normalizeBy
                    GeneDict[gene]['ChIP_Scores'][label].append(RPM)
            if strand == '-':
                for i in range(right-radius,right+radius,binsize):
                    v1=1. + hitRDS.getCounts(chrom=chr, rmin=i, rmax=i+binsize, uniqs=True, multi=True, splices=False, reportCombined=True)
                    RPM=v1/normalizeBy
                    GeneDict[gene]['ChIP_Scores'][label].append(RPM)
                GeneDict[gene]['ChIP_Scores'][label].reverse()

    ControlDatasets=[]
    if doControl:
        lineslist=open(Control_ChIP_files)
        for line in lineslist:
            fields=line.strip().split('\t')
            print fields
            label=fields[0]
            hitfilename=fields[1]
            try:
                if cachePages > hitRDS.getDefaultCacheSize():
                    hitRDS.setDBcache(cachePages)
                hitRDS = readDataset(hitfilename, verbose = True, cache=True)
                normalizeBy = len(hitRDS) / 1000000.
            except:
                continue
            ControlDatasets.append(label)
            for gene in GeneDict.keys():
                GeneDict[gene]['Control_Scores'][label]=[]
                (chr,left,right,strand)=GeneDict[gene]['coordinates']
                if strand == '+':
                    for i in range(left-radius,left+radius,binsize):
                        v1=1. + hitRDS.getCounts(chrom=chr, rmin=i, rmax=i+binsize, uniqs=True, multi=True, splices=False, reportCombined=True)
                        RPM=v1/normalizeBy
                        GeneDict[gene]['Control_Scores'][label].append(RPM)
                if strand == '-':
                    for i in range(right-radius,right+radius,binsize):
                        v1=1. + hitRDS.getCounts(chrom=chr, rmin=i, rmax=i+binsize, uniqs=True, multi=True, splices=False, reportCombined=True)
                        RPM=v1/normalizeBy
                        GeneDict[gene]['Control_Scores'][label].append(RPM)
                    GeneDict[gene]['Control_Scores'][label].reverse()

    Genes=GeneDict.keys()
    Genes.sort()

    ExpressionVector=[]
    for gene in Genes:
        ExpressionVector.append(GeneDict[gene]['FPKM'])

    numBins=(2.0*radius)/binsize
        
    outfile.write('#No Control:\n')
    for ChIP in ChIPDatasets:
        outline=ChIP
        print outline
        for i in range(numBins):
            ChIPvector=[]
            for gene in Genes:
                ChIPvector.append(GeneDict[gene]['ChIP_Scores'][ChIP][i])
            corr=numpy.corrcoef(ChIPvector,ExpressionVector)
            if corr[0,1] > 1 or corr[0,1] < -1:
                 print 'correlation greater than 1: ', corr, 'setting to zero'
                 outline=outline+'\t0'
            else:
                 outline=outline+'\t'+str(corr[0,1])[0:9]
        print outline
        outfile.write(outline+'\n')

    if doControl:
        for Control in ControlDatasets:
            outfile.write('#Using Control Dataset ' + Control +'\n')
            for ChIP in ChIPDatasets:
                outline=ChIP
                for i in range(numBins):
                    ChIPvector=[]
                    for gene in Genes:
                        score=max((GeneDict[gene]['ChIP_Scores'][ChIP][i]-GeneDict[gene]['Control_Scores'][Control][i]),0)
                        ChIPvector.append(score)
                    corr=numpy.corrcoef(ChIPvector,ExpressionVector)
                    if corr[0,1] > 1 or corr[0,1] < -1:
                        print 'correlation greater than 1: ', corr, 'setting to zero'
                        outline=outline+'\t0'
                    else:
                        outline=outline+'\t'+str(corr[0,1])[0:9]
                print outline
                outfile.write(outline+'\n')
    outfile.close()
   
run()
