#
#  commoncode.py
#  ENRAGE
#

import ConfigParser
import os
import string
from time import strftime
from array import array
from collections import defaultdict
import Peak
from cistematic.core.geneinfo import geneinfoDB
from cistematic.genomes import Genome
import Region

commoncodeVersion = 5.6
currentRDSversion = 2.0

class ErangeError(Exception):
    pass


def getReverseComplement(base):
    revComp = {"A": "T",
               "T": "A",
               "G": "C",
               "C": "G",
               "N": "N"
        }

    return revComp[base]


def countDuplicatesInList(listToCheck):
    tally = defaultdict(int)
    for item in listToCheck:
        tally[item] += 1

    return tally.items()


def writeLog(logFile, messenger, message):
    """ create a log file to write a message from a messenger or append to an existing file.
    """
    try:
        logfile = open(logFile)
    except IOError:
        logfile = open(logFile, "w")
    else:
        logfile = open(logFile, "a")

    logfile.writelines("%s: [%s] %s\n" % (strftime("%Y-%m-%d %H:%M:%S"), messenger, message))
    logfile.close()


def getGeneInfoDict(genome, cache=False):
    idb = geneinfoDB(cache=cache)
    if genome == "dmelanogaster":
        geneinfoDict = idb.getallGeneInfo(genome, infoKey="locus")
    else:
        geneinfoDict = idb.getallGeneInfo(genome)

    return geneinfoDict


def getGeneAnnotDict(genome, inRAM=False):
    return getExtendedGeneAnnotDict(genome, "", inRAM=inRAM)


def getExtendedGeneAnnotDict(genomeName, extendGenome, replaceModels=False, inRAM=False):
    genome = Genome(genomeName, inRAM=inRAM)
    if extendGenome != "":
        genome.extendFeatures(extendGenome, replace=replaceModels)

    geneannotDict = genome.allAnnotInfo()

    return geneannotDict


def getConfigParser(fileList=[]):
    configFiles = ["erange.config", os.path.expanduser("~/.erange.config")]
    for filename in fileList:
        configFiles.append(filename)

    config = ConfigParser.SafeConfigParser()
    config.read(configFiles)

    return config


def getConfigOption(parser, section, option, default=None):
    try:
        setting = parser.get(section, option)
    except (ConfigParser.NoSectionError, ConfigParser.NoOptionError):
        setting = default

    return setting


def getConfigIntOption(parser, section, option, default=None):
    try:
        setting = parser.getint(section, option)
    except (ConfigParser.NoSectionError, ConfigParser.NoOptionError):
        setting = default

    return setting


def getConfigFloatOption(parser, section, option, default=None):
    try:
        setting = parser.getfloat(section, option)
    except (ConfigParser.NoSectionError, ConfigParser.NoOptionError):
        setting = default

    return setting


def getConfigBoolOption(parser, section, option, default=None):
    try:
        setting = parser.getboolean(section, option)
    except (ConfigParser.NoSectionError, ConfigParser.NoOptionError, ValueError):
        setting = default

    return setting


def getAllConfigSectionOptions(parser, section):
    try:
        setting = parser.items(section)
    except ConfigParser.NoSectionError:
        setting = []

    return setting


def getMergedRegions(regionfilename, maxDist=1000, minHits=0, verbose=False, keepLabel=False,
                     fullChrom=False, chromField=1, scoreField=4, pad=0, compact=False,
                     doMerge=True, keepPeak=False, returnTop=0):

    """ returns a dictionary containing a list of merged overlapping regions by chromosome; 
    can optionally filter regions that have a scoreField fewer than minHits.
    Can also optionally return the label of each region, as well as the
    peak, if supplied (peakPos and peakHeight should be the last 2 fields).
    Can return the top regions based on score if higher than minHits.
    """
    infile = open(regionfilename)
    lines = infile.readlines()
    regions = getMergedRegionsFromList(lines, maxDist, minHits, verbose, keepLabel,
                                       fullChrom, chromField, scoreField, pad, compact,
                                       doMerge, keepPeak, returnTop)

    infile.close()

    return regions


