#
#  getSNPs.py
#  ENRAGE
#
# Originally written by: Wendy Lee
# Last modified: May 11th, 2009 by Ali Mortazavi

"""
    Get the matches and mismatches from the RDS file, and calculate the SNP thresholds uniqStartMin (Sl * readlength) and and totalRatio (Cl). 
    For each mismatch, choose the base change that occur most frequently (ie: has the highest number
    of independent reads)
    Threshold of Sl and Cl are from user input
    Sl = # of independent reads supporting a base change at position S 
    Cl = total # of all reads supporting a base change at position S / # of all # reads that pass through position S

    usage: python getSNPs.py samplerdsfile uniqStartMin totalRatioMin outfile [--nosplices] [--enforceChr] [--cache pages] where

    uniqStartMin = # of independent reads supporting a base change at position S
    totalRatioMin = total # of reads supporting a base change at position S / total # reads that pass through position S
"""

import sys
import optparse
from commoncode import writeLog, getConfigParser, getConfigBoolOption, getConfigIntOption
import ReadDataset

print "getSNPs: version 3.6"

try:
    import psyco
    psyco.full()
except:
    print "psyco is not running"
    pass

def usage():
    print __doc__


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = __doc__

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 4:
        usage()
        sys.exit(2)

    hitfile = args[0]
    uniqStartMin = float(args[1])
    totalRatioMin = float(args[2])
    outfilename = args[3]

    if options.cachePages > 0:
        doCache = True
    else:
        doCache = False

    writeSNPsToFile(hitfile, uniqStartMin, totalRatioMin, outfilename, doCache, options.cachePages, options.doSplices, options.forceChr)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--nosplices", action="store_false", dest="doSplices")
    parser.add_option("--enforceChr", action="store_true", dest="forceChr")
    parser.add_option("--cache", type="int", dest="cachePages")

    configParser = getConfigParser()
    section = "getSNPs"
    doSplices = getConfigBoolOption(configParser, section, "doSplices", True)
    forceChr = getConfigBoolOption(configParser, section, "forceChr", False)
    cachePages = getConfigIntOption(configParser, section, "cachePages", 0)

    parser.set_defaults(doSplices=True, forceChr=False, cachePages=0)

    return parser


def writeSNPsToFile(hitfile, uniqStartMin, totalRatioMin, outfilename, doCache, cachePages=0, doSplices=True, forceChr=False):
    writeLog("snp.log", sys.argv[0], "rdsfile: %s uniqStartMin: %1.2f totalRatioMin: %1.2f" % (hitfile, uniqStartMin, totalRatioMin))

    outfile  = open(outfilename, "w")
    header = "#Sl\tCl\tchrom\tpos\tmatch\tuniqMis\t\ttotalMis\tchange" 
    outfile.write(header + "\n")

    snpPropertiesList = getSNPs(hitfile, uniqStartMin, totalRatioMin, doCache, cachePages, doSplices, forceChr)
    for snpEntry in snpPropertiesList:
        outline = "%1.2f\t%1.2f\t%s\t%d\t%d\t%d\t\t%d\t%s\n" % snpEntry
        print outline
        outfile.write(outline + "\n")
        outfile.flush() 

    outfile.close()

    writeLog("snp.log", sys.argv[0], "%d candidate SNPs\n" % len(snpPropertiesList))


