"""
    usage: python $ERANGEPATH/findall.py label samplebamfile regionoutfile
           [--control controlbamfile] [--minimum minHits] [--ratio minRatio]
           [--spacing maxSpacing] [--listPeak] [--shift #bp | learn] [--learnFold num]
           [--noshift] [--autoshift] [--reportshift] [--nomulti] [--minPlus fraction]
           [--maxPlus fraction] [--leftPlus fraction] [--minPeak RPM] [--raw]
           [--revbackground] [--pvalue self|back|none] [--nodirectionality]
           [--strandfilter plus/minus] [--trimvalue percent] [--notrim]
           [--cache pages] [--log altlogfile] [--flag aflag] [--append] [--RNA]

           where values in brackets are optional and label is an arbitrary string.

           Use -ratio (default 4 fold) to set the minimum fold enrichment
           over the control, -minimum (default 4) is the minimum number of reads
           (RPM) within the region, and -spacing (default readlen) to set the maximum
           distance between reads in the region. -listPeak lists the peak of the
           region. Peaks mut be higher than -minPeak (default 0.5 RPM).
           Pvalues are calculated from the sample (change with -pvalue),
           unless the -revbackground flag and a control RDS file are provided.

           By default, all numbers and parameters are on a reads per
           million (RPM) basis. -raw will treat all settings, ratios and reported
           numbers as raw counts rather than RPM. Use -notrim to turn off region
           trimming and -trimvalue to control trimming (default 10% of peak signal)

           The peak finder uses minimal directionality information that can
           be turned off with -nodirectionality ; the fraction of + strand reads
           required to be to the left of the peak (default 0.3) can be set with
           -leftPlus ; -minPlus and -maxPlus change the minimum/maximum fraction
           of plus reads in a region, which (defaults 0.25 and 0.75, respectively).

           Use -shift to shift reads either by half the expected
           fragment length (default 0 bp) or '-shift learn ' to learn the shift
           based on the first chromosome. If you prefer to learn the shift
           manually, use -autoshift to calculate a per-region shift value, which
           can be reported using -reportshift. -strandfilter should only be used
           when explicitely calling unshifted stranded peaks from non-ChIP-seq
           data such as directional RNA-seq. regionoutfile is written over by
           default unless given the -append flag.
"""

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import math
import string
import optparse
import operator
import pysam
from commoncode import writeLog, findPeak, getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption, getConfigBoolOption, isSpliceEntry
import Region


versionString = "findall: version 3.2.1"
print versionString

class RegionDirectionError(Exception):
    pass
            