def getMergedRegionsFromList(regionList, maxDist=1000, minHits=0, verbose=False, keepLabel=False,
                     fullChrom = False, chromField=1, scoreField=4, pad=0, compact=False,
                     doMerge=True, keepPeak=False, returnTop=0):
    """ returns a dictionary containing a list of merged overlapping regions by chromosome; 
    can optionally filter regions that have a scoreField fewer than minHits.
    Can also optionally return the label of each region, as well as the
    peak, if supplied (peakPos and peakHeight should be the last 2 fields).
    Can return the top regions based on score if higher than minHits.
    """
    regions = {}
    hasPvalue = 0
    hasShift = 0
    if 0 < returnTop < len(regionList):
        scores = []
        for regionEntry in regionList:
            if regionEntry[0] == "#":
                if "pvalue" in regionEntry:
                    hasPvalue = 1

                if "readShift" in regionEntry:
                    hasShift = 1

                continue

            fields = regionEntry.strip().split("\t")
            hits = float(fields[scoreField].strip())
            scores.append(hits)

        scores.sort()
        returnTop = -1 * returnTop 
        minScore = scores[returnTop]
        if minScore > minHits:
            minHits = minScore

    mergeCount = 0
    chromField = int(chromField)
    count = 0
    for regionEntry in regionList:
        if regionEntry[0] == "#":
            if "pvalue" in regionEntry:
                hasPvalue = 1

            if "readShift" in regionEntry:
                hasShift = 1

            continue

        fields = regionEntry.strip().split("\t")
        if minHits >= 0:
            try:
                hits = float(fields[scoreField].strip())
            except (IndexError, ValueError):
                continue

            if hits < minHits:
                continue

        if compact:
            (chrom, pos) = fields[chromField].split(":")
            (front, back) = pos.split("-")
            start = int(front)
            stop = int(back)
        elif chromField > 1:
            label = string.join(fields[:chromField],"\t")
            chrom = fields[chromField]
            start = int(fields[chromField + 1]) - pad
            stop = int(fields[chromField + 2]) + pad
        else:
            label = fields[0]
            chrom = fields[1]
            start = int(fields[2]) - pad
            stop = int(fields[3]) + pad

        if not fullChrom:
            chrom = chrom[3:]

        if keepPeak:
            peakPos = int(fields[-2 - hasPvalue - hasShift])
            peakHeight = float(fields[-1 - hasPvalue - hasShift])

        if chrom not in regions:
            regions[chrom] = []

        merged = False

        if doMerge and len(regions[chrom]) > 0:
            for index in range(len(regions[chrom])):
                region = regions[chrom][index]
                rstart = region.start
                rstop = region.stop
                if regionsOverlap(start, stop, rstart, rstop) or regionsAreWithinDistance(start, stop, rstart, rstop, maxDist):
                    if start < rstart:
                        rstart = start

                    if rstop < stop:
                        rstop = stop

                    if keepPeak:
                        rpeakPos = region.peakPos
                        rpeakHeight = region.peakHeight
                        if peakHeight > rpeakHeight:
                            rpeakHeight = peakHeight
                            rpeakPos = peakPos

                    regions[chrom][index].start = rstart
                    regions[chrom][index].stop = rstop
                    regions[chrom][index].length = abs(rstop - rstart)
                    if keepLabel:
                        regions[chrom][index].label = label

                    if keepPeak:
                        regions[chrom][index].peakPos = rpeakPos
                        regions[chrom][index].peakHeight = rpeakHeight


                    mergeCount += 1
                    merged = True
                    break

        if not merged:
            region = Region.Region(start, stop)
            if keepLabel:
                region.label = label

            if keepPeak:
                region.peakPos = peakPos
                region.peakHeight = peakHeight

            regions[chrom].append(region)
            count += 1

        if verbose and (count % 100000 == 0):
            print count

    regionCount = 0
    for chrom in regions:
        regionCount += len(regions[chrom])
        regions[chrom].sort(cmp=lambda x,y:cmp(x.start, y.start))

    if verbose:
        print "merged %d times" % mergeCount
        print "returning %d regions" % regionCount

    return regions


def regionsOverlap(start, stop, rstart, rstop):
    if start > stop:
        (start, stop) = (stop, start)

    if rstart > rstop:
        (rstart, rstop) = (rstop, rstart)

    return (rstart <= start <= rstop) or (rstart <= stop <= rstop) or (start <= rstart <= stop) or (start <= rstop <= stop)


