import sys, optparse

try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

from cistematic.core import complement
from cistematic.core.motif import Motif
from cistematic.genomes import Genome
from commoncode import getMergedRegions, findPeak, getConfigParser, getConfigOption, getConfigBoolOption, getConfigFloatOption
import ReadDataset
from pylab import *
import matplotlib

print 'getallNRSE: version 3.5'

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %s genome regionfile siteOutfile [options]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    genome = argv[0]
    infilename = args[1]
    outfilename = args[2]

    getallNRSE(genome, infilename, outfilename, options.chipfilename,
               options.minHeight, options.minFraction, options.plotname,
               options.doCache, options.normalize, options.doVerbose,
               options.doMarkov1, options.maxpeakdist, options.fullOnly,
               options.motifDir)


def getParser(usage):

    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--dataset", dest="chipfilename")
    parser.add_option("--min", type="float", dest="minHeight")
    parser.add_option("--minfraction", type="float", dest="minFraction")
    parser.add_option("--plot", dest="plotname")
    parser.add_option("--cache", action="store_true", dest="doCache")
    parser.add_option("--raw", action="store_false", dest="normalize")
    parser.add_option("--verbose", action="store_true", dest="doVerbose")
    parser.add_option("--markov1", action="store_true", dest="doMarkov1")
    parser.add_option("--peakdist", type="int", dest="maxpeakdist")
    parser.add_option("--fullOnly", action="store_true", dest="fullOnly")
    parser.add_option("--motifdir", dest="motifDir")

    configParser = getConfigParser()
    section = "getallNRSE"
    chipfilename = getConfigOption(configParser, section, "chipfilename", "")
    minHeight = getConfigFloatOption(configParser, section, "minHeight", -2.)
    minFraction = getConfigFloatOption(configParser, section, "minFraction", -2.)
    plotname = getConfigOption(configParser, section, "plotname", "")
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    normalize = getConfigBoolOption(configParser, section, "normalize", True)
    doVerbose = getConfigBoolOption(configParser, section, "doVerbose", False)
    doMarkov1 = getConfigBoolOption(configParser, section, "doMarkov1", False)
    maxpeakdist = getConfigOption(configParser, section, "maxpeakdist", None)
    fullOnly = getConfigBoolOption(configParser, section, "fullOnly", False)
    motifDir = getConfigOption(configParser, section, "motifDir", "./")

    parser.set_defaults(chipfilename=chipfilename, minHeight=minHeight, minFraction=minFraction, plotname=plotname,
                        doCache=doCache, normalize=normalize, doVerbose=doVerbose, doMarkov1=doMarkov1,
                        maxpeakdist=maxpeakdist, fullOnly=fullOnly, motifDir=motifDir)

    return parser