class RegionFinder():
    def __init__(self, label, minRatio=4.0, minPeak=0.5, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, strandfilter="",
                 minHits=4.0, trimValue=0.1, doTrim=True, doDirectionality=True, shiftValue=0, maxSpacing=50, withFlag="",
                 normalize=True, listPeak=False, reportshift=False, stringency=1.0, controlfile=None, doRevBackground=False):

        self.statistics = {"index": 0,
                           "total": 0,
                           "mIndex": 0,
                           "mTotal": 0,
                           "failed": 0,
                           "badRegionTrim": 0
        }

        self.regionLabel = label
        self.rnaSettings = False
        self.controlRDSsize = 1.0
        self.sampleRDSsize = 1.0
        self.minRatio = minRatio
        self.minPeak = minPeak
        self.leftPlusRatio = leftPlusRatio
        self.stranded = "both"
        if strandfilter == "plus":
            self.stranded = "+"
            minPlusRatio = 0.9
            maxPlusRatio = 1.0
        elif strandfilter == "minus":
            self.stranded = "-"
            minPlusRatio = 0.0
            maxPlusRatio = 0.1

        if minRatio < minPeak:
            self.minPeak = minRatio

        self.minPlusRatio = minPlusRatio
        self.maxPlusRatio = maxPlusRatio
        self.strandfilter = strandfilter
        self.minHits = minHits
        self.trimValue = trimValue
        self.doTrim = doTrim
        self.doDirectionality = doDirectionality

        if self.doTrim:
            self.trimString = string.join(["%2.1f" % (100. * self.trimValue), "%"], "")
        else:
            self.trimString = "none"

        self.shiftValue = shiftValue
        self.maxSpacing = maxSpacing
        self.withFlag = withFlag
        self.normalize = normalize
        self.listPeak = listPeak
        self.reportshift = reportshift
        self.stringency = max(stringency, 1.0)
        self.controlfile = controlfile
        self.doControl = self.controlfile is not None
        self.doPvalue = False
        self.doRevBackground = doRevBackground


    def useRNASettings(self, readlen):
        self.rnaSettings = True
        self.shiftValue = 0
        self.doTrim = False
        self.doDirectionality = False
        self.maxSpacing = readlen


    def getHeader(self):
        if self.normalize:
            countType = "RPM"
        else:
            countType = "COUNT"

        headerFields = ["#regionID\tchrom\tstart\tstop", countType, "fold\tmulti%"]

        if self.doDirectionality:
            headerFields.append("plus%\tleftPlus%")

        if self.listPeak:
            headerFields.append("peakPos\tpeakHeight")

        if self.reportshift:
            headerFields.append("readShift")

        if self.doPvalue:
            headerFields.append("pValue")

        return string.join(headerFields, "\t")


    def printSettings(self, ptype, useMulti, pValueType):
        print
        self.printStatusMessages(ptype, useMulti)
        self.printOptionsSummary(useMulti, pValueType)


    def printStatusMessages(self, ptype, useMulti):
        if self.shiftValue == "learn":
            print "Will try to learn shift"

        if self.normalize:
            print "Normalizing to RPM"

        if self.doRevBackground:
            print "Swapping IP and background to calculate FDR"

        if ptype != "":
            if ptype in ["NONE", "SELF"]:
                pass
            elif ptype == "BACK":
                if self.doControl and self.doRevBackground:
                    pass
                else:
                    print "must have a control dataset and -revbackground for pValue type 'back'"
            else:
                print "could not use pValue type : %s" % ptype

        if self.withFlag != "":
            print "restrict to flag = %s" % self.withFlag

        if not useMulti:
            print "using unique reads only"

        if self.rnaSettings:
            print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"

        if self.strandfilter == "plus":
            print "only analyzing reads on the plus strand"
        elif self.strandfilter == "minus":
            print "only analyzing reads on the minus strand"


    def printOptionsSummary(self, useMulti, pValueType):

        print "\nenforceDirectionality=%s listPeak=%s nomulti=%s " % (self.doDirectionality, self.listPeak, not useMulti)
        print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded)
        try:
            print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
        except:
            print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)


    def getAnalysisDescription(self, hitfile, useMulti, pValueType):

        description = ["#ERANGE %s" % versionString]
        if self.doControl:
            description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample:\t%s (%.1f M reads)" % (hitfile, self.sampleRDSsize, self.controlfile, self.controlRDSsize))
        else:
            description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample: none" % (hitfile, self.sampleRDSsize))

        if self.withFlag != "":
            description.append("#restrict to Flag = %s" % self.withFlag)

        description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s" % (self.doDirectionality, self.listPeak, not useMulti))
        description.append("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded))
        try:
            description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
        except:
            description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))

        return string.join(description, "\n")


    def getFooter(self, bestShift):
        index = self.statistics["index"]
        mIndex = self.statistics["mIndex"]
        footerLines = ["#stats:\t%.1f RPM in %d regions" % (self.statistics["total"], index)]
        if self.doDirectionality:
            footerLines.append("#\t\t%d additional regions failed directionality filter" % self.statistics["failed"])

        if self.doRevBackground:
            try:
                percent = min(100. * (float(mIndex)/index), 100.)
            except ZeroDivisionError:
                percent = 0.

            footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, self.statistics["mTotal"], percent))

        if self.shiftValue == "auto" and self.reportshift:
            
            footerLines.append("#mode of shift values: %d" % bestShift)

        if self.statistics["badRegionTrim"] > 0:
            footerLines.append("#%d regions discarded due to trimming problems" % self.statistics["badRegionTrim"])

        return string.join(footerLines, "\n")


    def updateControlStatistics(self, peak, sumAll, peakScore):

        plusRatio = float(peak.numPlus)/peak.numHits
        if peakScore >= self.minPeak and self.minPlusRatio <= plusRatio <= self.maxPlusRatio:
            if self.doDirectionality:
                if self.leftPlusRatio < peak.numLeftPlus / peak.numPlus:
                    self.statistics["mIndex"] += 1
                    self.statistics["mTotal"] += sumAll
                else:
                    self.statistics["failed"] += 1
            else:
                # we have a region, but didn't check for directionality
                self.statistics["mIndex"] += 1
                self.statistics["mTotal"] += sumAll


def usage():
    print __doc__


