"""
    usage: python $ERANGEPATH/findall.py label samplerdsfile regionoutfile
           [--control controlrdsfile] [--minimum minHits] [--ratio minRatio]
           [--spacing maxSpacing] [--listPeak] [--shift #bp | learn] [--learnFold num]
           [--noshift] [--autoshift] [--reportshift] [--nomulti] [--minPlus fraction]
           [--maxPlus fraction] [--leftPlus fraction] [--minPeak RPM] [--raw]
           [--revbackground] [--pvalue self|back|none] [--nodirectionality]
           [--strandfilter plus/minus] [--trimvalue percent] [--notrim]
           [--cache pages] [--log altlogfile] [--flag aflag] [--append] [--RNA]

           where values in brackets are optional and label is an arbitrary string.

           Use -ratio (default 4 fold) to set the minimum fold enrichment
           over the control, -minimum (default 4) is the minimum number of reads
           (RPM) within the region, and -spacing (default readlen) to set the maximum
           distance between reads in the region. -listPeak lists the peak of the
           region. Peaks mut be higher than -minPeak (default 0.5 RPM).
           Pvalues are calculated from the sample (change with -pvalue),
           unless the -revbackground flag and a control RDS file are provided.

           By default, all numbers and parameters are on a reads per
           million (RPM) basis. -raw will treat all settings, ratios and reported
           numbers as raw counts rather than RPM. Use -notrim to turn off region
           trimming and -trimvalue to control trimming (default 10% of peak signal)

           The peak finder uses minimal directionality information that can
           be turned off with -nodirectionality ; the fraction of + strand reads
           required to be to the left of the peak (default 0.3) can be set with
           -leftPlus ; -minPlus and -maxPlus change the minimum/maximum fraction
           of plus reads in a region, which (defaults 0.25 and 0.75, respectively).

           Use -shift to shift reads either by half the expected
           fragment length (default 0 bp) or '-shift learn ' to learn the shift
           based on the first chromosome. If you prefer to learn the shift
           manually, use -autoshift to calculate a per-region shift value, which
           can be reported using -reportshift. -strandfilter should only be used
           when explicitely calling unshifted stranded peaks from non-ChIP-seq
           data such as directional RNA-seq. regionoutfile is written over by
           default unless given the -append flag.
"""

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import math
import string
import optparse
import operator
from commoncode import writeLog, findPeak, getBestShiftForRegion, getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption, getConfigBoolOption
import ReadDataset
import Region


versionString = "findall: version 3.2.1"
print versionString

class RegionDirectionError(Exception):
    pass
            