def getallNRSE(genome, infilename, outfilename, chipfilename="", minHeight=-2.,
               minFraction=-2., plotname="", doCache=False, normalize=True,
               doVerbose=False, doMarkov1=False, maxpeakdist=None, fullOnly=False,
               motifDir="./"):

    doPlot = False
    if plotname:
        matplotlib.use("Agg")
        doPlot = True

    if motifDir[-1] != "/":
        motifDir += "/"

    doDataset = False
    normalizeBy = 1
    if chipfilename:
        hitRDS = ReadDataset.ReadDataset(chipfilename, verbose=doVerbose, cache=doCache)
        doDataset = True
        if normalize:
            normalizeBy = len(hitRDS) / 1000000.

    if minFraction > 1.:
        minFraction /= 100.
        print "scaling minFraction to %.2f" % minFraction

    if maxpeakdist is not None:
        enforcePeakDist = True
    else:
        enforcePeakDist = False
        maxpeakdist = 101

    mot = Motif("", motifFile="%sNRSE3.mot" % motifDir)
    motL = Motif("", motifFile="%sNRSE3left.mot" % motifDir)
    motR = Motif("", motifFile="%sNRSE3right.mot" % motifDir)
    bestScore = mot.bestConsensusScore()
    bestLeft = motL.bestConsensusScore()
    bestRight = motR.bestConsensusScore()

    hg = Genome(genome)

    regions = getMergedRegions(infilename, maxDist=0, minHits=-1, verbose=doVerbose, doMerge=False)

    outfile = open(outfilename, "w")
    outfile.write("#dataset: %s\tregions:%s\tnormalize: %s\tmarkov1: %s\n" % (chipfilename, infilename, normalize, doMarkov1))
    outfile.write("#enforcePeakDist: %s\tpeakdist: %d bp\tfullOnly: %d bp\n" % (enforcePeakDist, maxpeakdist, fullOnly))
    outfile.write("#site\tscore\tleftscore\trightscore\tRPM\tpeakDist\ttype\theight\tfractionHeight\tregion\tsense\tseq\n")

    index = 0
    regionList = []

    for rchrom in regions:
        if "rand" in rchrom or "M" in rchrom or "hap" in rchrom:
            continue

        for region in regions[rchrom]:
            regionList.append((rchrom, region.start, region.length))

    notFoundIndex = 0
    currentChrom = ""
    for (rchrom, start, length) in regionList:
        seq = hg.sequence(rchrom, start, length)
        if doDataset:
            if rchrom != currentChrom:
                fullchrom = "chr" + rchrom
                hitDict = hitRDS.getReadsDict(chrom=fullchrom, withWeight=True, doMulti=True)
                currentChrom = rchrom

            peak = findPeak(hitDict[rchrom], start, length, doWeight=True)
            topPos = peak.topPos
            numHits = peak.numHits
            if len(topPos) == 0:
                print "topPos error"

            peakpos = topPos[0]
            peakscore = peak.smoothArray[peakpos]
            if peakscore == 0.:
                peakscore = -1.

            if normalize:
                numHits /= normalizeBy
                peakscore /= normalizeBy
        else:
            peakpos = length
            peakscore = -1
            numHits = 0
            smoothArray = [0.] * length

        found = []
        if doMarkov1:
            lefts = motL.locateMarkov1(seq, 3.)
            rights = motR.locateMarkov1(seq, 3.)
        else:
            lefts = motL.locateMotif(seq, 70)
            rights = motR.locateMotif(seq, 70)

        allhalfs = [(v0, v1, "L") for (v0, v1) in lefts] + [(v0, v1, "R") for (v0, v1) in rights]
        allhalfs.sort()

        # look for canonicals and non-canonicals
        if len(allhalfs) > 1:
            (firstpos, firstsense, firsttype) = allhalfs[0]
            for (secondpos, secondsense, secondtype) in allhalfs[1:]:
                if enforcePeakDist:
                    withinDistance = False
                    for aPos in topPos:
                        if abs(firstpos - aPos) < maxpeakdist or abs(secondpos - aPos) < maxpeakdist:
                            withinDistance = True
                    if not withinDistance:
                        firstpos = secondpos
                        firstsense = secondsense
                        firsttype = secondtype
                        continue

                if firsttype == "L":
                    dist = secondpos - firstpos + 2
                else:
                    dist = secondpos - firstpos -1

                if firstsense == secondsense and dist in [9, 10, 11, 16, 17, 18, 19]:
                    if (firsttype == "L" and secondtype == "R" and secondsense == "F"):
                        found.append((start + firstpos, firstpos - peakpos + (dist + 10)/2, dist))

                    if (firsttype == "R" and secondtype == "L" and secondsense == "R"):
                        found.append((start + firstpos, firstpos  - peakpos + (dist + 10)/2, dist))

                firstpos = secondpos
                firstsense = secondsense
                firsttype = secondtype

        # did we miss any 70%+ matches ?
        if doMarkov1:
            matches = mot.locateMarkov1(seq, 3.5)
        else:
            matches = mot.locateMotif(seq, 70)

        for (pos, sense) in matches:
            alreadyFound = False
            for (fpos, fpeakdist, fdist) in found:
                if pos + start == fpos:
                    alreadyFound = True

            if not alreadyFound:
                if enforcePeakDist:
                    withinDistance = False
                    for aPos in topPos:
                        if abs(firstpos - aPos) < maxpeakdist or abs(secondpos - aPos) < maxpeakdist:
                            withinDistance = True
                            thePos = aPos

                    if withinDistance:
                        found.append((start + pos, pos - thePos + 10, 11))

                else:
                    found.append((start + pos, pos - peakpos + 10, 11))

        # we'll now accept half-sites within maxpeakdist bp of peak if using a dataset, else all
        if len(found) == 0 and not fullOnly:
            bestone = -1
            if not doDataset:
                bestdist = maxpeakdist
            else:
                bestdist = length

            index = 0
            for (pos, sense, type) in allhalfs:
                if doDataset:
                    for aPos in topPos:
                        if abs(pos - aPos) < bestdist:
                            bestdist = abs(pos - aPos)
                            bestone = index
                            peakpos = aPos
                else:
                    found.append((start + allhalfs[index][0], allhalfs[index][0] + 5 - peakpos, 0))

                index += 1

            if (doDataset and bestdist < 101):
                try:
                    found.append((start + allhalfs[bestone][0], allhalfs[bestone][0] + 5 - peakpos, 0))
                except:
                    continue

        # see if we found an acceptable match
        foundValue = False
        for (foundpos, posdist, dist) in found:
            # get a score for 21-mer, report
            seq = hg.sequence(rchrom, foundpos, 21)
            # height will be measured from the center of the motif
            height = -2.
            for pos in range(10 + dist):
                try:
                    currentHeight = smoothArray[int(peakpos + posdist + pos)]
                except: 
                    pass

                if currentHeight > height:
                    height = currentHeight

            if normalize:
                height /= normalizeBy

            fractionHeight = height / peakscore
            if height < minHeight or fractionHeight < minFraction:
                continue

            foundValue = True
            (front, back) = mot.scoreMotif(seq)
            sense = "+"
            if front > back:
                score = int(100 * front / bestScore)
                theseq = hg.sequence(rchrom, foundpos, 10 + dist)
            else:
                score = int(100 * back / bestScore)
                theseq = complement(hg.sequence(rchrom, foundpos, 10 + dist))
                sense = "-"
                foundpos + 1

            leftScore = -1.
            rightScore = -1.
            leftseq = ""
            rightseq = ""
            if dist > 0:
                testseq = hg.sequence(rchrom, foundpos, 10 + dist)
                if sense == "-":
                    testseq = complement(testseq)

                leftseq = testseq[:9]
                rightseq = testseq[dist-2:]
            elif dist == 0:
                testseq = hg.sequence(rchrom, foundpos, 12)
                if sense == "-":
                    testseq = complement(testseq)
                    leftseq = testseq[3:]
                else:
                    leftseq = testseq[:9]

                rightseq = testseq

            (lfront, lback) = motL.scoreMotif(leftseq)
            (rfront, rback) = motR.scoreMotif(rightseq)
            if lfront > lback:
                leftScore = int(100 * lfront) / bestLeft
                leftSense = "+"
            else:
                leftScore = int(100 * lback) / bestLeft
                leftSense = "-"

            if rfront > rback:
                rightScore = int(100 * rfront) / bestRight
                rightSense = "+"
            else:
                rightScore = int(100 * rback) / bestRight
                rightSense = "-"

            if dist != 11:
                if rightScore > leftScore:
                    sense = rightSense
                else:
                    sense = leftSense

                if sense == "-" and dist > 0:
                    theseq = complement(hg.sequence(rchrom, foundpos, 10 + dist))

            outline = "chr%s:%d-%d\t%d\t%d\t%d\t%d\t%d\t%d\t%.2f\t%.2f\tchr%s:%d-%d\t%s\t%s" % (rchrom, foundpos, foundpos + 9 + dist, score, leftScore, rightScore, numHits, posdist, dist, height, fractionHeight, rchrom, start, start + length, sense, theseq)
            if doVerbose:
                print outline

            outfile.write(outline + "/n")

        # we didn't find a site - draw region
        if not foundValue and doVerbose:
            outline = "#no predictions for %s:%d-%d %d %.2f" % (rchrom, start, start + length, numHits, peakscore)
            print outline
            outfile.write(outline + "\n")

        if not foundValue and doPlot:
            drawarray = [val + notFoundIndex for val in smoothArray]
            drawpos = [drawarray[val] for val in topPos]
            plot(drawarray, "b")
            plot(topPos, drawpos, "r.")
            goodmatches = mot.locateMotif(seq, 75)
            if len(goodmatches) > 0:
                print topPos
                print goodmatches
                drawgood = []
                drawgoody = []
                for (mstart, sense) in goodmatches:
                    drawgood.append(mstart)
                    drawgoody.append(drawarray[mstart])

                plot(drawgood, drawgoody, "g.")

            notFoundIndex -= 30

    outfile.close()
    if doPlot:
        savefig(plotname)


if __name__ == "__main__":
    main(sys.argv)