def regionsAreWithinDistance(start, stop, rstart, rstop, maxDist):
    if start > stop:
        (start, stop) = (stop, start)

    if rstart > rstop:
        (rstart, rstop) = (rstop, rstart)

    return (abs(rstart-stop) <= maxDist) or (abs(rstop-start) <= maxDist)


def findPeak(hitList, start, length, readlen=25, doWeight=False, leftPlus=False,
             shift=0, maxshift=75):
    """ find the peak in a list of reads (hitlist) in a region
    of a given length and absolute start point. returns a
    list of peaks, the number of hits, a triangular-smoothed
    version of hitlist, and the number of reads that are
    forward (plus) sense.
    If doWeight is True, weight the reads accordingly.
    If leftPlus is True, return the number of plus reads left of
    the peak, taken to be the first TopPos position.
    """

    if shift == "auto":
        shift = getBestShiftForRegion(hitList, start, length, useWeight=doWeight, maxShift=maxshift)

    seqArray, regionArray, numHits, numPlus = findPeakSequenceArray(hitList, start, shift, length, readlen, doWeight, leftPlus)

    # implementing a triangular smooth
    smoothArray = array("f", [0.] * length)
    for pos in range(2,length -2):
        smoothArray[pos] = (seqArray[pos -2] + 2 * seqArray[pos - 1] + 3 * seqArray[pos] + 2 * seqArray[pos + 1] + seqArray[pos + 2]) / 9.0

    topPos = getPeakPositionList(smoothArray, length)
    peak = Peak.Peak(topPos, numHits, smoothArray, numPlus, shift=shift)

    if leftPlus:
        numLeftPlus = 0
        maxPos = topPos[0]
        for read in regionArray:
            if doWeight:
                weight = read["weight"]
            else:
                weight = 1.0

            currentPos = read["start"] - start
            if currentPos <= maxPos and read["sense"] == "+":
                numLeftPlus += weight

        peak.numLeftPlus = numLeftPlus

    return peak


def getBestShiftForRegion(readList, start, length, useWeight=False, maxShift=75):
    bestShift = 0
    lowestScore = 20000000000
    for testShift in xrange(maxShift + 1):
        shiftArray = array("f", [0.] * length)
        for read in readList:
            currentpos = read["start"] - start
            if read["sense"] == "+":
                currentpos += testShift
            else:
                currentpos -= testShift

            if (currentpos < 1) or (currentpos >= length):
                continue

            if useWeight:
                weight = read["weight"]
            else:
                weight = 1.0

            if read["sense"] == "+":
                shiftArray[currentpos] += weight
            else:
                shiftArray[currentpos] -= weight

        currentScore = 0
        for score in shiftArray:
            currentScore += abs(score)

        print currentScore
        if currentScore < lowestScore:
            bestShift = testShift
            lowestScore = currentScore

    return bestShift


def findPeakSequenceArray(hitList, start, shift, length, readlen, doWeight, leftPlus):
    seqArray = array("f", [0.] * length)
    numHits = 0.
    numPlus = 0.
    regionArray = []
    for read in hitList:
        currentpos = read["start"] - start
        if read["sense"] == "+":
            currentpos += shift
        else:
            currentpos -= shift

        if (currentpos <  1 - readlen) or (currentpos >= length):
            continue

        if doWeight:
            weight = read["weight"]
        else:
            weight = 1.0

        numHits += weight
        if leftPlus:
            regionArray.append(read)

        hitIndex = 0
        while currentpos < 0:
            hitIndex += 1
            currentpos += 1

        while hitIndex < readlen and currentpos < length:
            seqArray[currentpos] += weight
            hitIndex += 1
            currentpos += 1

        if read["sense"] == "+":
            numPlus += weight

    return seqArray, regionArray, numHits, numPlus


def getPeakPositionList(smoothArray, length):
    topNucleotide = 0
    peakList = []
    for currentpos in xrange(length):
        if topNucleotide < smoothArray[currentpos]:
            topNucleotide = smoothArray[currentpos]
            peakList = [currentpos]
        elif topNucleotide == smoothArray[currentpos]:
            peakList.append(currentpos)

    return peakList