class RegionFinder():
    def __init__(self, label, minRatio=4.0, minPeak=0.5, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, strandfilter="",
                 minHits=4.0, trimValue=0.1, doTrim=True, doDirectionality=True, shiftValue=0, maxSpacing=50, withFlag="",
                 normalize=True, listPeak=False, reportshift=False, stringency=1.0):

        self.statistics = {"index": 0,
                           "total": 0,
                           "mIndex": 0,
                           "mTotal": 0,
                           "failed": 0,
                           "badRegionTrim": 0
        }

        self.regionLabel = label
        self.rnaSettings = False
        self.controlRDSsize = 1
        self.sampleRDSsize = 1
        self.minRatio = minRatio
        self.minPeak = minPeak
        self.leftPlusRatio = leftPlusRatio
        self.stranded = "both"
        if strandfilter == "plus":
            self.stranded = "+"
            minPlusRatio = 0.9
            maxPlusRatio = 1.0
        elif strandfilter == "minus":
            self.stranded = "-"
            minPlusRatio = 0.0
            maxPlusRatio = 0.1

        if minRatio < minPeak:
            self.minPeak = minRatio

        self.minPlusRatio = minPlusRatio
        self.maxPlusRatio = maxPlusRatio
        self.strandfilter = strandfilter
        self.minHits = minHits
        self.trimValue = trimValue
        self.doTrim = doTrim
        self.doDirectionality = doDirectionality

        if self.doTrim:
            self.trimString = string.join(["%2.1f" % (100. * self.trimValue), "%"], "")
        else:
            self.trimString = "none"

        self.shiftValue = shiftValue
        self.maxSpacing = maxSpacing
        self.withFlag = withFlag
        self.normalize = normalize
        self.listPeak = listPeak
        self.reportshift = reportshift
        self.stringency = max(stringency, 1.0)


    def useRNASettings(self, readlen):
        self.rnaSettings = True
        self.shiftValue = 0
        self.doTrim = False
        self.doDirectionality = False
        self.maxSpacing = readlen


    def getHeader(self, doPvalue):
        if self.normalize:
            countType = "RPM"
        else:
            countType = "COUNT"

        headerFields = ["#regionID\tchrom\tstart\tstop", countType, "fold\tmulti%"]

        if self.doDirectionality:
            headerFields.append("plus%\tleftPlus%")

        if self.listPeak:
            headerFields.append("peakPos\tpeakHeight")

        if self.reportshift:
            headerFields.append("readShift")

        if doPvalue:
            headerFields.append("pValue")

        return string.join(headerFields, "\t")


    def printSettings(self, doRevBackground, ptype, doControl, useMulti, doCache, pValueType):
        print
        self.printStatusMessages(doRevBackground, ptype, doControl, useMulti)
        self.printOptionsSummary(useMulti, doCache, pValueType)


    def printStatusMessages(self, doRevBackground, ptype, doControl, useMulti):
        if self.shiftValue == "learn":
            print "Will try to learn shift"

        if self.normalize:
            print "Normalizing to RPM"

        if doRevBackground:
            print "Swapping IP and background to calculate FDR"

        if ptype != "":
            if ptype in ["NONE", "SELF"]:
                pass
            elif ptype == "BACK":
                if doControl and doRevBackground:
                    pass
                else:
                    print "must have a control dataset and -revbackground for pValue type 'back'"
            else:
                print "could not use pValue type : %s" % ptype

        if self.withFlag != "":
            print "restrict to flag = %s" % self.withFlag

        if not useMulti:
            print "using unique reads only"

        if self.rnaSettings:
            print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"

        if self.strandfilter == "plus":
            print "only analyzing reads on the plus strand"
        elif self.strandfilter == "minus":
            print "only analyzing reads on the minus strand"


    def printOptionsSummary(self, useMulti, doCache, pValueType):

        print "\nenforceDirectionality=%s listPeak=%s nomulti=%s cache=%s " % (self.doDirectionality, self.listPeak, not useMulti, doCache)
        print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded)
        try:
            print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
        except:
            print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)


    def getAnalysisDescription(self, hitfile, useMulti, doCache, pValueType, controlfile, doControl):

        description = ["#ERANGE %s" % versionString]
        if doControl:
            description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample:\t%s (%.1f M reads)" % (hitfile, self.sampleRDSsize, controlfile, self.controlRDSsize))
        else:
            description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample: none" % (hitfile, self.sampleRDSsize))

        if self.withFlag != "":
            description.append("#restrict to Flag = %s" % self.withFlag)

        description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s cache=%s" % (self.doDirectionality, self.listPeak, not useMulti, doCache))
        description.append("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded))
        try:
            description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
        except:
            description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))

        return string.join(description, "\n")


    def updateControlStatistics(self, peak, sumAll, peakScore):

        plusRatio = float(peak.numPlus)/peak.numHits
        if peakScore >= self.minPeak and self.minPlusRatio <= plusRatio <= self.maxPlusRatio:
            if self.doDirectionality:
                if self.leftPlusRatio < peak.numLeftPlus / peak.numPlus:
                    self.statistics["mIndex"] += 1
                    self.statistics["mTotal"] += sumAll
                else:
                    self.statistics["failed"] += 1
            else:
                # we have a region, but didn't check for directionality
                self.statistics["mIndex"] += 1
                self.statistics["mTotal"] += sumAll


def usage():
    print __doc__