def main(argv=None):
    if not argv:
        argv = sys.argv

    parser = makeParser()
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        usage()
        sys.exit(2)

    factor = args[0]
    hitfile = args[1]
    outfilename = args[2]

    shiftValue = 0

    if options.autoshift:
        shiftValue = "auto"

    if options.shift is not None:
        try:
            shiftValue = int(options.shift)
        except ValueError:
            if options.shift == "learn":
                shiftValue = "learn"

    if options.noshift:
        shiftValue = 0

    if options.doAppend:
        outputMode = "a"
    else:
        outputMode = "w"

    regionFinder = RegionFinder(factor, minRatio=options.minRatio, minPeak=options.minPeak, minPlusRatio=options.minPlusRatio,
                                maxPlusRatio=options.maxPlusRatio, leftPlusRatio=options.leftPlusRatio, strandfilter=options.strandfilter,
                                minHits=options.minHits, trimValue=options.trimValue, doTrim=options.doTrim,
                                doDirectionality=options.doDirectionality, shiftValue=shiftValue, maxSpacing=options.maxSpacing,
                                withFlag=options.withFlag, normalize=options.normalize, listPeak=options.listPeak,
                                reportshift=options.reportshift, stringency=options.stringency, controlfile=options.controlfile,
                                doRevBackground=options.doRevBackground)

    findall(regionFinder, hitfile, outfilename, options.logfilename, outputMode, options.rnaSettings,
            options.ptype, options.useMulti, options.combine5p)


def makeParser():
    usage = __doc__

    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--control", dest="controlfile")
    parser.add_option("--minimum", type="float", dest="minHits")
    parser.add_option("--ratio", type="float", dest="minRatio")
    parser.add_option("--spacing", type="int", dest="maxSpacing")
    parser.add_option("--listPeak", action="store_true", dest="listPeak")
    parser.add_option("--shift", dest="shift")
    parser.add_option("--learnFold", type="float", dest="stringency")
    parser.add_option("--noshift", action="store_true", dest="noShift")
    parser.add_option("--autoshift", action="store_true", dest="autoshift")
    parser.add_option("--reportshift", action="store_true", dest="reportshift")
    parser.add_option("--nomulti", action="store_false", dest="useMulti")
    parser.add_option("--minPlus", type="float", dest="minPlusRatio")
    parser.add_option("--maxPlus", type="float", dest="maxPlusRatio")
    parser.add_option("--leftPlus", type="float", dest="leftPlusRatio")
    parser.add_option("--minPeak", type="float", dest="minPeak")
    parser.add_option("--raw", action="store_false", dest="normalize")
    parser.add_option("--revbackground", action="store_true", dest="doRevBackground")
    parser.add_option("--pvalue", dest="ptype")
    parser.add_option("--nodirectionality", action="store_false", dest="doDirectionality")
    parser.add_option("--strandfilter", dest="strandfilter")
    parser.add_option("--trimvalue", type="float", dest="trimValue")
    parser.add_option("--notrim", action="store_false", dest="doTrim")
    parser.add_option("--cache", type="int", dest="cachePages")
    parser.add_option("--log", dest="logfilename")
    parser.add_option("--flag", dest="withFlag")
    parser.add_option("--append", action="store_true", dest="doAppend")
    parser.add_option("--RNA", action="store_true", dest="rnaSettings")
    parser.add_option("--combine5p", action="store_true", dest="combine5p")

    configParser = getConfigParser()
    section = "findall"
    minHits = getConfigFloatOption(configParser, section, "minHits", 4.0)
    minRatio = getConfigFloatOption(configParser, section, "minRatio", 4.0)
    maxSpacing = getConfigIntOption(configParser, section, "maxSpacing", 50)
    listPeak = getConfigBoolOption(configParser, section, "listPeak", False)
    shift = getConfigOption(configParser, section, "shift", None)
    stringency = getConfigFloatOption(configParser, section, "stringency", 4.0)
    noshift = getConfigBoolOption(configParser, section, "noshift", False)
    autoshift = getConfigBoolOption(configParser, section, "autoshift", False)
    reportshift = getConfigBoolOption(configParser, section, "reportshift", False)
    minPlusRatio = getConfigFloatOption(configParser, section, "minPlusRatio", 0.25)
    maxPlusRatio = getConfigFloatOption(configParser, section, "maxPlusRatio", 0.75)
    leftPlusRatio = getConfigFloatOption(configParser, section, "leftPlusRatio", 0.3)
    minPeak = getConfigFloatOption(configParser, section, "minPeak", 0.5)
    normalize = getConfigBoolOption(configParser, section, "normalize", True)
    logfilename = getConfigOption(configParser, section, "logfilename", "findall.log")
    withFlag = getConfigOption(configParser, section, "withFlag", "")
    doDirectionality = getConfigBoolOption(configParser, section, "doDirectionality", True)
    trimValue = getConfigFloatOption(configParser, section, "trimValue", 0.1)
    doTrim = getConfigBoolOption(configParser, section, "doTrim", True)
    doAppend = getConfigBoolOption(configParser, section, "doAppend", False)
    rnaSettings = getConfigBoolOption(configParser, section, "rnaSettings", False)
    cachePages = getConfigOption(configParser, section, "cachePages", None)
    ptype = getConfigOption(configParser, section, "ptype", "")
    controlfile = getConfigOption(configParser, section, "controlfile", None)
    doRevBackground = getConfigBoolOption(configParser, section, "doRevBackground", False)
    useMulti = getConfigBoolOption(configParser, section, "useMulti", True)
    strandfilter = getConfigOption(configParser, section, "strandfilter", "")
    combine5p = getConfigBoolOption(configParser, section, "combine5p", False)

    parser.set_defaults(minHits=minHits, minRatio=minRatio, maxSpacing=maxSpacing, listPeak=listPeak, shift=shift,
                        stringency=stringency, noshift=noshift, autoshift=autoshift, reportshift=reportshift,
                        minPlusRatio=minPlusRatio, maxPlusRatio=maxPlusRatio, leftPlusRatio=leftPlusRatio, minPeak=minPeak,
                        normalize=normalize, logfilename=logfilename, withFlag=withFlag, doDirectionality=doDirectionality,
                        trimValue=trimValue, doTrim=doTrim, doAppend=doAppend, rnaSettings=rnaSettings,
                        cachePages=cachePages, ptype=ptype, controlfile=controlfile, doRevBackground=doRevBackground, useMulti=useMulti,
                        strandfilter=strandfilter, combine5p=combine5p)

    return parser


