###########################################################################
#                                                                         #
# C O P Y R I G H T   N O T I C E                                         #
#  Copyright (c) 2003-10 by:                                              #
#    * California Institute of Technology                                 #
#                                                                         #
#    All Rights Reserved.                                                 #
#                                                                         #
# Permission is hereby granted, free of charge, to any person             #
# obtaining a copy of this software and associated documentation files    #
# (the "Software"), to deal in the Software without restriction,          #
# including without limitation the rights to use, copy, modify, merge,    #
# publish, distribute, sublicense, and/or sell copies of the Software,    #
# and to permit persons to whom the Software is furnished to do so,       #
# subject to the following conditions:                                    #
#                                                                         #
# The above copyright notice and this permission notice shall be          #
# included in all copies or substantial portions of the Software.         #
#                                                                         #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,         #
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF      #
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND                   #
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS     #
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN      #
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN       #
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE        #
# SOFTWARE.                                                               #
###########################################################################
#
__all__ = ["motif", "homology", "geneinfo", "protein"]

import cistematic
from cistematic.genomes import Genome, geneDB
import shutil, tempfile, os
from os import environ

if environ.get("CISTEMATIC_TEMP"):
    cisTemp = environ.get("CISTEMATIC_TEMP")
else:
    cisTemp = "/tmp"

tempfile.tempdir = cisTemp

goDict = {}
annotDict = {}
global cache 
cache = {}


def cacheGeneDB(genome):
    """ save a copy of a genome's gene database to a local cache.
    """
    if genome not in cache:
        try:
            tempgen = "%s.db" % tempfile.mktemp()
            shutil.copyfile(geneDB[genome], tempgen)
            cache[genome] = tempgen
        except:
            print "could not cache genome %s" % genome
    else:
        tempgen = cache[genome]

    return tempgen


def uncacheGeneDB(genome=""):
    """ remove the local copy of a genome's gene database.
    """
    global cache
    if genome in cache:
        try:
            os.remove(cache[genome])
        except:
            print "could not delete %s" % cache[genome]

        del cache[genome]
    else:
        for gen in cache:
            try:
                os.remove(cache[gen])
            except:
                print "could not delete %s" % cache[gen]

        cache = {}


def cachedGenomes():
    """ return lists of genomes with a gene database in the local cache.
    """
    return cache.keys()


def chooseDB(genome, dbfile=""):
    """ helper function to use genome's gene database from the local cache if present.
    """
    global cache
    if dbfile == "" and genome in cache:
        dbfile = cache[genome]

    return dbfile


def readChromosome(genome, chrom, db=""):
    """ return sequence for entire chromosome
    """
    aGenome = Genome(genome, chrom, dbFile=chooseDB(genome, db))
    return aGenome.getChromosomeSequence()


def getGenomeEntries(genome, db=""):
    """ return the entries for a given genome. 
    """
    global cache
    if db == "" and genome in cache:
        db = cache[genome]

    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    return (genome, aGenome.allGIDs())


def getGenomeGeneIDs(genome, db=""):
    """ return the entries for a given genome. 
    """
    global cache
    if db == "" and genome in cache:
        db = cache[genome]

    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    return aGenome.allGIDs()


def getChromoGeneEntries(chromosome, lowerbound=-1, upperbound=-1, db=""):
    """ return the entries for a given chromosome.
    """
    (genome, chromID) = chromosome
    aGenome = Genome(genome, chromID, dbFile=chooseDB(genome, db))
    return aGenome.chromGeneEntries(chromID, lowerbound, upperbound)


def getChromosomeNames(genome, db="", partition=1, slice=0):
    """ return the chromosomes for a given genome.
    """
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    return aGenome.allChromNames(partition, slice)


def geneEntry(geneID, db="", version="1"):
    """ returns (chrom, start, stop, length, sense) for a given geneID
    """
    genome = geneID[0]
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    return aGenome.geneInfo(geneID, version)