def main(argv=None):
    if not argv:
        argv = sys.argv

    parser = makeParser()
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        usage()
        sys.exit(2)

    factor = args[0]
    hitfile = args[1]
    outfilename = args[2]

    shiftValue = 0

    if options.autoshift:
        shiftValue = "auto"

    if options.shift is not None:
        try:
            shiftValue = int(options.shift)
        except ValueError:
            if options.shift == "learn":
                shiftValue = "learn"

    if options.noshift:
        shiftValue = 0

    if options.doAppend:
        outputMode = "a"
    else:
        outputMode = "w"

    regionFinder = RegionFinder(factor, minRatio=options.minRatio, minPeak=options.minPeak, minPlusRatio=options.minPlusRatio,
                                maxPlusRatio=options.maxPlusRatio, leftPlusRatio=options.leftPlusRatio, strandfilter=options.strandfilter,
                                minHits=options.minHits, trimValue=options.trimValue, doTrim=options.doTrim,
                                doDirectionality=options.doDirectionality, shiftValue=shiftValue, maxSpacing=options.maxSpacing,
                                withFlag=options.withFlag, normalize=options.normalize, listPeak=options.listPeak,
                                reportshift=options.reportshift, stringency=options.stringency)

    findall(regionFinder, hitfile, outfilename, options.logfilename, outputMode, options.rnaSettings,
            options.cachePages, options.ptype, options.controlfile, options.doRevBackground,
            options.useMulti, options.combine5p)


def makeParser():
    usage = __doc__

    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--control", dest="controlfile")
    parser.add_option("--minimum", type="float", dest="minHits")
    parser.add_option("--ratio", type="float", dest="minRatio")
    parser.add_option("--spacing", type="int", dest="maxSpacing")
    parser.add_option("--listPeak", action="store_true", dest="listPeak")
    parser.add_option("--shift", dest="shift")
    parser.add_option("--learnFold", type="float", dest="stringency")
    parser.add_option("--noshift", action="store_true", dest="noShift")
    parser.add_option("--autoshift", action="store_true", dest="autoshift")
    parser.add_option("--reportshift", action="store_true", dest="reportshift")
    parser.add_option("--nomulti", action="store_false", dest="useMulti")
    parser.add_option("--minPlus", type="float", dest="minPlusRatio")
    parser.add_option("--maxPlus", type="float", dest="maxPlusRatio")
    parser.add_option("--leftPlus", type="float", dest="leftPlusRatio")
    parser.add_option("--minPeak", type="float", dest="minPeak")
    parser.add_option("--raw", action="store_false", dest="normalize")
    parser.add_option("--revbackground", action="store_true", dest="doRevBackground")
    parser.add_option("--pvalue", dest="ptype")
    parser.add_option("--nodirectionality", action="store_false", dest="doDirectionality")
    parser.add_option("--strandfilter", dest="strandfilter")
    parser.add_option("--trimvalue", type="float", dest="trimValue")
    parser.add_option("--notrim", action="store_false", dest="doTrim")
    parser.add_option("--cache", type="int", dest="cachePages")
    parser.add_option("--log", dest="logfilename")
    parser.add_option("--flag", dest="withFlag")
    parser.add_option("--append", action="store_true", dest="doAppend")
    parser.add_option("--RNA", action="store_true", dest="rnaSettings")
    parser.add_option("--combine5p", action="store_true", dest="combine5p")

    configParser = getConfigParser()
    section = "findall"
    minHits = getConfigFloatOption(configParser, section, "minHits", 4.0)
    minRatio = getConfigFloatOption(configParser, section, "minRatio", 4.0)
    maxSpacing = getConfigIntOption(configParser, section, "maxSpacing", 50)
    listPeak = getConfigBoolOption(configParser, section, "listPeak", False)
    shift = getConfigOption(configParser, section, "shift", None)
    stringency = getConfigFloatOption(configParser, section, "stringency", 4.0)
    noshift = getConfigBoolOption(configParser, section, "noshift", False)
    autoshift = getConfigBoolOption(configParser, section, "autoshift", False)
    reportshift = getConfigBoolOption(configParser, section, "reportshift", False)
    minPlusRatio = getConfigFloatOption(configParser, section, "minPlusRatio", 0.25)
    maxPlusRatio = getConfigFloatOption(configParser, section, "maxPlusRatio", 0.75)
    leftPlusRatio = getConfigFloatOption(configParser, section, "leftPlusRatio", 0.3)
    minPeak = getConfigFloatOption(configParser, section, "minPeak", 0.5)
    normalize = getConfigBoolOption(configParser, section, "normalize", True)
    logfilename = getConfigOption(configParser, section, "logfilename", "findall.log")
    withFlag = getConfigOption(configParser, section, "withFlag", "")
    doDirectionality = getConfigBoolOption(configParser, section, "doDirectionality", True)
    trimValue = getConfigFloatOption(configParser, section, "trimValue", 0.1)
    doTrim = getConfigBoolOption(configParser, section, "doTrim", True)
    doAppend = getConfigBoolOption(configParser, section, "doAppend", False)
    rnaSettings = getConfigBoolOption(configParser, section, "rnaSettings", False)
    cachePages = getConfigOption(configParser, section, "cachePages", None)
    ptype = getConfigOption(configParser, section, "ptype", "")
    controlfile = getConfigOption(configParser, section, "controlfile", None)
    doRevBackground = getConfigBoolOption(configParser, section, "doRevBackground", False)
    useMulti = getConfigBoolOption(configParser, section, "useMulti", True)
    strandfilter = getConfigOption(configParser, section, "strandfilter", "")
    combine5p = getConfigBoolOption(configParser, section, "combine5p", False)

    parser.set_defaults(minHits=minHits, minRatio=minRatio, maxSpacing=maxSpacing, listPeak=listPeak, shift=shift,
                        stringency=stringency, noshift=noshift, autoshift=autoshift, reportshift=reportshift,
                        minPlusRatio=minPlusRatio, maxPlusRatio=maxPlusRatio, leftPlusRatio=leftPlusRatio, minPeak=minPeak,
                        normalize=normalize, logfilename=logfilename, withFlag=withFlag, doDirectionality=doDirectionality,
                        trimValue=trimValue, doTrim=doTrim, doAppend=doAppend, rnaSettings=rnaSettings,
                        cachePages=cachePages, ptype=ptype, controlfile=controlfile, doRevBackground=doRevBackground, useMulti=useMulti,
                        strandfilter=strandfilter, combine5p=combine5p)

    return parser