def findall(regionFinder, hitfile, outfilename, logfilename="findall.log", outputMode="w", rnaSettings=False,
            ptype="", useMulti=True, combine5p=False):

    writeLog(logfilename, versionString, string.join(sys.argv[1:]))
    controlBAM = None
    if regionFinder.doControl:
        print "\ncontrol:" 
        controlBAM = pysam.Samfile(regionFinder.controlfile, "rb")
        regionFinder.controlRDSsize = int(getHeaderComment(controlBAM.header, "Total")) / 1000000.

    print "\nsample:" 
    sampleBAM = pysam.Samfile(hitfile, "rb")
    regionFinder.sampleRDSsize = int(getHeaderComment(sampleBAM.header, "Total")) / 1000000.
    pValueType = getPValueType(ptype, regionFinder.doControl, regionFinder.doRevBackground)
    regionFinder.doPvalue = not pValueType == "none"
    regionFinder.readlen = int(getHeaderComment(sampleBAM.header, "ReadLength"))
    if rnaSettings:
        regionFinder.useRNASettings(regionFinder.readlen)

    regionFinder.printSettings(ptype, useMulti, pValueType)
    outfile = open(outfilename, outputMode)
    header = writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, pValueType)
    shiftDict = {}
    chromList = getChromosomeListToProcess(sampleBAM, controlBAM)
    for chromosome in chromList:
        #TODO: Really? Use first chr shift value for all of them
        maxSampleCoord = getMaxCoordinate(sampleBAM, chromosome, doMulti=useMulti)
        if regionFinder.shiftValue == "learn":
            regionFinder.shiftValue, shiftDict = learnShift(regionFinder, sampleBAM, maxSampleCoord, chromosome, logfilename, outfilename, outfile, useMulti, controlBAM, combine5p)

        allregions, outregions = findPeakRegions(regionFinder, sampleBAM, maxSampleCoord, chromosome, logfilename, outfilename, outfile, useMulti, controlBAM, combine5p)
        if regionFinder.doRevBackground:
            maxControlCoord = getMaxCoordinate(controlBAM, chromosome, doMulti=useMulti)
            backregions = findBackgroundRegions(regionFinder, sampleBAM, controlBAM, maxControlCoord, chromosome, useMulti)
        else:
            backregions = []
            pValueType = "self"

        writeChromosomeResults(regionFinder, outregions, outfile, shiftDict, allregions, header, backregions=backregions, pValueType=pValueType)

    try:
        bestShift = getBestShiftInDict(shiftDict)
    except ValueError:
        bestShift = 0

    footer = regionFinder.getFooter(bestShift)
    print footer
    print >> outfile, footer
    outfile.close()
    writeLog(logfilename, versionString, outfilename + footer.replace("\n#"," | ")[:-1])


def getHeaderComment(bamHeader, commentKey):
    for comment in bamHeader["CO"]:
        fields = comment.split("\t")
        if fields[0] == commentKey:
            return fields[1]

    raise KeyError


def getPValueType(ptype, doControl, doRevBackground):
    pValueType = "self"
    if ptype in ["NONE", "SELF", "BACK"]:
        if ptype == "NONE":
            pValueType = "none"
        elif ptype == "SELF":
            pValueType = "self"
        elif ptype == "BACK":
            if doControl and doRevBackground:
                pValueType = "back"
    elif doRevBackground:
        pValueType = "back"

    return pValueType


def writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, pValueType):
    print >> outfile, regionFinder.getAnalysisDescription(hitfile, useMulti, pValueType)
    header = regionFinder.getHeader()
    print >> outfile, header

    return header


def getChromosomeListToProcess(sampleBAM, controlBAM=None):
    if controlBAM is not None:
        chromosomeList = [chrom for chrom in sampleBAM.references if chrom in controlBAM.references and chrom != "chrM"]
    else:
        chromosomeList = [chrom for chrom in sampleBAM.references if chrom != "chrM"]

    return chromosomeList


def findPeakRegions(regionFinder, sampleBAM, maxCoord, chromosome, logfilename, outfilename,
                    outfile, useMulti, controlBAM, combine5p):

    outregions = []
    allregions = []
    print "chromosome %s" % (chromosome)
    previousHit = - 1 * regionFinder.maxSpacing
    readStartPositions = [-1]
    totalWeight = 0.0
    uniqueReadCount = 0
    reads = []
    numStartsInRegion = 0

    for alignedread in sampleBAM.fetch(chromosome):
        if doNotProcessRead(alignedread, doMulti=useMulti, strand=regionFinder.stranded, combine5p=combine5p):
            continue

        pos = alignedread.pos
        if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
            lastReadPos = readStartPositions[-1]
            lastBasePosition = lastReadPos + regionFinder.readlen - 1
            newRegionIndex = regionFinder.statistics["index"] + 1
            if regionFinder.doDirectionality:
                region = Region.DirectionalRegion(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel,
                                                  numReads=totalWeight)
            else:
                region = Region.Region(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel, numReads=totalWeight)

            if regionFinder.normalize:
                region.numReads /= regionFinder.sampleRDSsize

            allregions.append(int(region.numReads))
            regionLength = lastReadPos - region.start
            if regionPassesCriteria(regionFinder, region.numReads, numStartsInRegion, regionLength):
                region.foldRatio = getFoldRatio(regionFinder, controlBAM, region.numReads, chromosome, region.start, lastReadPos, useMulti)

                if region.foldRatio >= regionFinder.minRatio:
                    # first pass, with absolute numbers
                    peak = findPeak(reads, region.start, regionLength, regionFinder.readlen, doWeight=True, leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
                    if regionFinder.doTrim:
                        try:
                            lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.sampleRDSsize)
                        except IndexError:
                            regionFinder.statistics["badRegionTrim"] += 1
                            continue

                        region.foldRatio = getFoldRatio(regionFinder, controlBAM, region.numReads, chromosome, region.start, lastReadPos, useMulti)

                    try:
                        bestPos = peak.topPos[0]
                        peakScore = peak.smoothArray[bestPos]
                        if regionFinder.normalize:
                            peakScore /= regionFinder.sampleRDSsize
                    except (IndexError, AttributeError, ZeroDivisionError):
                        continue

                    if regionFinder.listPeak:
                        region.peakDescription = "%d\t%.1f" % (region.start + bestPos, peakScore)

                    if useMulti:
                        setMultireadPercentage(region, sampleBAM, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos,
                                               regionFinder.normalize, regionFinder.doTrim)

                    region.shift = peak.shift
                    # check that we still pass threshold
                    regionLength = lastReadPos - region.start
                    try:
                        plusRatio = float(peak.numPlus)/peak.numHits
                    except ZeroDivisionError:    # peak.numHits can be 0.0 if shift is larger than the region length
                        plusRatio = 0

                    if regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
                        try:
                            updateRegion(region, regionFinder.doDirectionality, regionFinder.leftPlusRatio, peak.numLeftPlus, peak.numPlus, plusRatio)
                            regionFinder.statistics["index"] += 1
                            outregions.append(region)
                            regionFinder.statistics["total"] += region.numReads
                        except RegionDirectionError:
                            regionFinder.statistics["failed"] += 1

            readStartPositions = []
            totalWeight = 0.0
            uniqueReadCount = 0
            reads = []
            numStartsInRegion = 0

        if pos not in readStartPositions:
            numStartsInRegion += 1

        readStartPositions.append(pos)
        weight = 1.0/alignedread.opt('NH')
        totalWeight += weight
        if weight == 1.0:
            uniqueReadCount += 1

        reads.append({"start": pos, "sense": getReadSense(alignedread), "weight": weight})
        previousHit = pos

    return allregions, outregions


def getReadSense(read):
    if read.is_reverse:
        sense = "-"
    else:
        sense = "+"

    return sense