def compNT(nt):
    """ returns the complementary basepair to base nt
    """
    compDict = {"A": "T", "T": "A",
                "G": "C", "C": "G",
                "S": "S",
                "W": "W",
                "R": "Y", "Y": "R",
                "M": "K", "K": "M",
                "H": "D", "D": "H",
                "B": "V", "V": "B",
                "N": "N",
                "a": "t", "t": "a",
                "g": "c", "c": "g",
                "n": "n",
                "z": "z"
    }

    return compDict.get(nt, "N")


def complement(sequence, length=-1):
    """ returns the complement of the sequence.
    """
    newSeq = ""
    seqLength = len(sequence)
    if length == seqLength or length < 0:
        seqList = list(sequence)
        seqList.reverse()
        return "".join(map(compNT, seqList))

    for index in range(seqLength - 1,seqLength - length - 1, -1):
        try:
            newSeq += compNT(sequence[index])
        except:
            newSeq += "N"

    return newSeq


def upstreamToNextGene(geneID, radius, version="1", db=""):
    """ return distance to gene immediately upstream.
    """
    upstream = radius
    genome = geneID[0]
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    try:
        if aGenome.checkGene(geneID):
            (chrom, start, stop, length, sense) = aGenome.geneInfo(geneID, version)
            if sense == "F":
                upstream = aGenome.leftGeneDistance(geneID, upstream, version)
            else:
                upstream = aGenome.rightGeneDistance(geneID, upstream, version)
    except:
        pass

    return upstream


def downstreamToNextGene(geneID, radius, version="1", db=""):
    """ return distance to gene immediately downstream.
    """
    downstream = radius
    genome = geneID[0]
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))

    try:
        if aGenome.checkGene(geneID):
            (chrom, start, stop, length, sense) = aGenome.geneInfo(geneID, version)
            if sense == "F":
                downstream = aGenome.rightGeneDistance(geneID, downstream, version)
            else:
                downstream = aGenome.leftGeneDistance(geneID, downstream, version)
    except:
        pass

    return downstream


def retrieveFeatures(match, radius, featureType="", db=""):
    """ return the features around a given match.
    """
    (chromosome, hit) = match
    (genome, chromID) = chromosome
    lowerboundHit = int(hit[0]) - int(radius)
    if lowerboundHit < 0:
        lowerboundHit = 0

    aGenome = Genome(genome, chromID, dbFile=chooseDB(genome, db))
    results = aGenome.getFeaturesIntersecting(chromID, lowerboundHit, 2 * int(radius), featureType)

    return results


