#
#  getSNPGeneInfo.py
#  ENRAGE
#
# This script look for the gene info and expression level for the snps.
# Written by: Wendy Lee
# Written on: August 7th, 2008

try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

import sys
import optparse
import string
from cistematic.core import genesIntersecting, cacheGeneDB, uncacheGeneDB
from commoncode import getGeneInfoDict, getConfigParser, getConfigBoolOption, getConfigIntOption

print "getSNPGeneInfo: version 4.6"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog genome snpsfile rpkmfile dbsnp_geneinfo_outfile [--cache] [--withoutsense] [--flank bp]"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 4:
        print usage
        sys.exit(2)

    genome = args[0]
    infilename = args[1]
    rpkmfilename = args[2]
    outfilename = args [3]

    writeSNPGeneInfo(genome, infilename, rpkmfilename, outfilename, options.doCache, options.withSense, options.flankBP)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--cache", action="store_true", dest="cachePages")
    parser.add_option("--withoutsense", action="store_false", dest="withSense")
    parser.add_option("--flank", type="int", dest="flankBP")

    configParser = getConfigParser()
    section = "getSNPGeneInfo"
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    withSense = getConfigBoolOption(configParser, section, "withSense", True)
    flankBP = getConfigIntOption(configParser, section, "flankBP", 0)

    parser.set_defaults(doCache=doCache, withSense=withSense, flankBP=flankBP)

    return parser


def writeSNPGeneInfo(genome, infilename, rpkmfilename, outfilename, doCache=False, withSense=True, flankBP=0):

    outList = getSNPGeneInfo(genome, infilename, rpkmfilename, doCache, withSense, flankBP)
    outfile = open(outfilename, "w")

    for outputLine in outList:
        outfile.write("%s\n" % outputLine)

    outfile.close()


def getSNPGeneInfo(genome, infilename, rpkmfilename, doCache=False, withSense=True, flankBP=0):

    rpkmDict = {}
    rpkmField = 3
    if rpkmfilename != "NONE":
        rpkmfile = open(rpkmfilename, "r")
        for line in rpkmfile:
            lineFields = line.split()
            rpkmDict[lineFields[0]] = lineFields[rpkmField]

        rpkmfile.close()

    infile = open(infilename)
    snpPositionList = []
    snpDict = {}

    for line in infile:
        if doNotProcessLine(line):
            continue

        fields = line.split("\t")
        chrom = fields[2][3:]
        start = int(fields[3])
        chromosomePosition = (chrom, start)
        snpPositionList.append(chromosomePosition)
        snpDict[chromosomePosition] = line

    if doCache:
        cacheGeneDB(genome)
        print "cached %s" % genome

    geneinfoDict = getGeneInfoDict(genome, cache=doCache)
    geneDict = {}

    if flankBP > 0:
        matchingGenesDict = genesIntersecting(genome, snpPositionList, flank=flankBP)
    else:
        matchingGenesDict = genesIntersecting(genome, snpPositionList)

    for pos in matchingGenesDict:
        geneID = matchingGenesDict[pos][0][0]
        try:
            symbol = geneinfoDict[geneID][0][0]
        except:
            symbol = "LOC%s" % geneID

        geneDescriptor = (symbol, geneID)
        if geneDict.has_key(geneDescriptor):
            geneDict[geneDescriptor]["position"].append(pos)
        else:
            geneDict[geneDescriptor] = {"position": [pos],
                                        "sense": matchingGenesDict[pos][0][-1]}

    if doCache:
        uncacheGeneDB(genome)

    return getSNPGeneOutputList(geneDict, snpDict, rpkmDict, withSense)


def doNotProcessLine(line):
    return line[0] == "#"


def getSNPGeneOutputList(geneDict, snpDict, rpkmDict, withSense):
    snpGeneOutputList = []
    snpGeneInfoList = getSNPGeneInfoList(geneDict, snpDict, rpkmDict, withSense)

    for snpEntry in snpGeneInfoList:
        outputItems = [snpEntry["snpDescription"], snpEntry["symbol"], snpEntry["geneID"], snpEntry["rpkm"]]
        if withSense:
            outputItems.append(snpEntry["sense"])

        line = string.join(outputItems, "\t")
        snpGeneOutputList.append(line)

    snpGeneOutputList.sort(reverse=True)

    return snpGeneOutputList


def getSNPGeneInfoList(geneDict, snpDict, rpkmDict, withSense):

    snpGeneInfoList = []

    for geneDescriptor in geneDict.keys():
        alreadyDoneList = []
        (symbol, geneID) = geneDescriptor
        genePositionList = geneDict[geneDescriptor]["position"]
        genePositionList.sort()

        for position in genePositionList:
            if snpDict[position] in alreadyDoneList:
                continue

            snpGeneInfoDict = {"symbol": symbol,
                               "geneID": geneID}

            rpkm = "N\A"
            if rpkmDict.has_key(geneID):
                rpkm = str(rpkmDict[geneID])

            snpGeneInfoDict["rpkm"] = rpkm
            snpGeneInfoDict["snpDescription"] = snpDict[position][:-1]
            if withSense:
                snpGeneInfoDict["sense"] = geneDict[geneDescriptor]["sense"]

            alreadyDoneList.append(snpDict[position])
            snpGeneInfoList.append(snpGeneInfoDict)

    snpGeneInfoList.sort(reverse=True)

    return snpGeneInfoList


if __name__ == "__main__":
    main(sys.argv)