def doNotProcessRead(read, doMulti=False, strand="both", combine5p=False):
    if read.opt('NH') > 1 and not doMulti:
        return True

    if strand == "+" and read.is_reverse:
        return True

    if strand == "-" and not read.is_reverse:
        return True
        
    return False


def getMaxCoordinate(samfile, chr, doMulti=False):
    maxCoord = 0
    for alignedread in samfile.fetch(chr):
        if alignedread.opt('NH') > 1:
            if doMulti:
                maxCoord = max(maxCoord, alignedread.pos)
        else:
            maxCoord = max(maxCoord, alignedread.pos)

    return maxCoord


def findBackgroundRegions(regionFinder, sampleBAM, controlBAM, maxCoord, chromosome, useMulti):
    #TODO: this is *almost* the same calculation - there are small yet important differences
    print "calculating background..."
    previousHit = - 1 * regionFinder.maxSpacing
    currentHitList = [-1]
    currentTotalWeight = 0.0
    currentReadList = []
    backregions = []
    numStarts = 0
    badRegion = False
    for alignedread in controlBAM.fetch(chromosome):
        if doNotProcessRead(alignedread, doMulti=useMulti):
            continue

        pos = alignedread.pos
        if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
            lastReadPos = currentHitList[-1]
            lastBasePosition = lastReadPos + regionFinder.readlen - 1
            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
            if regionFinder.normalize:
                region.numReads /= regionFinder.controlRDSsize

            backregions.append(int(region.numReads))
            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
            regionLength = lastReadPos - region.start
            if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
                numMock = 1. + countReadsInRegion(sampleBAM, chromosome, region.start, lastReadPos, countMulti=useMulti)
                if regionFinder.normalize:
                    numMock /= regionFinder.sampleRDSsize

                foldRatio = region.numReads / numMock
                if foldRatio >= regionFinder.minRatio:
                    # first pass, with absolute numbers
                    peak = findPeak(currentReadList, region.start, lastReadPos - region.start, regionFinder.readlen, doWeight=True,
                                    leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)

                    if regionFinder.doTrim:
                        try:
                            lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, 20., currentReadList, regionFinder.controlRDSsize)
                        except IndexError:
                            badRegion = True
                            continue

                        numMock = 1. + countReadsInRegion(sampleBAM, chromosome, region.start, lastReadPos, countMulti=useMulti)
                        if regionFinder.normalize:
                            numMock /= regionFinder.sampleRDSsize

                        foldRatio = region.numReads / numMock

                    try:
                        bestPos = peak.topPos[0]
                        peakScore = peak.smoothArray[bestPos]
                    except IndexError:
                        continue

                    # normalize to RPM
                    if regionFinder.normalize:
                        peakScore /= regionFinder.controlRDSsize

                    # check that we still pass threshold
                    regionLength = lastReadPos - region.start
                    if regionPassesCriteria(regionFinder, region.numReads, foldRatio, regionLength):
                        regionFinder.updateControlStatistics(peak, region.numReads, peakScore)

            currentHitList = []
            currentTotalWeight = 0.0
            currentReadList = []
            numStarts = 0
            if badRegion:
                badRegion = False
                regionFinder.statistics["badRegionTrim"] += 1

        if pos not in currentHitList:
            numStarts += 1

        currentHitList.append(pos)
        weight = 1.0/alignedread.opt('NH')
        currentTotalWeight += weight
        currentReadList.append({"start": pos, "sense": getReadSense(alignedread), "weight": weight})
        previousHit = pos

    return backregions


def learnShift(regionFinder, sampleBAM, maxCoord, chromosome, logfilename, outfilename,
               outfile, useMulti, controlBAM, combine5p):

    print "learning shift.... will need at least 30 training sites"
    stringency = regionFinder.stringency
    previousHit = -1 * regionFinder.maxSpacing
    positionList = [-1]
    totalWeight = 0.0
    readList = []
    shiftDict = {}
    count = 0
    numStarts = 0
    for alignedread in sampleBAM.fetch(chromosome):
        if doNotProcessRead(alignedread, doMulti=useMulti, strand=regionFinder.stranded, combine5p=combine5p):
            continue

        pos = alignedread.pos
        if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
            if regionFinder.normalize:
                totalWeight /= regionFinder.sampleRDSsize

            regionStart = positionList[0]
            regionStop = positionList[-1]
            regionLength = regionStop - regionStart
            if regionPassesCriteria(regionFinder, totalWeight, numStarts, regionLength, stringency=stringency):
                foldRatio = getFoldRatio(regionFinder, controlBAM, totalWeight, chromosome, regionStart, regionStop, useMulti)
                if foldRatio >= regionFinder.minRatio:
                    updateShiftDict(shiftDict, readList, regionStart, regionLength, regionFinder.readlen)
                    count += 1

            positionList = []
            totalWeight = 0.0
            readList = []

        if pos not in positionList:
            numStarts += 1

        positionList.append(pos)
        weight = 1.0/alignedread.opt('NH')
        totalWeight += weight
        readList.append({"start": pos, "sense": getReadSense(alignedread), "weight": weight})
        previousHit = pos

    outline = "#learn: stringency=%.2f min_signal=%2.f min_ratio=%.2f min_region_size=%d\n#number of training examples: %d" % (stringency,
                                                                                                                               stringency * regionFinder.minHits,
                                                                                                                               stringency * regionFinder.minRatio,
                                                                                                                               stringency * regionFinder.readlen,
                                                                                                                               count)

    print outline
    writeLog(logfilename, versionString, outfilename + outline)
    shiftValue = getShiftValue(shiftDict, count, logfilename, outfilename)
    outline = "#picked shiftValue to be %d" % shiftValue
    print outline
    print >> outfile, outline
    writeLog(logfilename, versionString, outfilename + outline)

    return shiftValue, shiftDict