def findall(regionFinder, hitfile, outfilename, logfilename="findall.log", outputMode="w", rnaSettings=False, cachePages=None,
            ptype="", controlfile=None, doRevBackground=False, useMulti=True, combine5p=False):

    writeLog(logfilename, versionString, string.join(sys.argv[1:]))
    doCache = cachePages is not None
    controlRDS = None
    doControl = controlfile is not None
    if doControl:
        print "\ncontrol:" 
        controlRDS = openRDSFile(controlfile, cachePages=cachePages, doCache=doCache)
        regionFinder.controlRDSsize = len(controlRDS) / 1000000.

    print "\nsample:" 
    hitRDS = openRDSFile(hitfile, cachePages=cachePages, doCache=doCache)
    regionFinder.sampleRDSsize = len(hitRDS) / 1000000.
    pValueType = getPValueType(ptype, doControl, doRevBackground)
    doPvalue = not pValueType == "none"
    regionFinder.readlen = hitRDS.getReadSize()
    if rnaSettings:
        regionFinder.useRNASettings(regionFinder.readlen)

    regionFinder.printSettings(doRevBackground, ptype, doControl, useMulti, doCache, pValueType)
    outfile = open(outfilename, outputMode)
    header = writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl)
    shiftDict = {}
    chromosomeList = getChromosomeListToProcess(hitRDS, controlRDS, doControl)
    for chromosome in chromosomeList:
        if regionFinder.shiftValue == "learn":
            learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)

        allregions, outregions = findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
        if doRevBackground:
            backregions = findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti)
            writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header, backregions=backregions, pValueType=pValueType)
        else:
            writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header)

    footer = getFooter(regionFinder, shiftDict, doRevBackground)
    print footer
    print >> outfile, footer
    outfile.close()
    writeLog(logfilename, versionString, outfilename + footer.replace("\n#"," | ")[:-1])


def getPValueType(ptype, doControl, doRevBackground):
    pValueType = "self"
    if ptype in ["NONE", "SELF", "BACK"]:
        if ptype == "NONE":
            pValueType = "none"
        elif ptype == "SELF":
            pValueType = "self"
        elif ptype == "BACK":
            if doControl and doRevBackground:
                pValueType = "back"
    elif doRevBackground:
        pValueType = "back"

    return pValueType


def openRDSFile(filename, cachePages=None, doCache=False):
    rds = ReadDataset.ReadDataset(filename, verbose=True, cache=doCache)
    if cachePages > rds.getDefaultCacheSize():
        rds.setDBcache(cachePages)

    return rds


def writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl):
    print >> outfile, regionFinder.getAnalysisDescription(hitfile, useMulti, doCache, pValueType, controlfile, doControl)
    header = regionFinder.getHeader(doPvalue)
    print >> outfile, header

    return header


def getChromosomeListToProcess(hitRDS, controlRDS=None, doControl=False):
    hitChromList = hitRDS.getChromosomes()
    if doControl:
        controlChromList = controlRDS.getChromosomes()
        chromosomeList = [chrom for chrom in hitChromList if chrom in controlChromList and chrom != "chrM"]
    else:
        chromosomeList = [chrom for chrom in hitChromList if chrom != "chrM"]

    return chromosomeList


def findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename,
                    outfile, useMulti, doControl, controlRDS, combine5p):

    outregions = []
    allregions = []
    print "chromosome %s" % (chromosome)
    previousHit = - 1 * regionFinder.maxSpacing
    readStartPositions = [-1]
    totalWeight = 0
    uniqueReadCount = 0
    reads = []
    numStarts = 0
    badRegion = False
    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
                                  strand=regionFinder.stranded, combine5p=combine5p)

    maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
    for read in hitDict[chromosome]:
        pos = read["start"]
        if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
            lastReadPos = readStartPositions[-1]
            lastBasePosition = lastReadPos + regionFinder.readlen - 1
            newRegionIndex = regionFinder.statistics["index"] + 1
            if regionFinder.doDirectionality:
                region = Region.DirectionalRegion(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel,
                                                  numReads=totalWeight)
            else:
                region = Region.Region(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel, numReads=totalWeight)

            if regionFinder.normalize:
                region.numReads /= regionFinder.sampleRDSsize

            allregions.append(int(region.numReads))
            regionLength = lastReadPos - region.start
            if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
                region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)

                if region.foldRatio >= regionFinder.minRatio:
                    # first pass, with absolute numbers
                    peak = findPeak(reads, region.start, regionLength, regionFinder.readlen, doWeight=True, leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
                    if regionFinder.doTrim:
                        try:
                            lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.sampleRDSsize)
                        except IndexError:
                            badRegion = True
                            continue

                        region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)

                    # just in case it changed, use latest data
                    try:
                        bestPos = peak.topPos[0]
                        peakScore = peak.smoothArray[bestPos]
                        if regionFinder.normalize:
                            peakScore /= regionFinder.sampleRDSsize
                    except:
                        continue

                    if regionFinder.listPeak:
                        region.peakDescription= "%d\t%.1f" % (region.start + bestPos, peakScore)

                    if useMulti:
                        setMultireadPercentage(region, hitRDS, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos,
                                               regionFinder.normalize, regionFinder.doTrim)

                    region.shift = peak.shift
                    # check that we still pass threshold
                    regionLength = lastReadPos - region.start
                    plusRatio = float(peak.numPlus)/peak.numHits
                    if regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
                        try:
                            updateRegion(region, regionFinder.doDirectionality, regionFinder.leftPlusRatio, peak.numLeftPlus, peak.numPlus, plusRatio)
                            regionFinder.statistics["index"] += 1
                            outregions.append(region)
                            regionFinder.statistics["total"] += region.numReads
                        except RegionDirectionError:
                            regionFinder.statistics["failed"] += 1

            readStartPositions = []
            totalWeight = 0
            uniqueReadCount = 0
            reads = []
            numStarts = 0
            if badRegion:
                badRegion = False
                regionFinder.statistics["badRegionTrim"] += 1

        if pos not in readStartPositions:
            numStarts += 1

        readStartPositions.append(pos)
        weight = read["weight"]
        totalWeight += weight
        if weight == 1.0:
            uniqueReadCount += 1

        reads.append({"start": pos, "sense": read["sense"], "weight": weight})
        previousHit = pos

    return allregions, outregions