def retrieveSeqFeatures(geneID, upstream, cds, downstream, boundToNextGene = False, geneDB=""):
    """ retrieve CDS features upstream, all or none of the cds, and downstream of a geneID.
        Feature positions are normalized and truncated to local sequence coordinates.
    """
    results = []
    (genome, gID) = geneID
    aGenome = Genome(genome, dbFile=chooseDB(genome, geneDB))
    if True:
        seqstart = 0
        seqlen = 0
        if aGenome.checkGene(geneID):
            (chrom, start, stop, length, sense) = aGenome.geneInfo(geneID)
            if stop < start:
                pos = stop
                stop = start
                start = pos

            if sense == "F":
                # figure out normalized seqstart and seqstop
                if upstream > 0:
                    if boundToNextGene:
                        upstream = aGenome.leftGeneDistance(geneID, upstream)

                    seqstart = start - upstream
                    if seqstart < 0:
                        seqstart = 0
                        upstream = start

                    seqlen = upstream

                if cds > 0:
                    if seqlen == 0:
                        seqstart = start

                    seqlen += length

                if downstream > 0:
                    if boundToNextGene:
                        downstream = aGenome.rightGeneDistance(geneID, downstream)

                    if seqlen == 0:
                        seqstart = stop

                    seqlen += downstream

                # process features
                allresults = aGenome.getFeaturesIntersecting(chrom, seqstart, seqlen, "CDS")
                for entry in allresults:
                    (fname, fversion, fchromosome, fstart, fstop, forientation, ftype) = entry
                    if fstop < fstart:
                        fstop = fstart
                        fstart = fstop

                    forstart = fstart - seqstart       # normalize
                    if forstart < 0:                # truncate
                        forstart = 0

                    forstop = fstop - seqstart # normalize
                    if forstop > seqlen:  # truncate
                        forstop = seqlen

                    if (ftype, forstart, forstop, forientation) not in results:
                        results.append((ftype, forstart, forstop, forientation))
            else:
                # figure out normalized seqstart and seqstop
                if upstream > 0:
                    if boundToNextGene:
                        upstream = aGenome.rightGeneDistance(geneID, upstream)

                    seqstart = stop + upstream
                    seqlen = upstream

                if cds > 0:
                    if seqlen == 0:
                        seqstart = stop

                    seqlen += length

                if downstream > 0:
                    if boundToNextGene:
                        downstream = aGenome.leftGeneDistance(geneID, downstream)

                    if seqlen == 0:
                        seqstart = start

                    seqlen += downstream

                # process features
                allresults = aGenome.getFeaturesIntersecting(chrom, seqstart - seqlen, seqlen, "CDS")
                for entry in allresults:
                    (fname, fversion, fchromosome, fstart, fstop, forientation, ftype) = entry
                    if fstop < fstart:
                        fstop = fstart
                        fstart = fstop

                    revstart = seqstart - fstop
                    if revstart < 0:
                        revstart = 0

                    revstop = seqstart - fstart
                    if revstop > seqlen:
                        fstop = seqlen

                    if (ftype, revstart, revstop, forientation) not in results:
                        results.append((ftype, revstart, revstop, forientation))
    else:
        pass

    return results


def getFeaturesIntersecting(genome, chrom, start, length, db="", ftype="CDS"):
    """ return features of type ftype that fall within the given region.
    """
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    return aGenome.getFeaturesIntersecting(chrom, start, length, ftype)


def retrieveSequence(genome, chrom, start, stop, sense="F", db=""):
    """ retrieve a sequence given a genome, chromosome, start, stop, and sense.
    """
    entrySeq = ""
    length = abs(stop - start) + 1
    try:
        aGenome = Genome(genome, dbFile=chooseDB(genome, db))
        if sense == "F":
            if start  < 1:
                seqStart  = 0
            else:
                seqStart  = start - 1

            sequence = aGenome.sequence(chrom, seqStart, length)
            entrySeq = sequence
        else:
            seqStart  = stop - 1
            entrySeq= aGenome.sequence(chrom, seqStart, length)

    except IOError:
        print "Couldn't retrieve sequence %s %s %s %s %s" % (genome, chrom, start, stop, sense) 

    return entrySeq


def retrieveCDS(geneID, maskCDS=False, maskLower=False, db="", version="1"):
    """ retrieveCDS() - retrieve a sequence given a gene identifier
    """
    entrySeq = ""
    genome = geneID[0]
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    try:
        if aGenome.checkGene(geneID):
            entrySeq = aGenome.geneSeq(geneID, maskCDS, maskLower, version)  
    except IOError:
        print "Could not find %s " % str(geneID)

    return entrySeq


def retrieveUpstream(geneID, upstream, maskCDS=False, maskLower=False, boundToNextGene=False, db="", version="1"):
    """ retrieve sequence 5' of cds of length upstream for a given a gene identifier
    """
    entrySeq = ""
    genome = geneID[0]
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))    
    try:
        if aGenome.checkGene(geneID):
            (chrom, start, stop, length, sense) = aGenome.geneInfo(geneID, version)
            if sense == "F":
                if boundToNextGene:
                    upstream = aGenome.leftGeneDistance(geneID, upstream, version)

                if (start - upstream) > 1:
                    seqStart  = start - upstream - 1
                    seqLength = upstream
                else:
                    seqStart  = 0
                    seqLength = upstream
            else:
                if boundToNextGene:
                    upstream = aGenome.rightGeneDistance(geneID, upstream, version)

                seqStart  = stop
                seqLength = upstream

            sequence = aGenome.sequence(chrom, seqStart, seqLength, maskCDS, maskLower)
            # do CDS masking here....
            if sense == "F":
                entrySeq = sequence
            else:
                entrySeq = complement(sequence, upstream)

    except IOError:
        print "Couldn't find ", geneID

    return entrySeq