def getFeaturesByChromDict(genomeObject, additionalRegionsDict={}, ignorePseudo=False,
                           restrictList=[], regionComplement=False, maxStop=250000000):
    """ return a dictionary of cistematic gene features. Requires
    cistematic, obviously. Can filter-out pseudogenes. Will use
    additional regions dict to supplement gene models, if available.
    Can restrict output to a list of GIDs.
    If regionComplement is set to true, returns the regions *outside* of the
    calculated boundaries, which is useful for retrieving intronic and
    intergenic regions. maxStop is simply used to define the uppermost
    boundary of the complement region.
    """ 
    featuresDict = genomeObject.getallGeneFeatures()
    restrictGID = False
    if len(restrictList) > 0:
        restrictGID = True

    if len(additionalRegionsDict) > 0:
        sortList = []
        for chrom in additionalRegionsDict:
            for region in additionalRegionsDict[chrom]:
                label = region.label
                if label not in sortList:
                    sortList.append(label)

                if label not in featuresDict:
                    featuresDict[label] = []
                    sense = "+"
                else:
                    sense = featuresDict[label][0][-1]

                featuresDict[label].append(("custom", chrom, region.start, region.stop, sense))

        for gid in sortList:
            featuresDict[gid].sort(cmp=lambda x,y:cmp(x[2], y[2]))

    featuresByChromDict = {}
    for gid in featuresDict:
        if restrictGID and gid not in restrictList:
            continue

        featureList = featuresDict[gid]
        newFeatureList = []
        isPseudo = False
        for (ftype, chrom, start, stop, sense) in featureList:
            if ftype == "PSEUDO":
                isPseudo = True

            if (start, stop, ftype) not in newFeatureList:
                notContained = True
                containedList = []
                for (fstart, fstop, ftype2) in newFeatureList:
                    if start >= fstart and stop <= fstop:
                        notContained = False

                    if start < fstart and stop > fstop:
                        containedList.append((fstart, fstop))

                if len(containedList) > 0:
                    newFList = []
                    notContained = True
                    for (fstart, fstop, ftype2) in newFeatureList:
                        if (fstart, fstop) not in containedList:
                            newFList.append((fstart, fstop, ftype2))
                            if start >= fstart and stop <= fstop:
                                notContained = False

                    newFeatureList = newFList
                if notContained:
                    newFeatureList.append((start, stop, ftype))

        if ignorePseudo and isPseudo:
            continue

        if chrom not in featuresByChromDict:
            featuresByChromDict[chrom] = []

        for (start, stop, ftype) in newFeatureList:
            featuresByChromDict[chrom].append((start, stop, gid, sense, ftype))

    for chrom in featuresByChromDict:
        featuresByChromDict[chrom].sort()

    if regionComplement:
        complementByChromDict = {}
        complementIndex = 0
        for chrom in featuresByChromDict:
            complementByChromDict[chrom] = []
            listLength = len(featuresByChromDict[chrom])
            if listLength > 0:
                currentStart = 0
                for index in range(listLength):
                    currentStop = featuresByChromDict[chrom][index][0]
                    complementIndex += 1
                    if currentStart < currentStop:
                        complementByChromDict[chrom].append((currentStart, currentStop, "nonExon%d" % complementIndex, "F", "nonExon"))

                    currentStart = featuresByChromDict[chrom][index][1]

                currentStop = maxStop
                complementByChromDict[chrom].append((currentStart, currentStop, "nonExon%d" % complementIndex, "F", "nonExon"))

        return (featuresByChromDict, complementByChromDict)
    else:
        return featuresByChromDict


