import sys, optparse
try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

from cistematic.core.motif import Motif, hasMotifExtension
from cistematic.core import complement
from cistematic.genomes import Genome
from commoncode import getMergedRegions, findPeak, getConfigParser, getConfigOption, getConfigBoolOption
import ReadDataset

print "getallsites: version 2.5"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog genome motifFile motThreshold regionfile siteOutfile [options]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 5:
        print usage
        sys.exit(1)

    genome = args[0]
    motfilename = args[1]
    motThreshold = float(args[2])
    infilename = args[3]
    outfilename = args[4]

    getallsites(genome, motfilename, motThreshold, infilename, outfilename, options.chipfilename,
                options.doCache, options.bestOnly, options.usePeak, options.printSeq, options.doMarkov1,
                options.useRank, options.noMerge)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--dataset", dest="chipfilename")
    parser.add_option("--cache", action="store_true", dest="doCache")
    parser.add_option("--best", action="store_true", dest="bestOnly",
                      help="only report the best position for each region")
    parser.add_option("--usepeak", action="store_true", dest="usePeak",
                      help="use peak position and height from regions file")
    parser.add_option("--printseq", action="store_true", dest="printSeq")
    parser.add_option("--nomerge", action="store_true", dest="noMerge")
    parser.add_option("--markov1", action="store_true", dest="doMarkov1")
    parser.add_option("--rank", type="int", dest="useRank",
                      help="return region ranking based on peak height ranking [requires --usepeak]")

    configParser = getConfigParser()
    section = "getallsites"
    chipfilename = getConfigOption(configParser, section, "chipfilename", "")
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    bestOnly = getConfigBoolOption(configParser, section, "bestOnly", False)
    usePeak = getConfigBoolOption(configParser, section, "usePeak", False)
    printSeq = getConfigBoolOption(configParser, section, "printSeq", False)
    doMarkov1 = getConfigBoolOption(configParser, section, "doMarkov1", False)
    useRank = getConfigBoolOption(configParser, section, "useRank", False)
    noMerge = getConfigBoolOption(configParser, section, "noMerge", False)

    parser.set_defaults(chipfilename=chipfilename, doCache=doCache, bestOnly=bestOnly, usePeak=usePeak,
                        printSeq=printSeq, doMarkov1=doMarkov1, useRank=useRank, noMerge=noMerge)

    return parser


def getallsites(genome, motfilename, motThreshold, infilename, outfilename, chipfilename="",
                doCache=False, bestOnly=False, usePeak=False, printSeq=False, doMarkov1=False,
                useRank=False, noMerge=False):

    if motThreshold < 1.0 and doMarkov1:
        print "motThreshold should be between 1.0 and 10.0 for markov1"
        sys.exit(1)
    elif motThreshold < 55.0 and not doMarkov1:
        print "motThreshold should be between 55 and 99 for a regular PSFM"
        sys.exit(1)

    if hasMotifExtension:
        print "will use cistematic.core.motif C-extension to speed up motif search"

    if useRank and usePeak:
        print "will return region ranking based on peak height ranking"
        useRank = True
    else:
        print "ignoring '-rank': can only use ranking when using a region file with peak position and height"
        useRank = False

    mot = Motif("", motifFile=motfilename)
    motLen = len(mot)
    bestScore = mot.bestConsensusScore()

    hg = Genome(genome)

    # minHits=-1 will force regions to be used regardless
    # maxDist= 0 prevents merging of non-overlapping regions
    if noMerge:
        regions = getMergedRegions(infilename, maxDist=0, minHits=-1, verbose=True, doMerge=False, keepPeak=usePeak)
    else:
        regions = getMergedRegions(infilename, maxDist=0, minHits=-1, verbose=True, keepPeak=usePeak)

    doRDS = False
    if chipfilename:
        doRDS = True

    if doRDS:
        hitRDS = ReadDataset.ReadDataset(chipfilename, verbose = True, cache=doCache)

    outfile = open(outfilename, "w")

    regionList = []

    for chrom in regions:
        if "rand" in chrom or "M" in chrom:
            continue

        if usePeak:
            for region in regions[chrom]:
                regionList.append((region.peakHeight, chrom, region.start, region.length, region.peakPos))
        else:
            for region in regions[chrom]:
                regionList.append((chrom, region.start, region.length))

    if usePeak:
        regionList.sort()
        regionList.reverse()

    notFoundIndex = 0
    currentChrom = ""
    count = 0
    for tuple in regionList:
        if usePeak:
            (rpeakheight, rchrom, start, length, rpeakpos) = tuple
        else:
            (rchrom, start, length) = tuple

        try:
            seq = hg.sequence(rchrom, start, length)
        except:
            print "couldn't retrieve %s %d %d - skipping" % (rchrom, start, length)
            continue

        count += 1
        numHits = -1
        if usePeak:
            peakpos = rpeakpos
            if useRank:
                numHits = count
            else:
                numHits = rpeakheight
        elif doRDS:
            if rchrom != currentChrom:
                fullchrom = "chr" + rchrom
                hitDict = hitRDS.getReadsDict(chrom=fullchrom)
                currentChrom = rchrom

            (topPos, numHits, smoothArray, numPlus) = findPeak(hitDict[rchrom], start, length)
            if len(topPos) == 0:
                print "topPos error"

            peakpos = topPos[0]

        found = []
        if doMarkov1:
            matches = mot.locateMarkov1(seq, motThreshold)
        else:
            matches = mot.locateMotif(seq, motThreshold)

        for (pos, sense) in matches:
            alreadyFound = False
            for (fpos, fdist) in found:
                if pos + start == fpos:
                    alreadyFound = True

            if not alreadyFound:
                if usePeak:
                    found.append((start + pos, start + pos  + motLen/2 - peakpos))
                elif doRDS:
                    found.append((start + pos, pos  + motLen/2 - peakpos))
                else:
                    found.append((start + pos, -1))

        foundValue = False
        bestList = []
        for (foundpos, peakdist) in found:
            seq = hg.sequence(rchrom, foundpos, motLen)
            foundValue = True
            (front, back) = mot.scoreMotif(seq)
            sense = "+"
            if front >= back:
                score = int(100 * front / bestScore)
            else:
                score = int(100 * back / bestScore)
                sense = "-"
                seq = complement(seq)

            if printSeq:
                print seq

            outline = "chr%s:%d-%d\t%d\t%d\t%d\tchr%s:%d-%d\t%s\n" % (rchrom, foundpos, foundpos + motLen - 1, score, numHits, peakdist, rchrom, start, start + length, sense)
            if bestOnly:
                bestList.append((abs(peakdist), outline))
            else:
                outfile.write(outline)

        if bestOnly and foundValue:
            bestList.sort()
            outfile.write(bestList[0][1])

        if not foundValue:
            if printSeq:
                print "could not find a %s site for %s:%d-%d" % (mot.tagID, rchrom, start, start+ length)

            notFoundIndex += 1
        if (count % 10000) == 0 and not printSeq:
            print count

    outfile.close()
    print "did not find motif in %d regions" % notFoundIndex


if __name__ == "__main__":
    main(sys.argv)
