##################################
#                                #
# Last modified 11/17/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB
from commoncode import *

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s rdsfilename bed chrfieldID outputfilename [-withmulti] [-cache size]' % sys.argv[0]

        sys.exit(1)
    
    hitfilename = sys.argv[1]
    bed=sys.argv[2]
    chrfieldID=int(sys.argv[3])
    outfilename = sys.argv[4]
    outfile = open(outfilename, 'w')

    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])

    doMulti=False
    if '-withmulti' in sys.argv:
        doMulti=True

    hitRDS = readDataset(hitfilename, verbose = True, cache=True)

    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)

    metadata = hitRDS.getMetadata()
    readlen = int(metadata['readsize'])
    dataType = metadata['dataType']
    readlenRange = range(readlen)

    normalizeBy = (len(hitRDS) + len(hitRDS))/1000000.
    print 'normalizing factor:', normalizeBy 

    linelist = open(bed)
    outline='chr\tleft\tright\tstrand\tsenseRPKM\tanti-senseRPKM'
    outfile.write(outline+'\n')
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr=fields[chrfieldID]
        left = int(fields[chrfieldID+1])
        right = int(fields[chrfieldID+2])
        strand = fields[chrfieldID+3]
        forwardCounts=hitRDS.getCounts(chrom=chr, rmin=left, rmax=right, uniqs=True, multi=doMulti, splices=True, reportCombined=True,sense='+')
        reverseCounts=hitRDS.getCounts(chrom=chr, rmin=left, rmax=right, uniqs=True, multi=doMulti, splices=True, reportCombined=True,sense='-')
        forwardRPKM=forwardCounts/(((right-left)/1000.)*normalizeBy)
        reverseRPKM=reverseCounts/(((right-left)/1000.)*normalizeBy)

        if strand == '+':
            outline = '%s\t%s\t%s\t%s\t%s\t%s\t' % (chr, left, right, strand, forwardRPKM, reverseRPKM)
        if strand == '-':
            outline = '%s\t%s\t%s\t%s\t%s\t%s\t' % (chr, left, right, strand, reverseRPKM, forwardRPKM)
        outfile.write(outline+'\n')

    outfile.close()
   
run()