def getLocusByChromDict(genome, upstream=0, downstream=0, useCDS=True,
                        additionalRegionsDict={}, ignorePseudo=False, upstreamSpanTSS=False,
                        lengthCDS=0, keepSense=False, adjustToNeighbor=True):
    """ return a dictionary of gene loci. Can be used to retrieve additional
    sequence upstream or downstream of gene, up to the next gene. Requires
    cistematic, obviously.
    Can filter-out pseudogenes and use additional regions outside of existing
    gene models. Use upstreamSpanTSS to overlap half of the upstream region
    over the TSS.
    If lengthCDS > 0 bp, e.g. X, return only the starting X bp from CDS. If
    lengthCDS < 0bp, return only the last X bp from CDS.
    """ 
    locusByChromDict = {}
    if upstream == 0 and downstream == 0 and not useCDS:
        print "getLocusByChromDict: asked for no sequence - returning empty dict"
        return locusByChromDict
    elif upstream > 0 and downstream > 0 and not useCDS:
        print "getLocusByChromDict: asked for only upstream and downstream - returning empty dict"
        return locusByChromDict
    elif lengthCDS != 0 and not useCDS:
        print "getLocusByChromDict: asked for partial CDS but not useCDS - returning empty dict"
        return locusByChromDict
    elif upstreamSpanTSS and lengthCDS != 0:
        print "getLocusByChromDict: asked for TSS spanning and partial CDS - returning empty dict"
        return locusByChromDict
    elif lengthCDS > 0 and downstream > 0:
        print "getLocusByChromDict: asked for discontinuous partial CDS from start and downstream - returning empty dict"
        return locusByChromDict
    elif lengthCDS < 0 and upstream > 0:
        print "getLocusByChromDict: asked for discontinuous partial CDS from stop and upstream - returning empty dict"
        return locusByChromDict

    genomeName = genome.genome
    featuresDict = getGeneFeatures(genome, additionalRegionsDict)
    for gid in featuresDict:
        featureList = featuresDict[gid]
        newFeatureList = []
        for (ftype, chrom, start, stop, sense) in featureList:
            newFeatureList.append((start, stop))

        if ignorePseudo and ftype == "PSEUDO":
            continue

        newFeatureList.sort()

        sense = featureList[0][-1]
        gstart = newFeatureList[0][0]
        gstop = newFeatureList[-1][1]
        glen = abs(gstart - gstop)
        if sense == "F":
            if not useCDS and upstream > 0:
                if upstreamSpanTSS:
                    if gstop > (gstart + upstream / 2):
                        gstop = gstart + upstream / 2
                else:
                    gstop = gstart
            elif not useCDS and downstream > 0:
                gstart = gstop

            if upstream > 0:
                if upstreamSpanTSS:
                    distance = upstream / 2
                else:
                    distance = upstream

                if adjustToNeighbor:
                    nextGene = genome.leftGeneDistance((genomeName, gid), distance * 2)
                    if nextGene < distance * 2:
                        distance = nextGene / 2

                distance = max(distance, 1)
                gstart -= distance

            if downstream > 0:
                distance = downstream
                if adjustToNeighbor:
                    nextGene = genome.rightGeneDistance((genomeName, gid), downstream * 2)
                    if nextGene < downstream * 2:
                        distance = nextGene / 2

                distance = max(distance, 1)
                gstop += distance

            if 0 < lengthCDS < glen:
                gstop = newFeatureList[0][0] + lengthCDS

            if lengthCDS < 0 and abs(lengthCDS) < glen:
                gstart = newFeatureList[-1][1] + lengthCDS
        else:
            if not useCDS and upstream > 0:
                if upstreamSpanTSS:
                    if gstart < (gstop - upstream / 2):
                        gstart = gstop - upstream / 2
                else:
                    gstart = gstop
            elif not useCDS and downstream > 0:
                    gstop = gstart

            if upstream > 0:
                if upstreamSpanTSS:
                    distance = upstream /2
                else:
                    distance = upstream

                if adjustToNeighbor:
                    nextGene = genome.rightGeneDistance((genomeName, gid), distance * 2)
                    if nextGene < distance * 2:
                        distance = nextGene / 2

                distance = max(distance, 1)
                gstop += distance

            if downstream > 0:
                distance = downstream
                if adjustToNeighbor:
                    nextGene = genome.leftGeneDistance((genomeName, gid), downstream * 2)
                    if nextGene < downstream * 2:
                        distance = nextGene / 2

                distance = max(distance, 1)
                gstart -= distance

            if 0 < lengthCDS < glen:
                gstart = newFeatureList[-1][-1] - lengthCDS

            if lengthCDS < 0 and abs(lengthCDS) < glen:
                gstop = newFeatureList[0][0] - lengthCDS

        if chrom not in locusByChromDict:
            locusByChromDict[chrom] = []

        if keepSense:
            locusByChromDict[chrom].append((gstart, gstop, gid, glen, sense))
        else:
            locusByChromDict[chrom].append((gstart, gstop, gid, glen))

    for chrom in locusByChromDict:
        locusByChromDict[chrom].sort()

    return locusByChromDict


