try:
    import psyco
    psyco.full()
except:
    pass

import sys
import tempfile
import shutil
import os
import optparse
import sqlite3 as sqlite
from commoncode import getConfigParser, getConfigOption, getConfigBoolOption

configParser = getConfigParser()
cisTemp = getConfigOption(configParser, "general", "cistematic_temp", default="/tmp")
tempfile.tempdir = cisTemp

print "chkSNPrmask: version 3.4"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %s dbfile snpsfile nr_snps_outfile [--cache numPages] [--repeats]"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    dbfile = args[0]
    filename = args[1]
    outfile = args[2]

    chkSNPrmask(dbfile, filename, outfile, options.repeats, options.cachePages)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--repeats", action="store_true", dest="repeats")
    parser.add_option("--cache", type="int", dest="cachePages")

    configParser = getConfigParser()
    section = "checkSNPrmask"
    repeats = getConfigBoolOption(configParser, section, "repeats", False)
    cachePages = getConfigOption(configParser, section, "cachePages", None)

    parser.set_defaults(repeats=repeats, cachePages=cachePages)

    return parser


def chkSNPrmask(dbfile, filename, outfile, repeats=False, cachePages=None):
    print dbfile

    if cachePages is not None:
        if cachePages < 250000:
            cachePages = 250000

        print "caching locally..."
        cachefile = tempfile.mktemp() + ".db"
        shutil.copyfile(dbfile, cachefile)
        db = sqlite.connect(cachefile)
        doCache = True
        print "cached..."
    else:
        cachePages = 500000
        doCache = False
        db = sqlite.connect(dbfile)

    sql = db.cursor()
    sql.execute("PRAGMA CACHE_SIZE = %d" % cachePages)
    sql.execute("PRAGMA temp_store = MEMORY")
    sql.execute("ANALYZE")

    infile = open(filename)
    featureList = []
    featureDict = {}

    for line in infile:
        if doNotProcessLine(line):
            continue

        fields = line.strip().split("\t")
        chrom = fields[2][3:]
        pos = int(fields[3])
        featureList.append((chrom,pos))
        featureDict[(chrom, pos)] = line.strip()

    featureList.sort()

    index = 0
    currentChrom=None
    for (chrom, pos) in featureList:
        index += 1
        if chrom != currentChrom:
            print "\n%s" % chrom
            currentChrom = chrom

        results = []
        try:
            sql.execute("select family from repeats where chrom = '%s' and %d between start and stop" % (chrom, pos)) 
            results = sql.fetchall()
        except:
            pass

        if repeats: # if user wants to keep track of the SNPs in repeats
            featureDict[(chrom,pos)] += "\tN\A" 
            for x in results:
                featureDict[(chrom,pos)] += "\t" + str(x)
        else:
            for x in results:
                try:
                    del featureDict[(chrom,pos)]
                except KeyError:
                    pass

        if index % 100 == 0:
            print ".",
            sys.stdout.flush()

    if doCache:
        print "removing cache"
        del db
        os.remove(cachefile)

    outFile = open(outfile, "w") 
    for key, value in featureDict.iteritems():
        outStr = str(value) + "\n"
        outFile.write(outStr)

    outFile.close()


def doNotProcessLine(line):
    return line[0] == "#"


if __name__ == "__main__":
    main(sys.argv)