def previousRegionIsDone(pos, previousHit, maxSpacing, maxCoord):
    return abs(pos - previousHit) > maxSpacing or pos == maxCoord


def regionPassesCriteria(regionFinder, sumAll, numStarts, regionLength, stringency=1):
    minTotalReads = stringency * regionFinder.minHits
    minNumReadStarts = stringency * regionFinder.minRatio
    minRegionLength = stringency * regionFinder.readlen

    return sumAll >= minTotalReads and numStarts >= minNumReadStarts and regionLength > minRegionLength


def trimRegion(region, regionFinder, peak, regionStop, trimValue, currentReadList, totalReadCount):
    bestPos = peak.topPos[0]
    peakScore = peak.smoothArray[bestPos]
    if regionFinder.normalize:
        peakScore /= totalReadCount

    minSignalThresh = trimValue * peakScore
    start = findStartEdgePosition(peak, minSignalThresh)
    regionEndPoint = regionStop - region.start - 1
    stop = findStopEdgePosition(peak, regionEndPoint, minSignalThresh)

    regionStop = region.start + stop
    region.start += start

    trimmedPeak = findPeak(currentReadList, region.start, regionStop - region.start, regionFinder.readlen, doWeight=True,
                           leftPlus=regionFinder.doDirectionality, shift=peak.shift)

    peak.numPlus = trimmedPeak.numPlus
    peak.numLeftPlus = trimmedPeak.numLeftPlus
    peak.topPos = trimmedPeak.topPos
    peak.smoothArray = trimmedPeak.smoothArray

    region.numReads = trimmedPeak.numHits
    if regionFinder.normalize:
        region.numReads /= totalReadCount

    region.stop = regionStop + regionFinder.readlen - 1
                          
    return regionStop


def findStartEdgePosition(peak, minSignalThresh):
    start = 0
    while not peakEdgeLocated(peak, start, minSignalThresh):
        start += 1

    return start


def findStopEdgePosition(peak, stop, minSignalThresh):
    while not peakEdgeLocated(peak, stop, minSignalThresh):
        stop -= 1

    return stop


def peakEdgeLocated(peak, position, minSignalThresh):
    return peak.smoothArray[position] >= minSignalThresh or position == peak.topPos[0]


def getFoldRatio(regionFinder, controlBAM, sumAll, chromosome, regionStart, regionStop, useMulti):
    """ Fold ratio calculated is total read weight over control
    """
    #TODO: this needs to be generalized as there is a point at which we want to use the sampleRDS instead of controlRDS
    if regionFinder.doControl:
        numMock = 1. + countReadsInRegion(controlBAM, chromosome, regionStart, regionStop, countMulti=useMulti)
        if regionFinder.normalize:
            numMock /= regionFinder.controlRDSsize

        foldRatio = sumAll / numMock
    else:
        foldRatio = regionFinder.minRatio

    return foldRatio


def countReadsInRegion(bamfile, chr, start, end, uniqs=True, countMulti=False, countSplices=False):
    count = 0.0
    for alignedread in bamfile.fetch(chr, start, end):
        if alignedread.opt('NH') > 1:
            if countMulti:
                if isSpliceEntry(alignedread.cigar):
                    if countSplices:
                        count += 1.0/alignedread.opt('NH')
                else:
                    count += 1.0/alignedread.opt('NH')
        elif uniqs:
            if isSpliceEntry(alignedread.cigar):
                if countSplices:
                    count += 1.0
            else:
                count += 1.0

    return count