def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti):
    #TODO: this is *almost* the same calculation - there are small yet important differences
    print "calculating background..."
    previousHit = - 1 * regionFinder.maxSpacing
    currentHitList = [-1]
    currentTotalWeight = 0
    currentReadList = []
    backregions = []
    numStarts = 0
    badRegion = False
    hitDict = controlRDS.getReadsDict(fullChrom=True, chrom=chromosome, withWeight=True, doMulti=useMulti, findallOptimize=True)
    maxCoord = controlRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
    for read in hitDict[chromosome]:
        pos = read["start"]
        if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
            lastReadPos = currentHitList[-1]
            lastBasePosition = lastReadPos + regionFinder.readlen - 1
            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
            if regionFinder.normalize:
                region.numReads /= regionFinder.controlRDSsize

            backregions.append(int(region.numReads))
            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
            regionLength = lastReadPos - region.start
            if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
                numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
                if regionFinder.normalize:
                    numMock /= regionFinder.sampleRDSsize

                foldRatio = region.numReads / numMock
                if foldRatio >= regionFinder.minRatio:
                    # first pass, with absolute numbers
                    peak = findPeak(currentReadList, region.start, lastReadPos - region.start, regionFinder.readlen, doWeight=True,
                                    leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)

                    if regionFinder.doTrim:
                        try:
                            lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, 20., currentReadList, regionFinder.controlRDSsize)
                        except IndexError:
                            badRegion = True
                            continue

                        numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
                        if regionFinder.normalize:
                            numMock /= regionFinder.sampleRDSsize

                        foldRatio = region.numReads / numMock

                    # just in case it changed, use latest data
                    try:
                        bestPos = peak.topPos[0]
                        peakScore = peak.smoothArray[bestPos]
                    except IndexError:
                        continue

                    # normalize to RPM
                    if regionFinder.normalize:
                        peakScore /= regionFinder.controlRDSsize

                    # check that we still pass threshold
                    regionLength = lastReadPos - region.start
                    if regionPassesCriteria(regionFinder, region.numReads, foldRatio, regionLength):
                        regionFinder.updateControlStatistics(peak, region.numReads, peakScore)

            currentHitList = []
            currentTotalWeight = 0
            currentReadList = []
            numStarts = 0
            if badRegion:
                badRegion = False
                regionFinder.statistics["badRegionTrim"] += 1

        if pos not in currentHitList:
            numStarts += 1

        currentHitList.append(pos)
        weight = read["weight"]
        currentTotalWeight += weight
        currentReadList.append({"start": pos, "sense": read["sense"], "weight": weight})
        previousHit = pos

    return backregions


def learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename,
               outfile, useMulti, doControl, controlRDS, combine5p):

    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
                                  strand=regionFinder.stranded, combine5p=combine5p)

    maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
    print "learning shift.... will need at least 30 training sites"
    stringency = regionFinder.stringency
    previousHit = -1 * regionFinder.maxSpacing
    positionList = [-1]
    totalWeight = 0
    readList = []
    shiftDict = {}
    count = 0
    numStarts = 0
    for read in hitDict[chromosome]:
        pos = read["start"]
        if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
            if regionFinder.normalize:
                totalWeight /= regionFinder.sampleRDSsize

            regionStart = positionList[0]
            regionStop = positionList[-1]
            regionLength = regionStop - regionStart
            if regionPassesCriteria(regionFinder, totalWeight, numStarts, regionLength, stringency=stringency):
                foldRatio = getFoldRatio(regionFinder, controlRDS, totalWeight, chromosome, regionStart, regionStop, useMulti, doControl)
                if foldRatio >= regionFinder.minRatio:
                    updateShiftDict(shiftDict, readList, regionStart, regionLength, regionFinder.readlen)
                    count += 1

            positionList = []
            totalWeight = 0
            readList = []

        if pos not in positionList:
            numStarts += 1

        positionList.append(pos)
        weight = read["weight"]
        totalWeight += weight
        readList.append({"start": pos, "sense": read["sense"], "weight": weight})
        previousHit = pos

    outline = "#learn: stringency=%.2f min_signal=%2.f min_ratio=%.2f min_region_size=%d\n#number of training examples: %d" % (stringency,
                                                                                                                               stringency * regionFinder.minHits,
                                                                                                                               stringency * regionFinder.minRatio,
                                                                                                                               stringency * regionFinder.readlen,
                                                                                                                               count)

    print outline
    writeLog(logfilename, versionString, outfilename + outline)
    regionFinder.shiftValue = getShiftValue(shiftDict, count, logfilename, outfilename)
    outline = "#picked shiftValue to be %d" % regionFinder.shiftValue
    print outline
    print >> outfile, outline
    writeLog(logfilename, versionString, outfilename + outline)


def previousRegionIsDone(pos, previousHit, maxSpacing, maxCoord):
    return abs(pos - previousHit) > maxSpacing or pos == maxCoord