def retrieveDownstream(geneID, downstream, maskCDS=False, maskLower=False, boundToNextGene=False, db="", version="1"):
    """ retrieve sequence 3' of CDS of length downstream for a given a gene identifier
    """
    entrySeq = ""
    genome = geneID[0]
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))    
    if True:
        if aGenome.checkGene(geneID):
            (chrom, start, stop, length, sense) = aGenome.geneInfo(geneID, version)
            if sense == "F":
                if boundToNextGene:
                    downstream = aGenome.rightGeneDistance(geneID, downstream, version)

                seqStart  = stop - 1
                seqLength = downstream + 1
            else:
                if boundToNextGene:
                    downstream = aGenome.leftGeneDistance(geneID, downstream, version)

                if (start - downstream) > 1:
                    seqStart  = start - downstream
                    seqLength = downstream
                else:
                    seqStart  = 0
                    seqLength = stop

            sequence = aGenome.sequence(chrom, seqStart, seqLength, maskCDS, maskLower)
            # do CDS masking here
            if sense == "F":
                entrySeq = sequence
            else:
                entrySeq = complement(sequence, downstream)

    return entrySeq


def retrieveSeq(geneID, upstream, cds, downstream, geneDB="", maskLower = False, boundToNextGene = False, version="1"):
    """ retrieve upstream, all or none of the cds, and downstream of a geneID
    """
    geneSeq = ""
    if int(cds) == 2:
        maskCDS = True
    else:
        maskCDS = False

    if upstream > 0:
        geneSeq += retrieveUpstream(geneID, upstream, maskCDS, maskLower, boundToNextGene, geneDB, version)

    if cds > 0:
        geneSeq += retrieveCDS(geneID, maskCDS, maskLower, geneDB, version)

    if downstream > 0:
        geneSeq += retrieveDownstream(geneID, downstream, maskCDS, maskLower, boundToNextGene, geneDB, version)

    if len(geneSeq) == 0:
        print "retrieveSeq Warning: retrieved null sequence for %s: %s (splice form %s) from geneDB %s" % (geneID[0], geneID[1], version, geneDB)

    return geneSeq


def retrieveAll(genome, genes, upstream, downstream, outputFilePath):
    """ retrieve set of upstream and downstrean sequences for a list of genes in a genome and save them to a file.
    """
    outFile = open(outputFilePath, "w")			
    for gene in genes:
        print "Processing " , gene
        outFile.write("> %s \n" % (gene))
        geneID = (genome, gene)
        outFile.write("%s\n" % retrieveSeq(geneID, upstream, 0, downstream))

    outFile.close()


def fasta(geneID, seq):
    """ write a fasta formated seq with geneID in the header.
    """
    fastaString = "> %s-%s\n%s\n" % (geneID[0],geneID[1], seq)

    return fastaString


def loadGOInfo(genome, db=""):
    """ load GO for a given genome
    """
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    if genome not in goDict.keys():
        goDict[genome] = aGenome.allGOInfo()


def getGOInfo(geneID, db=""):
    """ retrieve GO info for geneID
    """
    (genome, locus) = geneID
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    try:
        return aGenome.goInfo(geneID)
    except:
        return []


def getGOIDCount(genome, GOID, db=""):
    """ retrieve count of genes with a particular GOID.
    """
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    try:
        return aGenome.getGOIDCount(GOID)
    except:
        return []


def allGOTerms(genome, db=""):
    """ return all GO Terms.
    """
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    try:
        return aGenome.allGOterms()
    except:
        return []


def getAllGOInfo(genome, db=""):
    """ return all GO Info.
    """
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    try:
        return aGenome.allGoInfo()
    except:
        return []


def loadAnnotInfo(genome, db=""):
    """ load Annotations for a given genome
    """
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    if genome not in annotDict.keys():
        annotDict[genome] = aGenome.allAnnotInfo()