def getSNPs(hitfile, uniqStartMin, totalRatioMin, doCache, cachePages=0, doSplices=True, forceChr=False):

    hitRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
    if cachePages > 20000:
        hitRDS.setDBcache(cachePages)

    snpPropertiesList = []
    readLength = hitRDS.getReadSize() 
    chromList = hitRDS.getChromosomes()

    for chrom in chromList:
        if doNotProcessChromosome(forceChr, chrom):
            continue

        matchDict = getMatchDict(hitRDS, chrom, doSplices)
        print "got match dict for %s " % chrom
        mismatchDict = getMismatchDict(hitRDS, chrom, doSplices)
        print "got mismatch dict for %s " % chrom
        mismatchPositions = mismatchDict.keys()
        mismatchPositions.sort()
        for position in mismatchPositions:
            totalCount = mismatchDict[position]["totalCount"]
            uniqBaseDict = mismatchDict[position]["uniqBaseDict"]
            totalBaseDict = mismatchDict[position]["totalBaseDict"]
            highestCount = 0
            highestBaseChange = "N-N"
            highestTotalCount = 0
            for baseChange in uniqBaseDict:
                if totalBaseDict[baseChange] > highestTotalCount:
                    highestBaseChange = baseChange
                    highestCount = uniqBaseDict[baseChange]
                    highestTotalCount = totalBaseDict[baseChange]

            Cl = 0.
            matchCount = 0
            if highestCount >= uniqStartMin:
                for matchpos in xrange(position - readLength + 1, position + 1):
                    try:
                        matchCount += len([mstop for mstop in matchDict[matchpos] if position <= mstop])
                    except:
                        pass

                matchCount -= totalCount
                if matchCount < 0:
                    matchCount = 0

                Sl = highestCount/float(readLength)
                Cl = highestTotalCount/float(highestTotalCount + matchCount)
                if Cl >= totalRatioMin:
                    snpProperties = (Sl, Cl, chrom, position, matchCount, highestCount, highestTotalCount, highestBaseChange)
                    snpPropertiesList.append(snpProperties)

    return snpPropertiesList


def doNotProcessChromosome(forceChr, chromosome):
    if forceChr:
        if chromosome[:3] != "chr":
            return True
    else:
        return False


def getMatchDict(rds, chrom, withSplices=True):
    spliceDict = {}
    readDict = {}
    finalDict = {}

    try:
        readDict = rds.getReadsDict(fullChrom=True, bothEnds=True, noSense=True, chrom=chrom)
    except:
        readDict[chrom] = []

    for read in readDict[chrom]:
        start = read["start"]
        stop = read["stop"]
        if finalDict.has_key(start):
            finalDict[start].append(stop)
        else:
            finalDict[start] = [stop]

    if withSplices:
        try:
            spliceDict = rds.getSplicesDict(noSense=True, fullChrom=True, chrom=chrom, splitRead=True)
        except:
            spliceDict[chrom] = []

        for read in spliceDict[chrom]:
            try:
                start = read["startL"]
                stop = read["stopL"]
            except KeyError:
                start = read["startR"]
                stop = read["stopR"]

            if finalDict.has_key(start):
                finalDict[start].append(stop)
            else:
                finalDict[start] = [stop]

    return finalDict


def getMismatchDict(rds, chrom, withSplices=True):
    mismatchDict = {}
    spliceDict = rds.getMismatches(mischrom=chrom, useSplices=withSplices)
    for (start, change_at, change_base, change_from) in spliceDict[chrom]:
        change = "%s-%s" % (change_base, change_from)
        uniqueReadCount = 1
        totalCount = 1
        back = "%s:%s" % (str(start), change)
        uniqBaseDict = {change: 1}
        totalBaseDict = {change: 1}
        if mismatchDict.has_key(change_at):
            (uniqueReadCount, totalCount, back, uniqBaseDict, totalBaseDict) = mismatchDict[change_at]
            pos = "%s:%s" % (str(start), change)
            totalCount += 1
            if totalBaseDict.has_key(change): 
                totalBaseDict[change] += 1

            if pos not in back:
                uniqueReadCount += 1
                if uniqBaseDict.has_key(change):
                    uniqBaseDict[change] += 1 # dict contains total unique read counts

                back = "%s,%s" % (back, pos)

        mismatchDict[change_at] = {"uniqueReadCount": uniqueReadCount,
                                   "totalCount": totalCount, 
                                   "back": back,
                                   "uniqBaseDict": uniqBaseDict,
                                   "totalBaseDict": totalBaseDict
        }

    return mismatchDict


if __name__ == "__main__":
    main(sys.argv)