try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
import tempfile
import shutil
import os
import string
import sqlite3 as sqlite
from commoncode import getConfigParser, getConfigOption

print "chksnp: version 3.7"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog dbfile snpsfile dbsnp_outfile [--cache numPages] [--snpDB dbfile]"

    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--cache", type="int", dest="cachePages")
    parser.add_option("--snpDB", action="append", dest="snpDBList",
                      help="additional snp db files to check will be searched in order given")
    parser.set_defaults(cachePages=None, snpDBList=[])
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    dbfile = args[0]
    infile = args[1]
    outfile = args[2]

    chkSNPFile(dbfile, infile, outfile, options.cachePages, options.snpDBList)


def chkSNPFile(dbfile, inputFileName, outputFileName, cachePages=None, snpDBList=[]):

    snpInputFile = open(inputFileName)
    snpLocationList, snpDict = getSNPLocationInfo(snpInputFile)

    dbList = [dbfile]
    for dbFileName in snpDBList:
        dbList.append(dbFileName)

    annotatedSnpDict = annotateSNPFromDBList(snpLocationList, snpDict, dbList, cachePages)

    outputFile = open(outputFileName, "w")
    outputLine = ""
    outputFile.write(outputLine)
    for key,value in annotatedSnpDict.iteritems():
        outputLine = "%s\n" % str(value)
        outputFile.write(outputLine)

    outputFile.close()


def chkSNP(dbList, snpPropertiesList, cachePages=None):

    snpLocationList, snpDict = getSNPLocationInfo(snpPropertiesList)
    return annotateSNPFromDBList(snpLocationList, snpDict, dbList, cachePages)


def getSNPLocationInfo(snpPropertiesList):
    snpLocationList = []
    snpDict = {}

    for line in snpPropertiesList:
        if doNotProcessLine(line):
            continue

        fields = line.strip().split("\t")
        chromosome = fields[2][3:]
        position = int(fields[3])
        snpLocation = (chromosome, position)
        snpLocationList.append(snpLocation)
        snpDict[snpLocation] = line.strip()

    snpLocationList.sort()

    return snpLocationList, snpDict


def doNotProcessLine(line):
    return line[0] == "#"


def annotateSNPFromDB(snpLocationList, snpDict, dbFileName, cachePages=None):
    return annotateSNPFromDBList(snpLocationList, snpDict, [dbFileName], cachePages)


def annotateSNPFromDBList(snpLocationList, snpDict, dbList, cachePages=None):

    configParser = getConfigParser()
    cisTemp = getConfigOption(configParser, "general", "cistematic_temp", default="/tmp")
    tempfile.tempdir = cisTemp

    for dbFileName in dbList:
        if cachePages is not None:
            print "caching locally..."
            cachefile = "%s.db" % tempfile.mktemp()
            shutil.copyfile(dbFileName, cachefile)
            db = sqlite.connect(cachefile)
            doCache = True
            print "cached..."
        else:
            db = sqlite.connect(dbFileName)
            doCache = False

        cacheSize = max(cachePages, 500000)
        sql = db.cursor()
        sql.execute("PRAGMA CACHE_SIZE = %d" % cacheSize)
        sql.execute("PRAGMA temp_store = MEMORY")

        index = 0
        foundEntries = []
        for chromosomePosition in snpLocationList:
            (chromosome, position) = chromosomePosition
            found = False
            results = []
            index += 1
            startPosition = position - 1
            sql.execute("select func, name from snp where chrom = '%s' and start = %d and stop = %d" % (chromosome, startPosition, position)) 
            results = sql.fetchall()
            try:
                (func, name) = results[0]
                found = True
            except IndexError:
                sql.execute("select func, name from snp where chrom = '%s' and start <= %d and stop >= %d" % (chromosome, startPosition, position))
                results = sql.fetchall()
                try:
                    (func, name) = results[0]
                    found = True
                except IndexError:
                    pass

            if found:
                snpEntry = snpDict[chromosomePosition]
                snpDict[chromosomePosition] = string.join([snpEntry, str(name), str(func)], "\t")
                foundEntries.append(chromosomePosition)

            if index % 100 == 0:
                print ".",
                sys.stdout.flush()

        for chromosomePosition in foundEntries:
            del snpLocationList[snpLocationList.index(chromosomePosition)]

        if doCache:
            print "\nremoving cache"
            del db
            os.remove(cachefile)

    for chromosomePosition in snpLocationList:
        snpEntry = snpDict[chromosomePosition]
        snpDict[chromosomePosition] = string.join([snpEntry, "N\A", "N\A"], "\t")

    return snpDict


if __name__ == "__main__":
    main(sys.argv)