def getGeneFeatures(genome, additionalRegionsDict):
    featuresDict = genome.getallGeneFeatures()
    if len(additionalRegionsDict) > 0:
        sortList = []
        for chrom in additionalRegionsDict:
            for region in additionalRegionsDict[chrom]:
                label = region.label
                if label not in sortList:
                    sortList.append(label)

                if label not in featuresDict:
                    featuresDict[label] = []
                    sense = "+"
                else:
                    sense = featuresDict[label][0][-1]

                featuresDict[label].append(("custom", chrom, region.start, region.stop, sense))

        for gid in sortList:
            featuresDict[gid].sort(cmp=lambda x,y:cmp(x[2], y[2]))

    return featuresDict


def computeRegionBins(regionsByChromDict, hitDict, bins, readlen, regionList=[],
                      normalizedTag=1., defaultRegionFormat=True, fixedFirstBin=-1,
                      binLength=-1):
    """ returns 2 dictionaries of bin counts and region lengths, given a dictionary of predefined regions,
        a dictionary of reads, a number of bins, the length of reads, and optionally a list of regions
        or a different weight / tag.
    """
    index = 0
    regionsBins = {}
    regionsLen = {}

    if defaultRegionFormat:
        regionIDField = 0
        startField = 1
        stopField = 2
        lengthField = 3
    else:
        startField = 0
        stopField = 1
        regionIDField = 2
        lengthField = 3

    senseField = 4

    print "entering computeRegionBins"
    if len(regionList) > 0:
        for readID in regionList:
            regionsBins[readID] = [0.] * bins
    else:
        for chrom in regionsByChromDict:
            for regionTuple in regionsByChromDict[chrom]:
                regionID = regionTuple[regionIDField]
                regionsBins[regionID] = [0.] * bins

    for chrom in hitDict:
        if chrom not in regionsByChromDict:
            continue

        for regionTuple in regionsByChromDict[chrom]:
            regionID = regionTuple[regionIDField]
            regionsLen[regionID] = regionTuple[lengthField]

        print "%s\n" % chrom
        startRegion = 0
        for read in hitDict[chrom]:
            tagStart = read["start"]
            weight = read["weight"]
            index += 1
            if index % 100000 == 0:
                print "read %d " % index,

            stopPoint = tagStart + readlen
            if startRegion < 0:
                startRegion = 0

            for regionTuple in regionsByChromDict[chrom][startRegion:]:
                start = regionTuple[startField]
                stop = regionTuple[stopField]
                regionID = regionTuple[regionIDField]
                rlen = regionTuple[lengthField]
                try:
                    rsense = regionTuple[senseField]
                except IndexError:
                    rsense = "F"

                if tagStart > stop:
                    startRegion += 1
                    continue

                if start > stopPoint:
                    startRegion -= 10
                    break

                if start <= tagStart <= stop:
                    if binLength < 1:
                        regionBinLength = rlen / bins
                    else:
                        regionBinLength = binLength

                    startdist = tagStart - start
                    if rsense == "F":
                        # we are relying on python's integer division quirk
                        binID = startdist / regionBinLength
                        if (fixedFirstBin > 0) and (startdist < fixedFirstBin):
                            binID = 0
                        elif fixedFirstBin > 0:
                            binID = 1

                        if binID >= bins:
                            binID = bins - 1

                        try:
                            regionsBins[regionID][binID] += normalizedTag * weight
                        except KeyError:
                            print "%s %s" % (regionID, str(binID))
                    else:
                        rdist = rlen - startdist
                        binID = rdist / regionBinLength
                        if (fixedFirstBin > 0) and (rdist < fixedFirstBin):
                            binID = 0
                        elif fixedFirstBin > 0:
                            binID = 1

                        if binID >= bins:
                            binID = bins - 1

                        try:
                            regionsBins[regionID][binID] += normalizedTag * weight
                        except KeyError:
                            print "%s %s" % (regionID, str(binID))

                    stopPoint = stop

    return (regionsBins, regionsLen)
