#
#  getfasta.py
#  ENRAGE
#

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
from commoncode import getMergedRegions, findPeak, getConfigParser, getConfigOption, getConfigIntOption, getConfigBoolOption
import ReadDataset
from cistematic.genomes import Genome

print "getfasta: version 3.5"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %s genome regionfile outfilename [options]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    genome = args[0]
    regionfile = args[1]
    outfilename = args[2]

    getfasta(genome, regionfile, outfilename, options.seqsize, options.minHitThresh,
             options.topRegions, options.maxsize, options.usePeaks, options.hitFile,
             options.doCache, options.doCompact)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--seqradius", type="int", dest="seqsize")
    parser.add_option("--minreads", type="int", dest="minHitThresh")
    parser.add_option("--returnTop", type="int", dest="topRegions")
    parser.add_option("--maxsize", type="int", dest="maxsize")
    parser.add_option("--usepeak", action="store_true", dest="usePeaks")
    parser.add_option("--dataset", dest="hitfile")
    parser.add_option("--cache", action="store_true", dest="doCache")
    parser.add_option("--compact", action="store_true", dest="doCompact")

    configParser = getConfigParser()
    section = "getfasta"
    seqsize = getConfigIntOption(configParser, section, "seqsize", 50)
    minHitThresh = getConfigIntOption(configParser, section, "minHitThresh", -1)
    topRegions = getConfigIntOption(configParser, section, "topRegions", 0)
    maxsize = getConfigIntOption(configParser, section, "maxsize", 300000000)
    usePeaks = getConfigBoolOption(configParser, section, "usePeaks", False)
    hitfile = getConfigOption(configParser, section, "hitFile", None)
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    doCompact = getConfigBoolOption(configParser, section, "doCompact", False)

    parser.set_defaults(seqsize=seqsize, minHitThresh=minHitThresh, topRegions=topRegions, maxsize=maxsize,
                        usePeaks=usePeaks, hitfile=hitfile, doCache=doCache, doCompact=doCompact)

    return parser

def getfasta(genome, regionfile, outfilename, seqsize=50, minHitThresh=-1, topRegions=0,
             maxsize=300000000, usePeaks=False, hitfile=None, doCache=False, doCompact=False):
    doDataset = False
    if hitfile is not None:
        if usePeaks:
            print "ignoring dataset and relying on peak data"
        else:
            doDataset = True

    if doCompact:
        mergedRegions = getMergedRegions(regionfile, minHits=minHitThresh, verbose=True,
                                      chromField=0, compact=True, keepPeak=usePeaks,
                                      returnTop=topRegions)
    else:
        mergedRegions = getMergedRegions(regionfile, minHits=minHitThresh, verbose=True,
                                      keepPeak=usePeaks, returnTop=topRegions)

    if usePeaks:
        ncregions = getRegionUsingPeaks(mergedRegions, minHitThresh, maxsize)
    elif doDataset:
        hitRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
        ncregions = getRegionUsingRDS(mergedRegions, hitRDS, minHitThresh, maxsize)
    else:
        ncregions = getDefaultRegion(mergedRegions, maxsize)

    writeFastaFile(ncregions, genome, outfilename, seqsize)


def writeFastaFile(ncregions, genome, outfilename, seqsize=50):
    hg = Genome(genome)
    outfile = open(outfilename, "w")
    for chrom in ncregions:
        for regionDict in ncregions[chrom]:
            rstart = regionDict["start"]
            rlen = regionDict["length"]
            topPos = regionDict["topPos"]
            if topPos[0] >= 0:
                newrstart = rstart + topPos[0] - seqsize
                newrlen = 2 * seqsize + 1
            else:
                newrstart = rstart
                newrlen = rlen

            seq2 = hg.sequence(chrom, newrstart, newrlen)
            outfile.write(">chr%s:%d-%d\n%s\n" % (chrom, newrstart, newrstart + newrlen, seq2))

    outfile.close()


def getDefaultRegion(regionDict, maxsize):
    ncregions = {}
    for chromosome in regionDict:
        ncregions[chromosome] = []

    for chromosome in regionDict:
        print "%s: processing %d regions" % (chromosome, len(regionDict[chromosome]))
        for region in regionDict[chromosome]:
            start = region.start
            length = region.length

            if length > maxsize:
                print "%s:%d-%d length %d > %d max region size - skipping" % (chromosome, start, region.stop, length, maxsize)
                continue

            resultDict = {"start": start,
                          "length": length,
                          "topPos": [-1]
            }
            ncregions[chromosome].append(resultDict)

    return ncregions


def getRegionUsingPeaks(regionDict, minHitThresh=-1, maxsize=300000000):

    ncregions = {}
    for chromosome in regionDict:
        ncregions[chromosome] = []

    for chromosome in regionDict:
        print "%s: processing %d regions" % (chromosome, len(regionDict[chromosome]))
        for region in regionDict[chromosome]:
            start = region.start
            length = region.length

            if length > maxsize:
                print "%s:%d-%d length %d > %d max region size - skipping" % (chromosome, start, region.stop, length, maxsize)
                continue

            topPos = region.peakPos - start
            if region.peakHeight > minHitThresh:
                resultDict = {"start": start,
                              "length": length,
                              "topPos": [topPos]
                }
                ncregions[chromosome].append(resultDict)

    return ncregions


def getRegionUsingRDS(regionDict, hitRDS, minHitThresh=-1, maxsize=300000000):

    readlen = hitRDS.getReadSize()

    ncregions = {}
    for chromosome in regionDict:
        ncregions[chromosome] = []

    for chromosome in regionDict:
        print "%s: processing %d regions" % (chromosome, len(regionDict[chromosome]))
        for region in regionDict[chromosome]:
            start = region.start
            stop = region.stop
            length = region.length

            if length > maxsize:
                print "%s:%d-%d length %d > %d max region size - skipping" % (chromosome, start, stop, length, maxsize)
                continue

            thechrom = "chr%s" % chromosome
            print "."
            hitDict = hitRDS.getReadsDict(chrom=thechrom, withWeight=True, doMulti=True, findallOptimize=True, start=start, stop=stop)
            print "hitDict length: %d", len(hitDict[thechrom])
            peak = findPeak(hitDict[thechrom], start, length, readlen)
            if peak.numHits > minHitThresh:
                resultDict = {"start": start,
                              "length": length,
                              "topPos": peak.topPos
                }
                ncregions[chromosome].append(resultDict)

    return ncregions


if __name__ == "__main__":
    main(sys.argv)