import sqlite3 as sqlite
import string
import tempfile
import shutil
import os
from array import array
from commoncode import getReverseComplement, getConfigParser, getConfigOption

currentRDSVersion = "2.0"


class ReadDatasetError(Exception):
    pass


class ReadDataset():
    """ Class for storing reads from experiments. Assumes that custom scripts
    will translate incoming data into a format that can be inserted into the
    class using the insert* methods. Default class subtype ('DNA') includes
    tables for unique and multireads, whereas 'RNA' subtype also includes a
    splices table.
    """

    def __init__(self, datafile, initialize=False, datasetType="DNA", verbose=False, 
                 cache=False, reportCount=True):
        """ creates an rds datafile if initialize is set to true, otherwise
        will append to existing tables. datasetType can be either 'DNA' or 'RNA'.
        """
        self.dbcon = ""
        self.memcon = ""
        self.dataType = ""
        self.rdsVersion = currentRDSVersion
        self.memBacked = False
        self.memChrom = ""
        self.memCursor = ""
        self.cachedDBFile = ""

        if cache:
            if verbose:
                print "caching ...."

            self.cacheDB(datafile)
            dbFile = self.cachedDBFile
        else:
            dbFile = datafile

        self.dbcon = sqlite.connect(dbFile)
        self.dbcon.row_factory = sqlite.Row
        self.dbcon.execute("PRAGMA temp_store = MEMORY")
        if initialize:
            if datasetType not in ["DNA", "RNA"]:
                raise ReadDatasetError("failed to initialize: datasetType must be 'DNA' or 'RNA'")
            else:
                self.dataType = datasetType

            self.initializeTables(self.dbcon)
        else:
            metadata = self.getMetadata("dataType")
            self.dataType = metadata["dataType"]

        try:
            metadata = self.getMetadata("rdsVersion")
            self.rdsVersion = metadata["rdsVersion"]
        except:
            try:
                self.insertMetadata([("rdsVersion", float(currentRDSVersion))])
            except IOError:
                print "could not add rdsVersion - read-only ?"
                self.rdsVersion = "pre-1.0"

        if verbose:
            if initialize:
                print "INITIALIZED dataset %s" % datafile
            else:
                print "dataset %s" % datafile

            metadata = self.getMetadata()
            print "metadata:"
            pnameList = metadata.keys()
            pnameList.sort()
            for pname in pnameList:
                print "\t" + pname + "\t" + metadata[pname]

            if reportCount:
                ucount = self.getUniqsCount()
                mcount = self.getMultiCount()
                if self.dataType == "DNA" and not initialize:
                    try:
                        print "\n%d unique reads and %d multireads" % (int(ucount), int(mcount))
                    except ValueError:
                        print "\n%s unique reads and %s multireads" % (ucount, mcount)
                elif self.dataType == "RNA" and not initialize:
                    scount = self.getSplicesCount()
                    try:
                        print "\n%d unique reads, %d spliced reads and %d multireads" % (int(ucount), int(scount), int(mcount))
                    except ValueError:
                        print "\n%s unique reads, %s spliced reads and %s multireads" % (ucount, scount, mcount)

            print "default cache size is %d pages" % self.getDefaultCacheSize()
            if self.hasIndex():
                print "found index"
            else:
                print "not indexed"


    def __len__(self):
        """ return the number of usable reads in the dataset.
        """
        total = self.getUniqsCount()
        total += self.getMultiCount()

        if self.dataType == "RNA":
            total += self.getSplicesCount()

        total = int(total)

        return total


    def __del__(self):
        """ cleanup copy in local cache, if present.
        """
        if self.cachedDBFile != "":
            self.uncacheDB()


    def cacheDB(self, filename):
        """ copy geneinfoDB to a local cache.
        """
        configParser = getConfigParser()
        cisTemp = getConfigOption(configParser, "general", "cistematic_temp", default="/tmp")
        tempfile.tempdir = cisTemp
        self.cachedDBFile =  "%s.db" % tempfile.mktemp()
        shutil.copyfile(filename, self.cachedDBFile)


    def saveCacheDB(self, filename):
        """ copy geneinfoDB to a local cache.
        """
        shutil.copyfile(self.cachedDBFile, filename)


    def uncacheDB(self):
        """ delete geneinfoDB from local cache.
        """
        global cachedDBFile
        if self.cachedDBFile != "":
            try:
                os.remove(self.cachedDBFile)
            except:
                print "could not delete %s" % self.cachedDBFile

            self.cachedDB = ""


    def attachDB(self, filename, asname):
        """ attach another database file to the readDataset.
        """
        stmt = "attach '%s' as %s" % (filename, asname)
        self.execute(stmt)


    def detachDB(self, asname):
        """ detach a database file to the readDataset.
        """
        stmt = "detach %s" % (asname)
        self.execute(stmt)


    def importFromDB(self, asname, table, ascolumns="*", destcolumns="", flagged=""):
        """ import into current RDS the table (with columns destcolumns,
            with default all columns) from the database file asname,
            using the column specification of ascolumns (default all).
        """
        stmt = "insert into %s %s select %s from %s.%s" % (table, destcolumns, ascolumns, asname, table)
        if flagged != "":
            stmt += " where flag = '%s' " % flagged

        self.executeCommit(stmt)


    def getTables(self, asname=""):
        """ get a list of table names in a particular database file.
        """
        resultList = []
        sql = self.getSqlCursor()

        if asname != "":
            asname += "."

        stmt = "select name from %ssqlite_master where type='table'" % asname
        sql.execute(stmt)
        results = sql.fetchall()

        for row in results:
            resultList.append(row["name"])

        return resultList


    def getSqlCursor(self):
        if self.memBacked:
            sql = self.getMemCursor()
        else:
            sql = self.getFileCursor()

        return sql


    def hasIndex(self):
        """ check whether the RDS file has at least one index.
        """
        stmt = "select count(*) from sqlite_master where type='index'"
        count = int(self.execute(stmt, returnResults=True)[0][0])
        if count > 0:
            return True

        return False


    def initializeTables(self, dbConnection, cache=100000):
        """ creates table schema in a database connection, which is
        typically a database file or an in-memory database.
        """
        dbConnection.execute("PRAGMA DEFAULT_CACHE_SIZE = %d" % cache)
        dbConnection.execute("create table metadata (name varchar, value varchar)")
        dbConnection.execute("insert into metadata values('dataType','%s')" % self.dataType)
        positionSchema = "start int, stop int"
        tableSchema = "(ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, %s, sense varchar, weight real, flag varchar, mismatch varchar)" % positionSchema
        dbConnection.execute("create table uniqs %s" % tableSchema)
        dbConnection.execute("create table multi %s" % tableSchema)
        if self.dataType == "RNA":
            positionSchema = "startL int, stopL int, startR int, stopR int"
            tableSchema = "(ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, %s, sense varchar, weight real, flag varchar, mismatch varchar)" % positionSchema
            dbConnection.execute("create table splices %s" % tableSchema)

        dbConnection.commit()


    def getFileCursor(self):
        """ returns a cursor to file database for low-level (SQL)
        access to the data.
        """
        return self.dbcon.cursor()


    def getMemCursor(self):
        """ returns a cursor to memory database for low-level (SQL)
        access to the data.
        """
        return self.memcon.cursor()


    def getMetadata(self, valueName=""):
        """ returns a dictionary of metadata.
        """
        whereClause = ""
        resultsDict = {}

        if valueName != "":
            whereClause = " where name='%s'" % valueName

        sql = self.getSqlCursor()

        sql.execute("select name, value from metadata %s" % whereClause)
        results = sql.fetchall()

        for row in results:
            parameterName = row["name"]
            parameterValue = row["value"]
            if parameterName not in resultsDict:
                resultsDict[parameterName] = parameterValue
            else:
                trying = True
                index = 2
                while trying:
                    newName = string.join([parameterName, str(index)], ":")
                    if newName not in resultsDict:
                        resultsDict[newName] = parameterValue
                        trying = False

                    index += 1

        return resultsDict


    def getReadSize(self):
        """ returns readsize if defined in metadata.
        """
        metadata = self.getMetadata()
        if "readsize" not in metadata:
            raise ReadDatasetError("no readsize parameter defined")
        else:
            mysize = metadata["readsize"]
            if "import" in mysize:
                mysize = mysize.split()[0]

            return int(mysize)


    def getDefaultCacheSize(self):
        """ returns the default cache size.
        """
        return int(self.execute("PRAGMA DEFAULT_CACHE_SIZE", returnResults=True)[0][0])


    def getChromosomes(self, table="uniqs", fullChrom=True):
        """ returns a list of distinct chromosomes in table.
        """
        statement = "select distinct chrom from %s" % table
        sql = self.getSqlCursor()

        sql.execute(statement)
        results = []
        for row in sql:
            if fullChrom:
                if row["chrom"] not in results:
                    results.append(row["chrom"])
            else:
                if  len(row["chrom"][3:].strip()) < 1:
                    continue

                if row["chrom"][3:] not in results:
                    results.append(row["chrom"][3:])

        results.sort()

        return results


    def getMaxCoordinate(self, chrom, verbose=False, doUniqs=True,
                         doMulti=False, doSplices=False):
        """ returns the maximum coordinate for reads on a given chromosome.
        """
        maxCoord = 0
        sql = self.getSqlCursor()

        if doUniqs:
            try:
                sql.execute("select max(start) from uniqs where chrom = '%s'" % chrom)
                maxCoord = int(sql.fetchall()[0][0])
            except:
                print "couldn't retrieve coordMax for chromosome %s" % chrom

        if doSplices:
            sql.execute("select max(startR) from splices where chrom = '%s'" % chrom)
            try:
                spliceMax = int(sql.fetchall()[0][0])
                if spliceMax > maxCoord:
                    maxCoord = spliceMax
            except:
                pass

        if doMulti:
            sql.execute("select max(start) from multi where chrom = '%s'" % chrom)
            try:
                multiMax = int(sql.fetchall()[0][0])
                if multiMax > maxCoord:
                    maxCoord = multiMax
            except:
                pass

        if verbose:
            print "%s maxCoord: %d" % (chrom, maxCoord)

        return maxCoord


    def getReadsDict(self, bothEnds=False, noSense=False, fullChrom=False, chrom="",
                     flag="", withWeight=False, withFlag=False, withMismatch=False, withID=False,
                     withChrom=False, withPairID=False, doUniqs=True, doMulti=False, findallOptimize=False,
                     readIDDict=False, readLike="", start=-1, stop=-1, limit=-1, hasMismatch=False,
                     flagLike=False, strand='', combine5p=False):
        """ returns a dictionary of reads in a variety of formats
        and which can be restricted by chromosome or custom-flag.
        Returns unique reads by default, but can return multireads
        with doMulti set to True.
        
        Need to rethink original design 1: Cannot have pairID without exporting as a readIDDict
        """
        whereClause = []
        resultsDict = {}

        if chrom != "" and chrom != self.memChrom:
            whereClause.append("chrom = '%s'" % chrom)

        if flag != "":
            if flagLike:
                flagLikeClause = string.join(['flag LIKE "%', flag, '%"'], "")
                whereClause.append(flagLikeClause)
            else:
                whereClause.append("flag = '%s'" % flag)

        if start > -1:
            whereClause.append("start > %d" % start)

        if stop > -1:
            whereClause.append("stop < %d" % stop)

        if len(readLike) > 0:
            readIDClause = string.join(["readID LIKE  '", readLike, "%'"], "")
            whereClause.append(readIDClause)

        if hasMismatch:
            whereClause.append("mismatch != ''")

        if strand in ["+", "-"]:
            whereClause.append("sense = '%s'" % strand)

        if len(whereClause) > 0:
            whereStatement = string.join(whereClause, " and ")
            whereQuery = "where %s" % whereStatement
        else:
            whereQuery = ""

        groupBy = []
        if findallOptimize:
            selectClause = ["select start, sense, sum(weight)"]
            groupBy = ["GROUP BY start, sense"]
        else:
            selectClause = ["select ID, chrom, start, readID"]
            if bothEnds:
                selectClause.append("stop")

            if not noSense:
                selectClause.append("sense")

            if withWeight:
                selectClause.append("weight")

            if withFlag:
                selectClause.append("flag")

            if withMismatch:
                selectClause.append("mismatch")

        if limit > 0 and not combine5p:
            groupBy.append("LIMIT %d" % limit)

        selectQuery = string.join(selectClause, ",")
        groupQuery = string.join(groupBy)
        if doUniqs:
            stmt = [selectQuery, "from uniqs", whereQuery, groupQuery]
            if doMulti:
                stmt.append("UNION ALL")
                stmt.append(selectQuery)
                stmt.append("from multi")
                stmt.append(whereQuery)
                stmt.append(groupQuery)
        else:
            stmt = [selectQuery, "from multi", whereQuery]

        if combine5p:
            if findallOptimize:
                selectQuery = "select start, sense, weight, chrom"

            if doUniqs:
                subSelect = [selectQuery, "from uniqs", whereQuery]
                if doMulti:
                    subSelect.append("union all")
                    subSelect.append(selectQuery)
                    subSelect.append("from multi")
                    subSelect.append(whereQuery)
            else:
                subSelect = [selectQuery, "from multi", whereQuery]

            sqlStmt = string.join(subSelect)
            if findallOptimize:
                selectQuery = "select start, sense, sum(weight)"

            stmt = [selectQuery, "from (", sqlStmt, ") group by chrom,start having ( count(start) > 1 and count(chrom) > 1) union",
                    selectQuery, "from(", sqlStmt, ") group by chrom, start having ( count(start) = 1 and count(chrom) = 1)"]

        if findallOptimize:
            if self.memBacked:
                self.memcon.row_factory = None
                sql = self.memcon.cursor()
            else:
                self.dbcon.row_factory = None
                sql = self.dbcon.cursor()

            stmt.append("order by start")
        elif readIDDict:
            if self.memBacked:
                sql = self.memcon.cursor()
            else:
                sql = self.dbcon.cursor()

            stmt.append("order by readID, start")
        else:
            if self.memBacked:
                sql = self.memcon.cursor()
            else:
                sql = self.dbcon.cursor()

            stmt.append("order by chrom, start")

        sqlQuery = string.join(stmt)
        sql.execute(sqlQuery)

        if findallOptimize:
            resultsDict[chrom] = [{"start": int(row[0]), "sense": row[1], "weight": float(row[2])} for row in sql]
            if self.memBacked:
                self.memcon.row_factory = sqlite.Row
            else:
                self.dbcon.row_factory = sqlite.Row
        else:
            currentChrom = ""
            currentReadID = ""
            pairID = 0
            for row in sql:
                readID = row["readID"]
                if fullChrom:
                    chrom = row["chrom"]
                else:
                    chrom = row["chrom"][3:]

                if not readIDDict and chrom != currentChrom:
                    resultsDict[chrom] = []
                    currentChrom = chrom
                    dictKey = chrom
                elif readIDDict:
                    theReadID = readID
                    if "::" in readID:
                        theReadID = readID.split("::")[0]

                    if "/" in theReadID and withPairID:
                        (theReadID, pairID) = readID.split("/")

                    if theReadID != currentReadID:
                        resultsDict[theReadID] = []
                        currentReadID = theReadID
                        dictKey = theReadID

                newrow = {"start": int(row["start"])}
                if bothEnds:
                    newrow["stop"] = int(row["stop"])

                if not noSense:
                    newrow["sense"] = row["sense"]

                if withWeight:
                    newrow["weight"] = float(row["weight"])

                if withFlag:
                    newrow["flag"] = row["flag"]

                if withMismatch:
                    newrow["mismatch"] = row["mismatch"]

                if withID:
                    newrow["readID"] = readID

                if withChrom:
                    newrow["chrom"] = chrom

                if withPairID:
                    newrow["pairID"] = pairID

                resultsDict[dictKey].append(newrow)

        return resultsDict


    def getSplicesDict(self, noSense=False, fullChrom=False, chrom="",
                       flag="", withWeight=False, withFlag=False, withMismatch=False,
                       withID=False, withChrom=False, withPairID=False, readIDDict=False,
                       splitRead=False, hasMismatch=False, flagLike=False, start=-1,
                       stop=-1, strand=""):
        """ returns a dictionary of spliced reads in a variety of
        formats and which can be restricted by chromosome or custom-flag.
        Returns unique spliced reads for now.
        """
        whereClause = []
        resultsDict = {}

        if chrom != "" and chrom != self.memChrom:
            whereClause = ["chrom = '%s'" % chrom]

        if flag != "":
            if flagLike:
                flagLikeClause = string.join(['flag LIKE "%', flag, '%"'], "")
                whereClause.append(flagLikeClause)
            else:
                whereClause.append("flag = '%s'" % flag)

        if hasMismatch:
            whereClause.append("mismatch != ''")

        if strand != "":
            whereClause.append("sense = '%s'" % strand)

        if start > -1:
            whereClause.append("startL > %d" % start)

        if stop > -1:
            whereClause.append("stopR < %d" % stop)

        if len(whereClause) > 0:
            whereStatement = string.join(whereClause, " and ")
            whereQuery = "where %s" % whereStatement
        else:
            whereQuery = ""

        selectClause = ["select ID, chrom, startL, stopL, startR, stopR, readID"]
        if not noSense:
            selectClause.append("sense")

        if withWeight:
            selectClause.append("weight")

        if withFlag:
            selectClause.append("flag")

        if withMismatch:
            selectClause.append("mismatch")

        selectQuery = string.join(selectClause, " ,")
        if self.memBacked:
            sql = self.memcon.cursor()
        else:
            sql = self.dbcon.cursor()

        stmt = "%s from splices %s order by chrom, startL" % (selectQuery, whereQuery)
        sql.execute(stmt)
        currentReadID = ""
        currentChrom = ""
        for row in sql:
            pairID = 0
            readID = row["readID"]
            if fullChrom:
                chrom = row["chrom"]
            else:
                chrom = row["chrom"][3:]

            if not readIDDict and chrom != currentChrom:
                resultsDict[chrom] = []
                currentChrom = chrom
                dictKey = chrom
            elif readIDDict:
                if "/" in readID:
                    (theReadID, pairID) = readID.split("/")
                else:
                    theReadID = readID

                if theReadID != currentReadID:
                    resultsDict[theReadID] = []
                    currentReadID = theReadID
                    dictKey = theReadID

            newrow = {"startL": int(row["startL"])}
            newrow["stopL"] = int(row["stopL"])
            newrow["startR"] = int(row["startR"])
            newrow["stopR"] = int(row["stopR"])
            if not noSense:
                newrow["sense"] = row["sense"]

            if withWeight:
                newrow["weight"] = float(row["weight"])

            if withFlag:
                newrow["flag"] = row["flag"]

            if withMismatch:
                newrow["mismatch"] = row["mismatch"]

            if withID:
                newrow["readID"] = readID

            if withChrom:
                newrow["chrom"] = chrom

            if withPairID:
                newrow["pairID"] = pairID

            if splitRead:
                leftDict = newrow.copy()
                del leftDict["startR"]
                del leftDict["stopR"]
                rightDict = newrow
                del rightDict["startL"]
                del rightDict["stopL"]
                resultsDict[dictKey].append(leftDict)
                resultsDict[dictKey].append(rightDict)
            else:
                resultsDict[dictKey].append(newrow)

        return resultsDict


    def getCounts(self, chrom="", rmin="", rmax="", uniqs=True, multi=False,
                  splices=False, reportCombined=True, sense="both"):
        """ return read counts for a given region.
        """
        ucount = 0
        mcount = 0
        scount = 0
        restrict = ""
        if sense in ["+", "-"]:
            restrict = " sense ='%s' " % sense

        if uniqs:
            try:
                ucount = float(self.getUniqsCount(chrom, rmin, rmax, restrict))
            except:
                ucount = 0

        if multi:
            try:
                mcount = float(self.getMultiCount(chrom, rmin, rmax, restrict))
            except:
                mcount = 0

        if splices:
            try:
                scount = float(self.getSplicesCount(chrom, rmin, rmax, restrict))
            except:
                scount = 0

        if reportCombined:
            total = ucount + mcount + scount
            return total
        else:
            return (ucount, mcount, scount)


    def getTotalCounts(self, chrom="", rmin="", rmax=""):
        """ return read counts for a given region.
        """
        return self.getCounts(chrom, rmin, rmax, uniqs=True, multi=True, splices=True, reportCombined=True, sense="both")


    def getTableEntryCount(self, table, chrom="", rmin="", rmax="", restrict="", distinct=False, startField="start"):
        """ returns the number of row in the uniqs table.
        """
        whereClause = []
        count = 0

        if chrom !=""  and chrom != self.memChrom:
            whereClause = ["chrom='%s'" % chrom]

        if rmin != "":
            whereClause.append("%s >= %s" % (startField, str(rmin)))

        if rmax != "":
            whereClause.append("%s <= %s" % (startField, str(rmax)))

        if restrict != "":
            whereClause.append(restrict)

        if len(whereClause) > 0:
            whereStatement = string.join(whereClause, " and ")
            whereQuery = "where %s" % whereStatement
        else:
            whereQuery = ""

        if self.memBacked:
            sql = self.memcon.cursor()
        else:
            sql = self.dbcon.cursor()

        if distinct:
            sql.execute("select count(distinct chrom+%s+sense) from %s %s" % (startField, table, whereQuery))
        else:
            sql.execute("select sum(weight) from %s %s" % (table, whereQuery))

        result = sql.fetchone()

        try:
            count = int(result[0])
        except:
            count = 0

        return count


    def getSplicesCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
        """ returns the number of row in the splices table.
        """
        return self.getTableEntryCount("splices", chrom, rmin, rmax, restrict, distinct, startField="startL")


    def getUniqsCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
        """ returns the number of distinct readIDs in the uniqs table.
        """
        return self.getTableEntryCount("uniqs", chrom, rmin, rmax, restrict, distinct)


    def getMultiCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
        """ returns the total weight of readIDs in the multi table.
        """
        return self.getTableEntryCount("multi", chrom, rmin, rmax, restrict, distinct)


    def getReadIDs(self, uniqs=True, multi=False, splices=False, paired=False, limit=-1):
        """ get readID's.
        """
        stmt = []
        limitPart = ""
        if limit > 0:
            limitPart = "LIMIT %d" % limit

        if uniqs:
            stmt.append("select readID from uniqs")

        if multi:
            stmt.append("select readID from multi")

        if splices:
            stmt.append("select readID from splices")

        if len(stmt) > 0:
            selectPart = string.join(stmt, " union ")
        else:
            selectPart = ""

        sqlQuery = "%s group by readID %s" % (selectPart, limitPart)
        if self.memBacked:
            sql = self.memcon.cursor()
        else:
            sql = self.dbcon.cursor()

        sql.execute(sqlQuery)
        result = sql.fetchall()

        if paired:
            return [x[0].split("/")[0] for x in result]
        else:
            return [x[0] for x in result]


    def getMismatches(self, mischrom=None, verbose=False, useSplices=True):
        """ returns the uniq and spliced mismatches in a dictionary.
        """
        readlen = self.getReadSize()
        if mischrom:
            hitChromList = [mischrom]
        else:
            hitChromList = self.getChromosomes()
            hitChromList.sort()

        snpDict = {}
        for achrom in hitChromList:
            if verbose:
                print "getting mismatches from chromosome %s" % (achrom)

            snpDict[achrom] = []
            hitDict = self.getReadsDict(fullChrom=True, chrom=achrom, withMismatch=True, hasMismatch=True)
            if useSplices and self.dataType == "RNA":
                spliceDict = self.getSplicesDict(fullChrom=True, chrom=achrom, withMismatch=True, readIDDict=True, hasMismatch=True)
                spliceIDList = spliceDict.keys()
                for k in spliceIDList:
                    spliceEntry = spliceDict[k][0]
                    startpos = spliceEntry["startL"]
                    lefthalf = spliceEntry["stopL"]
                    rightstart = spliceEntry["startR"]
                    sense = spliceEntry["sense"]
                    mismatches = spliceEntry["mismatch"]
                    spMismatchList = mismatches.split(",")
                    for mismatch in spMismatchList:
                        if "N" in mismatch:
                            continue

                        change_len = len(mismatch)
                        if sense == "+":
                            change_from = mismatch[0]
                            change_base = mismatch[change_len-1]
                            change_pos = int(mismatch[1:change_len-1])
                        elif sense == "-":
                            change_from = getReverseComplement([mismatch[0]])
                            change_base = getReverseComplement([mismatch[change_len-1]])
                            change_pos = readlen - int(mismatch[1:change_len-1]) + 1

                        firsthalf = int(lefthalf)-int(startpos)+1
                        secondhalf = 0
                        if int(change_pos) <= int(firsthalf):
                            change_at = startpos + change_pos - 1
                        else:
                            secondhalf = change_pos - firsthalf
                            change_at = rightstart + secondhalf

                        snpDict[achrom].append([startpos, change_at, change_base, change_from])

            if achrom not in hitDict.keys():
                continue

            for readEntry in hitDict[achrom]:
                start = readEntry["start"]
                sense = readEntry["sense"]
                mismatches = readEntry["mismatch"]
                mismatchList = mismatches.split(",")
                for mismatch in mismatchList:
                    if "N" in mismatch:
                        continue

                    change_len = len(mismatch)
                    if sense == "+":
                        change_from = mismatch[0]
                        change_base = mismatch[change_len-1]
                        change_pos = int(mismatch[1:change_len-1])
                    elif sense == "-":
                        change_from = getReverseComplement([mismatch[0]])
                        change_base = getReverseComplement([mismatch[change_len-1]])
                        change_pos = readlen - int(mismatch[1:change_len-1]) + 1

                    change_at = start + change_pos - 1
                    snpDict[achrom].append([start, change_at, change_base, change_from])

        return snpDict


    def getChromProfile(self, chromosome, cstart=-1, cstop=-1, useMulti=True,
                        useSplices=False, normalizationFactor = 1.0, trackStrand=False,
                        keepStrand="both", shiftValue=0):
        """return a profile of the chromosome as an array of per-base read coverage....
            keepStrand = 'both', 'plusOnly', or 'minusOnly'.
            Will also shift position of unique and multireads (but not splices) if shift is a natural number
        """
        metadata = self.getMetadata()
        try:
            readlen = int(metadata["readsize"])
        except KeyError:
            readlen = 0

        dataType = metadata["dataType"]
        scale = 1. / normalizationFactor
        shift = {}
        shift['+'] = int(shiftValue)
        shift['-'] = -1 * int(shiftValue)

        if cstop > 0:
            lastNT = self.getMaxCoordinate(chromosome, doMulti=useMulti, doSplices=useSplices) + readlen
        else:
            lastNT = cstop - cstart + readlen + shift["+"]

        chromModel = array("f",[0.] * lastNT)
        hitDict = self.getReadsDict(fullChrom=True, chrom=chromosome, withWeight=True, doMulti=useMulti, start=cstart, stop=cstop, findallOptimize=True)
        if cstart < 0:
            cstart = 0

        for readEntry in hitDict[chromosome]:
            hstart = readEntry["start"]
            sense =  readEntry ["sense"]
            weight = readEntry["weight"]
            hstart = hstart - cstart + shift[sense]
            for currentpos in range(hstart,hstart+readlen):
                try:
                    if not trackStrand or (sense == "+" and keepStrand != "minusOnly"):
                        chromModel[currentpos] += scale * weight
                    elif sense == "-" and keepStrand != "plusOnly":
                        chromModel[currentpos] -= scale * weight
                except:
                    continue

        del hitDict
        if useSplices and dataType == "RNA":
            if cstop > 0:
                spliceDict = self.getSplicesDict(fullChrom=True, chrom=chromosome, withID=True, start=cstart, stop=cstop)
            else:
                spliceDict = self.getSplicesDict(fullChrom=True, chrom=chromosome, withID=True)
   
            if chromosome in spliceDict:
                for spliceEntry in spliceDict[chromosome]:
                    Lstart = spliceEntry["startL"]
                    Lstop = spliceEntry["stopL"]
                    Rstart = spliceEntry["startR"]
                    Rstop = spliceEntry["stopR"]
                    rsense = spliceEntry["sense"]
                    if (Rstop - cstart) < lastNT:
                        for index in range(abs(Lstop - Lstart)):
                            currentpos = Lstart - cstart + index
                            # we only track unique splices
                            if not trackStrand or (rsense == "+" and keepStrand != "minusOnly"):
                                chromModel[currentpos] += scale
                            elif rsense == "-" and keepStrand != "plusOnly":
                                chromModel[currentpos] -= scale

                        for index in range(abs(Rstop - Rstart)):
                            currentpos = Rstart - cstart + index
                            # we only track unique splices
                            if not trackStrand or (rsense == "+" and keepStrand != "minusOnly"):
                                chromModel[currentpos] += scale
                            elif rsense == "-" and keepStrand != "plusOnly":
                                chromModel[currentpos] -= scale

            del spliceDict

        return chromModel


    def insertMetadata(self, valuesList):
        """ inserts a list of (pname, pvalue) into the metadata
        table.
        """
        self.dbcon.executemany("insert into metadata(name, value) values (?,?)", valuesList)
        self.dbcon.commit()


    def updateMetadata(self, pname, newValue, originalValue=""):
        """ update a metadata field given the original value and the new value.
        """
        stmt = "update metadata set value='%s' where name='%s'" % (str(newValue), pname)
        if originalValue != "":
            stmt += " and value='%s' " % str(originalValue)

        self.dbcon.execute(stmt)
        self.dbcon.commit()


    def insertUniqs(self, valuesList):
        """ inserts a list of (readID, chrom, start, stop, sense, weight, flag, mismatch)
        into the uniqs table.
        """
        self.dbcon.executemany("insert into uniqs(ID, readID, chrom, start, stop, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)", valuesList)
        self.dbcon.commit()


    def insertMulti(self, valuesList):
        """ inserts a list of (readID, chrom, start, stop, sense, weight, flag, mismatch)
        into the multi table.
        """
        self.dbcon.executemany("insert into multi(ID, readID, chrom, start, stop, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)", valuesList)
        self.dbcon.commit()


    def insertSplices(self, valuesList):
        """ inserts a list of (readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch)
        into the splices table.
        """
        self.dbcon.executemany("insert into splices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?,?,?)", valuesList)
        self.dbcon.commit()


    def flagReads(self, regionsList, uniqs=True, multi=False, splices=False, sense="both"):
        """ update reads on file database in a list region of regions for a chromosome to have a new flag.
            regionsList must have 4 fields per region of the form (flag, chrom, start, stop) or, with
            sense set to '+' or '-', 5 fields per region of the form (flag, chrom, start, stop, sense).
        """
        restrict = ""
        if sense != "both":
            restrict = " and sense = ? "

        if uniqs:
            self.dbcon.executemany("UPDATE uniqs SET flag = ? where chrom = ? and start >= ? and start < ? " + restrict, regionsList)

        if multi:
            self.dbcon.executemany("UPDATE multi SET flag = ? where chrom = ? and start >= ? and start < ? " + restrict, regionsList)

        if self.dataType == "RNA" and splices:
            self.dbcon.executemany("UPDATE splices SET flag = flag || ' L:' || ? where chrom = ? and startL >= ? and startL < ? " + restrict, regionsList)
            self.dbcon.executemany("UPDATE splices SET flag = flag || ' R:' || ? where chrom = ? and startR >= ? and startR < ? " + restrict, regionsList)

        self.dbcon.commit()


    def setFlags(self, flag, uniqs=True, multi=True, splices=True):
        """ set the flag fields in the entire dataset.
        """
        if uniqs:
            self.dbcon.execute("UPDATE uniqs SET flag = '%s'" % flag)

        if multi:
            self.dbcon.execute("UPDATE multi SET flag = '%s'" % flag)

        if self.dataType == "RNA" and splices:
            self.dbcon.execute("UPDATE splices SET flag = '%s'" % flag)

        self.dbcon.commit()


    def resetFlags(self, uniqs=True, multi=True, splices=True):
        """ reset the flag fields in the entire dataset to clear. Useful for rerunning an analysis from scratch.
        """
        self.setFlags("", uniqs, multi, splices)


    def reweighMultireads(self, readList):
        self.dbcon.executemany("UPDATE multi SET weight = ? where chrom = ? and start = ? and readID = ? ", readList)


    def setSynchronousPragma(self, value="ON"):
        try:
            self.dbcon.execute("PRAGMA SYNCHRONOUS = %s" % value)
        except:
            print "warning: couldn't set PRAGMA SYNCHRONOUS = %s" % value


    def setDBcache(self, cache, default=False):
        self.dbcon.execute("PRAGMA CACHE_SIZE = %d" % cache)
        if default:
            self.dbcon.execute("PRAGMA DEFAULT_CACHE_SIZE = %d" % cache)


    def execute(self, statement, returnResults=False):
        sql = self.getSqlCursor()

        sql.execute(statement)
        if returnResults:
            result = sql.fetchall()
            return result


    def executeCommit(self, statement):
        self.execute(statement)

        if self.memBacked:
            self.memcon.commit()
        else:
            self.dbcon.commit()


    def buildIndex(self, cache=100000):
        """ Builds the file indeces for the main tables.
            Cache is the number of 1.5 kb pages to keep in memory.
            100000 pages translates into 150MB of RAM, which is our default.
        """
        if cache > self.getDefaultCacheSize():
            self.setDBcache(cache)
        self.setSynchronousPragma("OFF")
        self.dbcon.execute("CREATE INDEX uPosIndex on uniqs(chrom, start)")
        print "built uPosIndex"
        self.dbcon.execute("CREATE INDEX uChromIndex on uniqs(chrom)")
        print "built uChromIndex"
        self.dbcon.execute("CREATE INDEX mPosIndex on multi(chrom, start)")
        print "built mPosIndex"
        self.dbcon.execute("CREATE INDEX mChromIndex on multi(chrom)")
        print "built mChromIndex"

        if self.dataType == "RNA":
            self.dbcon.execute("CREATE INDEX sPosIndex on splices(chrom, startL)")
            print "built sPosIndex"
            self.dbcon.execute("CREATE INDEX sPosIndex2 on splices(chrom, startR)")
            print "built sPosIndex2"
            self.dbcon.execute("CREATE INDEX sChromIndex on splices(chrom)")
            print "built sChromIndex"

        self.dbcon.commit()
        self.setSynchronousPragma("ON")


    def dropIndex(self):
        """ drops the file indices for the main tables.
        """
        try:
            self.setSynchronousPragma("OFF")
            self.dbcon.execute("DROP INDEX uPosIndex")
            self.dbcon.execute("DROP INDEX uChromIndex")
            self.dbcon.execute("DROP INDEX mPosIndex")
            self.dbcon.execute("DROP INDEX mChromIndex")

            if self.dataType == "RNA":
                self.dbcon.execute("DROP INDEX sPosIndex")
                try:
                    self.dbcon.execute("DROP INDEX sPosIndex2")
                except:
                    pass

                self.dbcon.execute("DROP INDEX sChromIndex")

            self.dbcon.commit()
        except:
            print "problem dropping index"

        self.setSynchronousPragma("ON")


    def memSync(self, chrom="", index=False):
        """ makes a copy of the dataset into memory for faster access.
        Can be restricted to a "full" chromosome. Can also build the
        memory indices.
        """
        self.memcon = ""
        self.memcon = sqlite.connect(":memory:")
        self.initializeTables(self.memcon)
        cursor = self.dbcon.cursor()
        whereclause = ""
        if chrom != "":
            print "memSync %s" % chrom
            whereclause = " where chrom = '%s' " % chrom
            self.memChrom = chrom
        else:
            self.memChrom = ""

        self.memcon.execute("PRAGMA temp_store = MEMORY")
        self.memcon.execute("PRAGMA CACHE_SIZE = 1000000")
        # copy metadata to memory
        self.memcon.execute("delete from metadata")
        results = cursor.execute("select name, value from metadata")
        results2 = []
        for row in results:
            results2.append((row["name"], row["value"]))

        self.memcon.executemany("insert into metadata(name, value) values (?,?)", results2)

        self.copyDBEntriesToMemory("uniqs", whereclause)
        self.copyDBEntriesToMemory("multi", whereclause)
        if self.dataType == "RNA":
            self.copySpliceDBEntriesToMemory(whereclause)

        if index:
            if chrom != "":
                self.memcon.execute("CREATE INDEX uPosIndex on uniqs(start)")
                self.memcon.execute("CREATE INDEX mPosIndex on multi(start)")
                if self.dataType == "RNA":
                    self.memcon.execute("CREATE INDEX sPosLIndex on splices(startL)")
                    self.memcon.execute("CREATE INDEX sPosRIndex on splices(startR)")
            else:
                self.memcon.execute("CREATE INDEX uPosIndex on uniqs(chrom, start)")
                self.memcon.execute("CREATE INDEX mPosIndex on multi(chrom, start)")
                if self.dataType == "RNA":
                    self.memcon.execute("CREATE INDEX sPosLIndex on splices(chrom, startL)")
                    self.memcon.execute("CREATE INDEX sPosRIndex on splices(chrom, startR)")

        self.memBacked = True
        self.memcon.row_factory = sqlite.Row
        self.memcon.commit()


    def copyDBEntriesToMemory(self, dbName, whereClause=""):
        cursor = self.dbcon.cursor()
        sourceEntries = cursor.execute("select chrom, start, stop, sense, weight, flag, mismatch, readID from %s %s" % (dbName, whereClause))
        destinationEntries = []
        for row in sourceEntries:
            destinationEntries.append((row["readID"], row["chrom"], int(row["start"]), int(row["stop"]), row["sense"], row["weight"], row["flag"], row["mismatch"]))

        self.memcon.executemany("insert into %s(ID, readID, chrom, start, stop, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)" % dbName, destinationEntries)


    def copySpliceDBEntriesToMemory(self, whereClause=""):
        cursor = self.dbcon.cursor()
        sourceEntries = cursor.execute("select chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch, readID from splices %s" % whereClause)
        destinationEntries = []
        for row in sourceEntries:
            destinationEntries.append((row["readID"], row["chrom"], int(row["startL"]), int(row["stopL"]), int(row["startR"]), int(row["stopR"]), row["sense"],
                                       row["weight"], row["flag"], row["mismatch"]))

        self.memcon.executemany("insert into splices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)", destinationEntries)