def updateShiftDict(shiftDict, readList, regionStart, regionLength, readlen):
    peak = findPeak(readList, regionStart, regionLength, readlen, doWeight=True, shift="auto")
    try:
        shiftDict[peak.shift] += 1
    except KeyError:
        shiftDict[peak.shift] = 1


def getShiftValue(shiftDict, count, logfilename, outfilename):
    if count < 30:
        outline = "#too few training examples to pick a shiftValue - defaulting to 0\n#consider picking a lower minimum or threshold"
        print outline
        writeLog(logfilename, versionString, outfilename + outline)
        shiftValue = 0
    else:
        shiftValue = getBestShiftInDict(shiftDict)
        print shiftDict

    return shiftValue


def getRegion(regionStart, regionStop, factor, index, chromosome, sumAll, foldRatio, multiP,
              peakDescription, shift, doDirectionality, leftPlusRatio, numLeft,
              numPlus, plusRatio):

    if doDirectionality:
        if leftPlusRatio < numLeft / numPlus:
            plusP = plusRatio * 100.
            leftP = 100. * numLeft / numPlus
            # we have a region that passes all criteria
            region = Region.DirectionalRegion(regionStart, regionStop,
                                              factor, index, chromosome, sumAll,
                                              foldRatio, multiP, plusP, leftP,
                                              peakDescription, shift)

        else:
            raise RegionDirectionError
    else:
        # we have a region, but didn't check for directionality
        region = Region.Region(regionStart, regionStop, factor, index, chromosome,
                               sumAll, foldRatio, multiP, peakDescription, shift)

    return region


def setMultireadPercentage(region, sampleBAM, hitRDSsize, currentTotalWeight, currentUniqueCount, chromosome, lastReadPos, normalize, doTrim):
    if doTrim:
        sumMulti = 0.0
        for alignedread in sampleBAM.fetch(chromosome, region.start, lastReadPos):
            if alignedread.opt('NH') > 1:
                sumMulti += 1.0/alignedread.opt('NH')
    else:
        sumMulti = currentTotalWeight - currentUniqueCount

    # normalize to RPM
    if normalize:
        sumMulti /= hitRDSsize

    try:
        multiP = 100. * (sumMulti / region.numReads)
    except ZeroDivisionError:
        return

    region.multiP = min(multiP, 100.)


def regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
    regionPasses = False
    if regionPassesCriteria(regionFinder, region.numReads, region.foldRatio, regionLength):
        #TODO: here is where the test dataset is failing
        if peakScore >= regionFinder.minPeak and regionFinder.minPlusRatio <= plusRatio <= regionFinder.maxPlusRatio:
            regionPasses = True

    return regionPasses


def updateRegion(region, doDirectionality, leftPlusRatio, numLeft, numPlus, plusRatio):

    if doDirectionality:
        if leftPlusRatio < numLeft / numPlus:
            region.plusP = plusRatio * 100.
            region.leftP = 100. * numLeft / numPlus
        else:
            raise RegionDirectionError


def writeChromosomeResults(regionFinder, outregions, outfile, shiftDict,
                           allregions, header, backregions=[], pValueType="none"):

    print regionFinder.statistics["mIndex"], regionFinder.statistics["mTotal"]
    if regionFinder.doPvalue:
        if pValueType == "self":
            poissonmean = calculatePoissonMean(allregions)
        else:
            poissonmean = calculatePoissonMean(backregions)

    print header
    for region in outregions:
        if regionFinder.shiftValue == "auto" and regionFinder.reportshift:
            try:
                shiftDict[region.shift] += 1
            except KeyError:
                shiftDict[region.shift] = 1

        outline = getRegionString(region, regionFinder.reportshift)

        # iterative poisson from http://stackoverflow.com/questions/280797?sort=newest
        if regionFinder.doPvalue:
            pValue = calculatePValue(int(region.numReads), poissonmean)
            outline += "\t%1.2g" % pValue

        print outline
        print >> outfile, outline


def calculatePoissonMean(dataList):
    dataList.sort()
    listSize = float(len(dataList))
    try:
        poissonmean = sum(dataList) / listSize
    except ZeroDivisionError:
        poissonmean = 0

    print "Poisson n=%d, p=%f" % (listSize, poissonmean)

    return poissonmean


def calculatePValue(sum, poissonmean):
    pValue = math.exp(-poissonmean)
    for i in xrange(sum):
        pValue *= poissonmean
        pValue /= i+1

    return pValue


def getRegionString(region, reportShift):
    if reportShift:
        outline = region.printRegionWithShift()
    else:
        outline = region.printRegion()

    return outline


def getBestShiftInDict(shiftDict):
    return max(shiftDict.iteritems(), key=operator.itemgetter(1))[0]


if __name__ == "__main__":
    main(sys.argv)