###########################################################################
#                                                                         #
# C O P Y R I G H T   N O T I C E                                         #
#  Copyright (c) 2003-10 by:                                              #
#    * California Institute of Technology                                 #
#                                                                         #
#    All Rights Reserved.                                                 #
#                                                                         #
# Permission is hereby granted, free of charge, to any person             #
# obtaining a copy of this software and associated documentation files    #
# (the "Software"), to deal in the Software without restriction,          #
# including without limitation the rights to use, copy, modify, merge,    #
# publish, distribute, sublicense, and/or sell copies of the Software,    #
# and to permit persons to whom the Software is furnished to do so,       #
# subject to the following conditions:                                    #
#                                                                         #
# The above copyright notice and this permission notice shall be          #
# included in all copies or substantial portions of the Software.         #
#                                                                         #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,         #
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF      #
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND                   #
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS     #
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN      #
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN       #
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE        #
# SOFTWARE.                                                               #
###########################################################################
#
try:
    from pysqlite2 import dbapi2 as sqlite
except:
    from sqlite3 import dbapi2 as sqlite

import tempfile, shutil, os
from os import environ

if environ.get("CISTEMATIC_ROOT"):
    cisRoot = environ.get("CISTEMATIC_ROOT") 
else:
    cisRoot = "/proj/genome"

dbPath = "%s/db/gene_info.db" % cisRoot

if environ.get("CISTEMATIC_TEMP"):
    cisTemp = environ.get("CISTEMATIC_TEMP")
else:
    cisTemp = "/tmp"

tempfile.tempdir = cisTemp

speciesMap = {"3702": "athaliana",
              "4932": "scerevisiae",
              "6239": "celegans",
              "7227": "dmelanogaster",
              "7668": "spurpuratus",
              "7955": "drerio",
              "8364": "xtropicalis",
              "9031": "ggallus",
              "9606": "hsapiens",
              "9615": "cfamiliaris",
              "9796": "ecaballus",
              "9913": "btaurus",
              "10090": "mmusculus",
              "10116": "rnorvegicus",
              "13616": "mdomestica"
}


class geneinfoDB:
    """ The geneinfoDB class allows for the querying of NCBI gene data.
    """
    startingGenome = ""
    targetGenomes  = []
    cachedDB = ""

        
    def __init__(self,  tGenomes=[], cache=False):
        """ initialize the geneinfoDB object with a target genome and cache database, if desired. 
        """
        self.targetGenomes = tGenomes
        if cache:
            self.cacheDB()


    def __del__(self):
        """ cleanup copy in local cache, if present. 
        """
        if self.cachedDB != "":
            self.uncacheDB()


    def cacheDB(self):
        """ copy geneinfoDB to a local cache. 
        """
        self.cachedDB = "%s.db" % tempfile.mktemp()
        shutil.copyfile(dbPath, self.cachedDB)


    def uncacheDB(self):
        """ delete geneinfoDB from local cache. 
        """
        global cachedDB
        if self.cachedDB != "":
            try:
                os.remove(self.cachedDB)
            except OSError:
                print "could not delete %s" % self.cachedDB

            self.cachedDB = ""


    def connectDB(self):
        """ return a handle to the database. 
        """
        path = dbPath
        if self.cachedDB != "":
            path = self.cachedDB

        return sqlite.connect(path, timeout=60)


    def getGeneInfo(self, geneID):
        """ returns a list of one or more (symbol, locustag, dbxrefs, chromosome, map_location) for a geneID.
        """
        db = self.connectDB()
        cursor = db.cursor()
        emptyRes = []
        (gen, gid) = geneID
        cursor.execute("select symbol, locustag, dbxrefs, chromosome, map_location from gene_info where genome = :gen and gID = :gid " , locals())
        entry = cursor.fetchone()
        cursor.close()
        db.close()
        try:
            (symbol, locustag, dbxrefs, chromosome, map_location) = entry
            return (str(symbol), str(locustag), str(dbxrefs), str(chromosome), str(map_location))
        except ValueError:
            pass

        return emptyRes


    def getallGeneInfo(self, genome, infoKey="gid"):
        """ returns a dictionary of one or more (symbol, locustag, dbxrefs, chromosome, map_location) per gID.
            acceptable infoKey arguments are: 'locus', and 'gid'.
        """
        db = self.connectDB()
        cursor = db.cursor()
        resDict = {}
        cursor.execute("select gid, symbol, locustag, dbxrefs, chromosome, map_location from gene_info where genome = :genome", locals())
        results = cursor.fetchall()
        cursor.close()
        db.close()
        for (gid, symbol, locustag, dbxrefs, chromosome, map_location) in results:
            if infoKey == "locus":
                if str(locustag) not in resDict:
                    resDict[str(locustag)] = []

                resDict[str(locustag)].append((str(symbol), str(gid), str(dbxrefs), str(chromosome), str(map_location)))
            else:
                if str(gid) not in resDict:
                    resDict[str(gid)] = []

                resDict[str(gid)].append((str(symbol), str(locustag), str(dbxrefs), str(chromosome), str(map_location)))

        return resDict


    def getDescription(self, geneID):
        """ returns a list of one or more gene description for a geneID. 
        """
        db = self.connectDB()
        cursor = db.cursor()
        results = []
        (gen, gid) = geneID
        cursor.execute("select description from gene_description where genome = :gen and gID = :gid", locals())
        entries = cursor.fetchall()
        cursor.close()
        db.close()
        for entry in entries:
            results.append(str(entry[0]))

        return results


    def geneIDSynonyms(self, geneID):
        """ returns a list of synonyms for a geneID. 
        """
        db = self.connectDB()
        cursor = db.cursor()
        results = []
        (gen, gid) = geneID
        cursor.execute("select synonym from gene_synonyms where genome = :gen and gID = :gid", locals())
        entries = cursor.fetchall()
        cursor.close()
        db.close()
        for entry in entries:
            results.append(str(entry[0]))

        return results


    def getGeneID(self, genome, synonym):
        """ returns a geneID given a genome and a synonym. 
        """
        db = self.connectDB()
        cursor = db.cursor()
        results = []
        cursor.execute("select gID from gene_info where genome= :genome and symbol= :synonym", locals())
        entry = cursor.fetchone()
        if entry:
            cursor.close()
            db.close()
            return (genome, str(entry[0]))

        cursor.execute("select gID from gene_synonyms where genome = :genome and synonym = :synonym", locals())
        entry = cursor.fetchone()
        if entry:
            cursor.close()
            db.close()
            return (genome, str(entry[0]))

        cursor.execute("select gID from gene_info where genome = :genome and locustag = :synonym", locals())
        entry = cursor.fetchone()
        cursor.close()
        db.close()
        if entry:
            return (genome, str(entry[0]))

        return results


