#
#  makewiggle.py
#  ENRAGE
#
import sys
import optparse
import ReadDataset
from commoncode import getConfigParser, getConfigOption, getConfigBoolOption, getConfigIntOption, getConfigFloatOption

print "makewiggle: version 6.8"

try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %s name rdsfile outfilename [options]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    name = args[0]
    hitfilename = args[1]
    outfilename = args[2]

    makewiggle(name, hitfilename, outfilename, options.doNormalize, options.color, options.altColor,
               options.limitChrom, options.shift, options.doSplit, options.listfilename, options.listPrefix,
               options.group, options.startPriority, options.skipRandom, options.withMulti,
               options.withSplices, options.doSingle, options.cachePages, options.enforceChr, options.strand,
               options.chunk)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--raw", action="store_false", dest="doNormalize")
    parser.add_option("--color", dest="color")
    parser.add_option("--altcolor", dest="altColor")
    parser.add_option("--chrom", dest="limitChrom")
    parser.add_option("--shift", type="int", dest="shift")
    parser.add_option("--split", action="store_true", dest="doSplit")
    parser.add_option("--listfile", dest="listfilename")
    parser.add_option("--listprefix", dest="listPrefix")
    parser.add_option("--group", dest="group")
    parser.add_option("--startPriority", type="float", dest="startPriority")
    parser.add_option("--skiprandom", action="store_true", dest="skipRandom")
    parser.add_option("--nomulti", action="store_false", dest="withMulti")
    parser.add_option("--splices", action="store_true", dest="withSplices")
    parser.add_option("--singlebase", action="store_true", dest="doSingle")
    parser.add_option("--cache", type="int", dest="cachePages")
    parser.add_option("--enforceChr", action="store_true", dest="enforceChr")
    parser.add_option("--stranded", dest="strand")
    parser.add_option("--maxchunk", type="int", dest="chunk")

    configParser = getConfigParser()
    section = "makewiggle"
    doNormalize = getConfigBoolOption(configParser, section, "doNormalize", True)
    color = getConfigOption(configParser, section, "color", None)
    altColor = getConfigOption(configParser, section, "altColor", "")
    limitChrom = getConfigOption(configParser, section, "limitChrom", None)
    shift = getConfigIntOption(configParser, section, "shift", 0)
    doSplit = getConfigBoolOption(configParser, section, "doSplit", False)
    listfilename = getConfigOption(configParser, section, "listfilename", None)
    listPrefix = getConfigOption(configParser, section, "listPrefix", "")
    group = getConfigOption(configParser, section, "group", "")
    startPriority = getConfigFloatOption(configParser, section, "startPriority", 0.01)
    skipRandom = getConfigBoolOption(configParser, section, "skipRandom", False)
    withMulti = getConfigBoolOption(configParser, section, "withMulti", True)
    withSplices = getConfigBoolOption(configParser, section, "withSplices", False)
    doSingle = getConfigBoolOption(configParser, section, "doSingle", False)
    cachePages = getConfigIntOption(configParser, section, "cachePages", -1)
    enforceChr = getConfigBoolOption(configParser, section, "enforceChr", False)
    strand = getConfigOption(configParser, section, "strand", None)
    chunk = getConfigIntOption(configParser, section, "chunk", 20)

    parser.set_defaults(doNormalize=doNormalize, color=color, altColor=altColor, limitChrom=limitChrom,
                        shift=shift, doSplit=doSplit, listfilename=listfilename, listPrefix=listPrefix,
                        group=group, startPriority=startPriority, skipRandom=skipRandom, withMulti=withMulti,
                        withSplices=withSplices, doSingle=doSingle, cachePages=cachePages, enforceChr=enforceChr,
                        strand=strand, chunk=chunk)

    return parser