def getAnnotInfo(geneID, db=""):
    """ retrieve Annotations for a given geneID
    """
    (genome, locus) = geneID
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    try:
        return aGenome.annotInfo(geneID)
    except:
        return []


def getAllAnnotInfo(genome, db=""):
    """ return all Annotation Info.
    """
    aGenome = Genome(genome, dbFile=chooseDB(genome, db))
    try:
        return aGenome.allAnnotInfo()
    except:
        return []


def sanitize(inSeq, windowSize=16):
    """ make sure that very simple repeats are out of the sequence. 
        Soft-mask any window that has windowSize - 2 of mononucleotides 
        and (windowSize / 2) - 1 non-GC dinucleotides.     
    """
    seqlen = len(inSeq)
    outSeq = list(inSeq.upper())
    winmin2 = windowSize - 2
    winhalf = windowSize/2 - 1
    for pos in range(seqlen - windowSize):
        window = inSeq[pos:pos + windowSize].upper()
        if window.count("A") > winmin2 or window.count("C") > winmin2 or window.count("G") > winmin2 or window.count("T") > winmin2:
            for index in range(windowSize):
                outSeq[pos + index] = outSeq[pos + index].lower()

        if window.count("AC") >= winhalf or window.count("AG") >= winhalf or window.count("AT") >= winhalf or window.count("CT") >= winhalf or window.count("GT") >= winhalf or window.count("TA") >= winhalf or window.count("TC") >= winhalf or window.count("TG") >= winhalf or window.count("GA") > winhalf or window.count("CA") > winhalf:
            for index in range(windowSize):
                outSeq[pos + index] = outSeq[pos + index].lower()

    return "".join(outSeq)


def featuresIntersecting(genome, posList, radius, ftype, name="", chrom="", version="", db="", extendGen="", replaceMod=False):
    """ returns a dictionary of matching features to positions of the double form (chromosome, position).
        Only positions with features within radius are returned.
    """
    resultDict = {}
    if extendGen != "":
        aGenome = Genome(genome, dbFile=chooseDB(genome, db), inRAM=True)
        aGenome.extendFeatures(extendGen, replace = replaceMod)
    else:
        aGenome = Genome(genome, dbFile=chooseDB(genome, db))

    features = aGenome.getFeatures(ftype, name, chrom, version)
    if len(posList) < 1 or len(features) < 1:
        return resultDict

    chromList = features.keys()
    for (chrom, pos) in posList:
        tempList = []
        if chrom not in chromList:
            continue

        for (name, version, chromosome, start, stop, orientation, atype) in features[chrom]:
            if (pos + radius) < start or (pos - radius) > stop:
                continue

            tempList.append((name, version, chromosome, start, stop, orientation, atype))

        if len(tempList) > 0:
            resultDict[(chrom, pos)] = tempList

    return resultDict


def genesIntersecting(genome, posList, name="", chrom="", version="", db="", flank=0, extendGen="", replaceMod=False):
    """ returns a dictionary of matching genes to positions of the double form (chromosome, position).
        Only positions with features within radius are returned.
    """
    resultDict = {}
    if extendGen != "":
        aGenome = Genome(genome, dbFile=chooseDB(genome, db), inRAM=True)
        aGenome.extendFeatures(extendGen, replace = replaceMod)
    else:
        aGenome = Genome(genome, dbFile=chooseDB(genome, db))

    genes = aGenome.getallGeneInfo(name, chrom, version)
    if len(posList) < 1 or len(genes) < 1:
        return resultDict

    chromList = genes.keys()
    for (chrom, pos) in posList:
        tempList = []
        if chrom not in chromList:
            continue

        for (name, chromosome, start, stop, orientation) in genes[chrom]:
            if start-flank <= pos <= stop+flank:
                tempList.append((name, "noversion", chromosome, start, stop, orientation))

        if len(tempList) > 0:
            resultDict[(chrom, pos)] = tempList

    return resultDict