def buildgeneinfoDB(datafile, path=dbPath):
    """ populate geneinfo database from NCBI gene information. 
    """
    inFile = open(datafile, "r")
    idb = geneinfoDB()
    db = idb.connectDB()
    cursor = db.cursor()
    cursor.execute("create table gene_info(ID INTEGER PRIMARY KEY, genome varchar, gID varchar, symbol varchar, locustag varchar, dbxrefs varchar, chromosome varchar, map_location varchar)")
    cursor.execute("create table gene_description(ID INTEGER PRIMARY KEY, genome varchar, gID varchar, description varchar)")
    cursor.execute("create table gene_synonyms(ID INTEGER PRIMARY KEY, genome varchar, gID varchar, synonym varchar)")
    genomeKeys = speciesMap.keys()

    for line in inFile:
        line = line.replace("'", "prime")
        field = line.split("\t")
        if field[0] in genomeKeys:
            try:
                genome = speciesMap[field[0]]
                sqlstmt = "INSERT into gene_info(ID, genome, gID, symbol, locustag, dbxrefs, chromosome, map_location) values (NULL, '%s', '%s', '%s', '%s', '%s', '%s', '%s')" % (genome, field[1], field[2], field[3], field[5], field[6], field[7])
                cursor.execute(sqlstmt)
                descr = field[8].strip()
                if len(descr) > 1:
                    sqlstmt2 = "INSERT into gene_description(ID, genome, gID, description) values (NULL, '%s', '%s', '%s')" % (genome, field[1], field[8].strip())
                    cursor.execute(sqlstmt2)

                sqlstmt3 = "INSERT into gene_synonyms(ID, genome, gID, synonym) values (NULL, '%s', '%s', '%s')" % (genome, field[1], field[1].strip())
                cursor.execute(sqlstmt3)
                synonyms = field[4].split("|")
                for entry in synonyms:
                    try:
                        if entry != "-" and entry != field[1].strip():
                            sqlstmt3 = "INSERT into gene_synonyms(ID, genome, gID, synonym) values (NULL, '%s', '%s', '%s')" % (genome, field[1], entry.strip())
                            cursor.execute(sqlstmt3)
                    except sqlite.OperationalError:
                        pass
            except sqlite.OperationalError:
                print "could not register %s" % (line)

    cursor.execute("create index genIdx1 on gene_info(genome)")
    cursor.execute("create index genIdx2 on gene_description(genome)")
    cursor.execute("create index genIdx3 on gene_synonyms(genome)")
    cursor.execute("create index gIDIdx1 on gene_info(gID)")
    cursor.execute("create index gIDIdx2 on gene_description(gID)")
    cursor.execute("create index gIDIdx3 on gene_synonyms(gID)")
    cursor.execute("create index synIdx on gene_synonyms(synonym)")
    db.commit()
    cursor.close()
    db.close()