def regionPassesCriteria(regionFinder, sumAll, numStarts, regionLength, stringency=1):
    minTotalReads = stringency * regionFinder.minHits
    minNumReadStarts = stringency * regionFinder.minRatio
    minRegionLength = stringency * regionFinder.readlen

    return sumAll >= minTotalReads and numStarts > minNumReadStarts and regionLength > minRegionLength


def trimRegion(region, regionFinder, peak, regionStop, trimValue, currentReadList, totalReadCount):
    bestPos = peak.topPos[0]
    peakScore = peak.smoothArray[bestPos]
    if regionFinder.normalize:
        peakScore /= totalReadCount

    minSignalThresh = trimValue * peakScore
    start = findStartEdgePosition(peak, minSignalThresh)
    regionEndPoint = regionStop - region.start - 1
    stop = findStopEdgePosition(peak, regionEndPoint, minSignalThresh)

    regionStop = region.start + stop
    region.start += start

    trimmedPeak = findPeak(currentReadList, region.start, regionStop - region.start, regionFinder.readlen, doWeight=True,
                           leftPlus=regionFinder.doDirectionality, shift=peak.shift)

    peak.numPlus = trimmedPeak.numPlus
    peak.numLeftPlus = trimmedPeak.numLeftPlus
    peak.topPos = trimmedPeak.topPos
    peak.smoothArray = trimmedPeak.smoothArray

    region.numReads = trimmedPeak.numHits
    if regionFinder.normalize:
        region.numReads /= totalReadCount

    region.stop = regionStop + regionFinder.readlen - 1
                          
    return regionStop


def findStartEdgePosition(peak, minSignalThresh):
    start = 0
    while not peakEdgeLocated(peak, start, minSignalThresh):
        start += 1

    return start


def findStopEdgePosition(peak, stop, minSignalThresh):
    while not peakEdgeLocated(peak, stop, minSignalThresh):
        stop -= 1

    return stop


def peakEdgeLocated(peak, position, minSignalThresh):
    return peak.smoothArray[position] >= minSignalThresh or position == peak.topPos[0]


def getFoldRatio(regionFinder, controlRDS, sumAll, chromosome, regionStart, regionStop, useMulti, doControl):
    """ Fold ratio calculated is total read weight over control
    """
    #TODO: this needs to be generalized as there is a point at which we want to use the sampleRDS instead of controlRDS
    if doControl:
        numMock = 1. + controlRDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
        if regionFinder.normalize:
            numMock /= regionFinder.controlRDSsize

        foldRatio = sumAll / numMock
    else:
        foldRatio = regionFinder.minRatio

    return foldRatio


def updateShiftDict(shiftDict, readList, regionStart, regionLength, readlen):
    peak = findPeak(readList, regionStart, regionLength, readlen, doWeight=True, shift="auto")
    try:
        shiftDict[peak.shift] += 1
    except KeyError:
        shiftDict[peak.shift] = 1


def getShiftValue(shiftDict, count, logfilename, outfilename):
    if count < 30:
        outline = "#too few training examples to pick a shiftValue - defaulting to 0\n#consider picking a lower minimum or threshold"
        print outline
        writeLog(logfilename, versionString, outfilename + outline)
        shiftValue = 0
    else:
        shiftValue = getBestShiftInDict(shiftDict)
        print shiftDict

    return shiftValue


def getRegion(regionStart, regionStop, factor, index, chromosome, sumAll, foldRatio, multiP,
              peakDescription, shift, doDirectionality, leftPlusRatio, numLeft,
              numPlus, plusRatio):

    if doDirectionality:
        if leftPlusRatio < numLeft / numPlus:
            plusP = plusRatio * 100.
            leftP = 100. * numLeft / numPlus
            # we have a region that passes all criteria
            region = Region.DirectionalRegion(regionStart, regionStop,
                                              factor, index, chromosome, sumAll,
                                              foldRatio, multiP, plusP, leftP,
                                              peakDescription, shift)

        else:
            raise RegionDirectionError
    else:
        # we have a region, but didn't check for directionality
        region = Region.Region(regionStart, regionStop, factor, index, chromosome,
                               sumAll, foldRatio, multiP, peakDescription, shift)

    return region


