##################################
#                                #
# Last modified 11/08/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from commoncode import *

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s bedfilename rdsfilename chrFieldID outputfilename [-RPM] [-nomulti] [-nounique] [-rawreadNumber] [-chromField fieldID] [-control controlrdsfilename] [-cache size] [-singlecoordinate radius]' % sys.argv[0]

        sys.exit(1)
    
    bedfile = sys.argv[1]
    hitfilename = sys.argv[2]
    fieldID = int(sys.argv[3])
    outfilename = sys.argv[4]

    outfile = open(outfilename, 'w')

    doMulti=True
    if '-nomulti' in sys.argv:
        doMulti=False
        print 'will not count multi reads'

    doUnique=True
    if '-nounique' in sys.argv:
        doUnique=False
        print 'will not count unique reads'

    doRaw=False
    if '-rawreadNumber' in sys.argv:
        doRaw = True

    doRPM = False
    if '-RPM' in sys.argv:
        doRPM = True

    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])
    hitRDS = readDataset(hitfilename, verbose = True, cache=doCache)

    doSC=False
    if '-singlecoordinate' in sys.argv:
        doSC=True
        radius = int(sys.argv[sys.argv.index('-singlecoordinate') + 1])

    doControl=False
    if '-control' in sys.argv:
        doControl=True
        controlrdsfilename = sys.argv[sys.argv.index('-control') + 1]
        hitctrlRDS = readDataset(controlrdsfilename, verbose = True, cache=doCache)
        ctrlnormalizeBy = len(hitctrlRDS) / 1000000.

    #sqlite default_cache_size is 2000 pages
    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)
        if doControl:
            if cachePages > hitctrlRDS.getDefaultCacheSize():
                hitctrlRDS.setDBcache(cachePages)

    metadata = hitRDS.getMetadata()
    try:
        readlen = int(metadata['readsize'])
    except:
        readlen=36
        print 'warning: read length not specified in RDS file, assuming it is 36'
    dataType = metadata['dataType']
    readlenRange = range(readlen)

    normalizeBy = len(hitRDS)/1000000.
    print normalizeBy 

    listoflines = open(bedfile)
    lineslist = listoflines.readlines()
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        if doSC:
            rmin=int(fields[fieldID+1])-radius
            rmax=int(fields[fieldID+1])+radius
        else:
            rmin=int(fields[fieldID+1])
            rmax=int(fields[fieldID+2])
        chr=fields[fieldID]
        if rmax==rmin:
            continue
        value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=doUnique, multi=doMulti, splices=True, reportCombined=True)
        if doRaw:
            outline = '%s\t%s\t' % (line.strip(), value)
        else:
            if doControl:
                ctrlvalue=hitctrlRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=doUnique, multi=doMulti, splices=True, reportCombined=True)
            if doRPM:
                RPM=value/normalizeBy
                outline = '%s\t%s\t' % (line.strip(), RPM)
                if doControl:
                    ctrlRPM=(1.+ctrlvalue)/ctrlnormalizeBy
                    outline = '%s\t%s\t%s\t%s\t' % (line.strip(), RPM, ctrlRPM, RPM/ctrlRPM)
            else:
                RPKM=value/(normalizeBy*(rmax-rmin)/1000.0)
                outline = '%s\t%s\t' % (line.strip(), RPKM)
                if doControl:
                    ctrlRPKM=(1.+ctrlvalue)/(ctrlnormalizeBy*(rmax-rmin)/1000.0)
                    outline = '%s\t%s\t%s\t%s\t' % (line.strip(), RPKM, ctrlRPKM, RPKM/ctrlRPKM)
        outfile.write(outline + '\n')

    outfile.close()
   
run()