def makewiggle(name, hitfilename, outfilename, doNormalize=True, color=None, altColor="",
               limitChrom=None, shift=0, doSplit=False, listfilename=None, listPrefix="",
               group="", startPriority=0.01, skipRandom=False, withMulti=True, withSplices=False,
               doSingle=False, cachePages=-1, enforceChr=False, strand=None, chunk=20):

    priorityIncrement = 0.01
    wigType = "bedGraph"

    if color is not None:
        colorString = " color=%s" % color
    else:
        colorString = ""

    if altColor:
        colorString += " altcolor=%s" % altColor

    doList = False
    if listfilename is not None:
        doList = True
    
    chromLimit = False
    if limitChrom is not None:
        chromLimit = True

    if group:
        groupName = "group=%s" % group

    doCache = False
    if cachePages > 0:
        doCache = True

    maxSpan = chunk * 1000000

    isStranded = False
    strandedDirection = "both"
    if strand is not None:
        isStranded = True
        if strand == "plus":
            strandedDirection = "plusOnly"
        elif strand == "minus":
            strandedDirection = "minusOnly"

        print "will keep track of %s strand(s)" % strandedDirection

    if shift:
        print "Will shift reads by +/- %d bp according to their sense" % shift
        name += "shift=%d" % shift
    
    hitRDS = ReadDataset.ReadDataset(hitfilename, verbose=True, cache=doCache)

    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)

    readlen = hitRDS.getReadSize()

    if doNormalize:
        normalizeBy = len(hitRDS) / 1000000.
    else:
        normalizeBy = 1.

    if doList:
        listfile = open(listfilename, "w")

    priority = startPriority    
    if not doSplit:
        outfile = open(outfilename, "w")
        if doList:
            listfile.write("%s%s\n" % (listPrefix, outfilename))

        outfile.write('track type=%s name="%s" %s priority=%.3f visibility=full%s\n' % (wigType, name, groupName, priority, colorString)) 

    chromList = hitRDS.getChromosomes()
    chromList.sort()
    for achrom in chromList:
        if enforceChr and ("chr" not in achrom):
            continue

        if chromLimit and achrom != limitChrom:
            continue

        if skipRandom and "random" in achrom:
            continue

        if doSplit:
            outfile = open("%s.%s" % (outfilename, achrom), "w")
            if doList:
                listfile.write("%s%s.%s\n" % (listPrefix, outfilename, achrom))

            outfile.write('track type=%s name="%s %s" %s priority=%.3f visibility=full%s\n' % (wigType, name, achrom, groupName, priority, colorString))   
            priority += priorityIncrement  

        lastNT = hitRDS.getMaxCoordinate(achrom, doMulti=withMulti, doSplices=withSplices) + readlen
        spanStart = 0

        previousVal = 0
        previousStart = 1
        lineIndex = 0
        for spanStop in xrange(maxSpan, lastNT+maxSpan, maxSpan):
            if spanStop > lastNT:
                spanStop = lastNT

            print achrom, spanStart, spanStop
            chromModel = hitRDS.getChromProfile(achrom, spanStart, spanStop, withMulti, withSplices, normalizeBy, isStranded, strandedDirection, shiftValue=shift)

            for index in xrange(len(chromModel)):
                currentVal = chromModel[index]
                if doSingle:
                    outline = "%s %d %.4f\n" % (achrom, spanStart + index, currentVal)
                    outfile.write(outline)
                    continue

                if currentVal == previousVal:
                    continue

                if currentVal != previousVal:
                    if previousVal != 0:
                        lastpos = index + spanStart
                        outline = "%s %d %d %.4f\n" % (achrom, previousStart, lastpos, previousVal)
                        outfile.write(outline)
                        lineIndex += 1

                    previousVal = currentVal
                    previousStart = index + spanStart

            currentVal = 0
            del chromModel
            spanStart = spanStop + 1

        if doSplit:
            outfile.close()

        if doSingle:
            print index + 1
        else:
            print lineIndex

    if not doSplit:
        outfile.close()

    if doList:
        listfile.close()


if __name__ == "__main__":
    main(sys.argv)