def setMultireadPercentage(region, hitRDS, hitRDSsize, currentTotalWeight, currentUniqueCount, chromosome, lastReadPos, normalize, doTrim):
    if doTrim:
        sumMulti = hitRDS.getMultiCount(chromosome, region.start, lastReadPos)
    else:
        sumMulti = currentTotalWeight - currentUniqueCount

    # normalize to RPM
    if normalize:
        sumMulti /= hitRDSsize

    try:
        multiP = 100. * (sumMulti / region.numReads)
    except ZeroDivisionError:
        return

    region.multiP = min(multiP, 100.)


def regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
    regionPasses = False
    if regionPassesCriteria(regionFinder, region.numReads, region.foldRatio, regionLength):
        if peakScore >= regionFinder.minPeak and regionFinder.minPlusRatio <= plusRatio <= regionFinder.maxPlusRatio:
            regionPasses = True

    return regionPasses


def updateRegion(region, doDirectionality, leftPlusRatio, numLeft, numPlus, plusRatio):

    if doDirectionality:
        if leftPlusRatio < numLeft / numPlus:
            region.plusP = plusRatio * 100.
            region.leftP = 100. * numLeft / numPlus
        else:
            raise RegionDirectionError


def writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
                                allregions, header):

    writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
                           allregions, header, backregions=[], pValueType="self")


def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
                           allregions, header, backregions=[], pValueType="none"):

    print regionFinder.statistics["mIndex"], regionFinder.statistics["mTotal"]
    if doPvalue:
        if pValueType == "self":
            poissonmean = calculatePoissonMean(allregions)
        else:
            poissonmean = calculatePoissonMean(backregions)

    print header
    writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=regionFinder.shiftValue, reportshift=regionFinder.reportshift, shiftDict=shiftDict)


def calculatePoissonMean(dataList):
    dataList.sort()
    listSize = float(len(dataList))
    try:
        poissonmean = sum(dataList) / listSize
    except ZeroDivisionError:
        poissonmean = 0

    print "Poisson n=%d, p=%f" % (listSize, poissonmean)

    return poissonmean


def writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=0, reportshift=False, shiftDict={}):
    for region in outregions:
        if shiftValue == "auto" and reportshift:
            try:
                shiftDict[region.shift] += 1
            except KeyError:
                shiftDict[region.shift] = 1

        outline = getRegionString(region, reportshift)

        # iterative poisson from http://stackoverflow.com/questions/280797?sort=newest
        if doPvalue:
            sumAll = int(region.numReads)
            pValue = calculatePValue(sumAll, poissonmean)
            outline += "\t%1.2g" % pValue

        print outline
        print >> outfile, outline


def calculatePValue(sum, poissonmean):
    pValue = math.exp(-poissonmean)
    #TODO: 798: DeprecationWarning: integer argument expected, got float - for i in xrange(sum)
    for i in xrange(sum):
        pValue *= poissonmean
        pValue /= i+1

    return pValue


def getRegionString(region, reportShift):
    if reportShift:
        outline = region.printRegionWithShift()
    else:
        outline = region.printRegion()

    return outline


def getFooter(regionFinder, shiftDict, doRevBackground):
    index = regionFinder.statistics["index"]
    mIndex = regionFinder.statistics["mIndex"]
    footerLines = ["#stats:\t%.1f RPM in %d regions" % (regionFinder.statistics["total"], index)]
    if regionFinder.doDirectionality:
        footerLines.append("#\t\t%d additional regions failed directionality filter" % regionFinder.statistics["failed"])

    if doRevBackground:
        try:
            percent = min(100. * (float(mIndex)/index), 100.)
        except ZeroDivisionError:
            percent = 0.

        footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, regionFinder.statistics["mTotal"], percent))

    if regionFinder.shiftValue == "auto" and regionFinder.reportshift:
        bestShift = getBestShiftInDict(shiftDict)
        footerLines.append("#mode of shift values: %d" % bestShift)

    if regionFinder.statistics["badRegionTrim"] > 0:
        footerLines.append("#%d regions discarded due to trimming problems" % regionFinder.statistics["badRegionTrim"])

    return string.join(footerLines, "\n")


def getBestShiftInDict(shiftDict):
    return max(shiftDict.iteritems(), key=operator.itemgetter(1))[0]


if __name__ == "__main__":
    main(sys.argv)
