###########################################################################
#                                                                         #
# C O P Y R I G H T   N O T I C E                                         #
#  Copyright (c) 2003-10 by:                                              #
#    * California Institute of Technology                                 #
#                                                                         #
#    All Rights Reserved.                                                 #
#                                                                         #
# Permission is hereby granted, free of charge, to any person             #
# obtaining a copy of this software and associated documentation files    #
# (the "Software"), to deal in the Software without restriction,          #
# including without limitation the rights to use, copy, modify, merge,    #
# publish, distribute, sublicense, and/or sell copies of the Software,    #
# and to permit persons to whom the Software is furnished to do so,       #
# subject to the following conditions:                                    #
#                                                                         #
# The above copyright notice and this permission notice shall be          #
# included in all copies or substantial portions of the Software.         #
#                                                                         #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,         #
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF      #
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND                   #
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS     #
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN      #
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN       #
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE        #
# SOFTWARE.                                                               #
###########################################################################
#
# motif.py - defines the motif object and its methods in cistematic
from string import upper, lower
from math import log, exp
from copy import deepcopy
from cistematic.core import complement
from cistematic.cisstat.score import pearsonCorrelation
import re, os, tempfile

if os.environ.get("CISTEMATIC_ROOT"):
    cisRoot = os.environ.get("CISTEMATIC_ROOT") 
else:
    cisRoot = "/proj/genome"

if os.environ.get("CISTEMATIC_TEMP"):
    cisTemp = os.environ.get("CISTEMATIC_TEMP")
else:
    cisTemp = "/tmp"

tempfile.tempdir = cisTemp

try:
    import cistematic.core._motif as _motif
    hasMotifExtension = True
except:
    hasMotifExtension = False

matrixRow = {"A": 0,
             "C": 1,
             "G": 2,
             "T": 3
}

symbolToMatrix = {"A": [1.0, 0.0, 0.0, 0.0],
                  "C": [0.0, 1.0, 0.0, 0.0],
                  "G": [0.0, 0.0, 1.0, 0.0],
                  "T": [0.0, 0.0, 0.0, 1.0],
                  "W": [0.5, 0.0, 0.0, 0.5],
                  "S": [0.0, 0.5, 0.5, 0.0],
                  "R": [0.5, 0.0, 0.5, 0.0],
                  "Y": [0.0, 0.5, 0.0, 0.5],
                  "M": [0.5, 0.5, 0.0, 0.0],
                  "K": [0.0, 0.0, 0.5, 0.5],
                  "H": [0.33, 0.33, 0.0, 0.34],
                  "B": [0.0, 0.33, 0.33, 0.34],
                  "V": [0.33, 0.33, 0.34, 0.0],
                  "D": [0.33, 0.0, 0.33, 0.34],
                  "N": [0.25, 0.25, 0.25, 0.25]
}

reMAP = {"A": "A",
         "C": "C",
         "G": "G",
         "T": "T",
         "N": "[ACGT]",
         "W": "[AT]",
         "S": "[GC]",
         "R": "[AG]",
         "Y": "[CT]",
         "M": "[AC]",
         "K": "[GT]",
         "H": "[ACT]",
         "B": "[CGT]",
         "V": "[ACG]",
         "D": "[AGT]",
         "|": "I"
}

motifDict = {"A": ["W", "R", "M", "H", "V", "D"],
             "T": ["W", "Y", "K", "H", "B", "D"],
             "C": ["S", "Y", "M", "H", "B", "V"],
             "G": ["S", "R", "K", "B", "V", "D"]
}


class Motif:
    """ The Motif class is the heart of cistematic. It captures both the consensus, 
        PWM, and Markov(1) versions of the motif, as well as provides methods for scanning 
        sequences for matches.
    """
    motifSeq = ""
    motifPWM = []
    reversePWM = []
    motifMarkov1 = []
    revMarkov1 = []
    tagID = ""
    sequences = []
    annotations = []
    strictConsensus = ""
    threshold = 1.0
    info = ""


    def __init__(self, tagID="", motif="", PWM=[], seqs=[], thresh=0.0, info="", motifFile="", seqfile=""):
        """ initialize a Motif object either with (A) a tagID identifier and either of 
            (i) a IUPAC consensus or (ii) a PWM or (iii) example sequences, or (B) from 
            a motif file in Cistematic motif format.
        """        
        fileTagID = ""
        if motifFile != "":
            (fileTagID, motif, PWM, seqs, thresh, info) = self.readMotif(motifFile)

        if len(tagID) > 0:
            self.setTagID(tagID)
        elif len(fileTagID) > 0:
            self.setTagID(fileTagID)
        else:
            self.setTagID("default")

        if motif <> "":
            self.setMotif(motif)

        # PWM overrules the motif
        if len(PWM) > 1:
            self.setPWM(PWM)

        # a seqfile can be used to create a list of sequences
        if len(seqfile) > 0:
            seqs = self.readSeqFile(seqfile)

        # Sequences overrules the PWM and motif
        if len(seqs) > 1:
            self.setSequences(seqs)
            self.setThreshold(len(seqs[0]))

        self.setThreshold(float(thresh))
        if len(info) > 0:
            self.setInfo(info)


    def __len__(self):
        """ returns the length of the motif 
        """
        return len(self.motifPWM)


    def readMotif(self, motifFile):
        """ read motif in cistmeatic motif format to initialize motif instance.
        """
        motif = ""
        PWM = []
        seqs = []
        threshold = 0.0
        info = ""
        infile = open(motifFile, "r")
        for line in infile:
            if len(line) < 4 or line[0] == "#":
                continue

            fields = line.strip().split("\t")
            fieldType = fields[0].lower()
            if fieldType not in ["tagid", "motif", "acgt", "sequence", "threshold", "info"]:
                print "could not process line %s" % line
                continue

            if len(fields) < 2:
                continue

            if fieldType == "motif":
                motif = fields[1]
            elif fieldType == "acgt":
                PWM.append([float(fields[1]), float(fields[2]), float(fields[3]), float(fields[4])])
            elif fieldType == "sequence":
                seqs.append(fields[1])
            elif fieldType == "threshold":
                threshold = float(fields[1])
            elif fieldType == "info":
                info = fields[1].strip()
            elif fieldType == "tagid":
                tagID = fields[1]

        infile.close()

        return (tagID, motif, PWM, seqs, threshold, info)


    def readSeqFile(self, seqFile):
        """ read sequences from a sequence file.
        """
        seqs = []
        infile = open(seqFile, "r")
        for line in infile:
            seqs.append(line.strip().upper())

        infile.close()

        return seqs


    def saveMotif(self, motFile):
        """ Save motif in cistematic motif format.
        """
        outfile = open(motFile, "w")
        outfile.write("tagid\t%s\n" % self.tagID)
        outfile.write("info\t%s\n" % self.info.strip())
        outfile.write("threshold\t%s\n" % str(self.threshold))
        outfile.write("motif\t%s\n" % self.buildConsensus())
        for col in self.motifPWM:
            outfile.write("acgt\t%f\t%f\t%f\t%f\n" % (col[0], col[1], col[2], col[3]))

        for seq in self.sequences:
            outfile.write("sequence\t%s\n" % seq)

        outfile.close()


    def setTagID(self, tag):
        """ set motif identifier.
        """
        self.tagID = tag


    def setInfo(self, info):
        """ set motif info string.
        """
        self.info = info


    def setThreshold(self, threshold):
        """ set a pre-defined threshold or the highest-possible threshold, otherwise.
        """
        (sForward, sBackward) = self.scoreMotif(self.buildStrictConsensus())
        if sForward > sBackward:
            self.threshold = sForward
        else:
            self.threshold = sBackward

        if threshold < self.threshold and threshold > 0:
            self.threshold = threshold


    def setMotif(self, motif):
        """ set the motif PWM using a IUPAC consensus. 
        """
        self.motifSeq = motif
        self.motifPWM = self.buildPWM(motif)
        self.buildReversePWM()
        self.strictConsensus = self.buildStrictConsensus()


    def setSequences(self, seqs):
        """ set the founder sequences for the motif and recalculate PWM with them. 
        """
        self.sequences = seqs
        self.calculatePWMfromSeqs()
        self.calculateMarkov1()


    def setPWM(self, PWM):
        """ set the PWM for the motif and calculate consensus.
        """
        self.motifPWM = PWM
        self.buildReversePWM()
        self.motifSeq = self.buildConsensus()
        self.strictConsensus = self.buildStrictConsensus()


    def appendToPWM(self, col):
        """ add a column to the PWM for the motif 
        """
        self.motifPWM.append(col)


    def buildPWM(self, motif):
        """ returns the PWM for the provided consensus motif. 
        """
        PWM = []
        for letter in upper(motif):
            PWM.append(symbolToMatrix[letter])

        return PWM


    def buildReversePWM(self):
        """ returns the reverse PWM for the motif 
        """
        theRevPWM = []
        tempPWM = deepcopy(self.motifPWM)
        tempPWM.reverse()
        for col in tempPWM:
            theRevPWM.append([col[matrixRow["T"]],  col[matrixRow["G"]], col[matrixRow["C"]], col[matrixRow["A"]]])

        self.reversePWM = deepcopy(theRevPWM)

        return self.reversePWM


    def getPWM(self):
        """ returns the PWM for the motif 
        """
        return self.motifPWM


    def printPWM(self):
        """ print the PWM and the consensus
        """
        aRow = ""
        cRow = ""
        gRow = ""
        tRow = ""
        cons = self.buildConsensus()
        consLine = "Cons:"
        for NT in cons:
            consLine += "\t"
            consLine += NT

        for col in self.motifPWM:
            aRow = "%s%s\t" % (aRow, str(round(col[matrixRow["A"]],4)))
            cRow = "%s%s\t" % (cRow, str(round(col[matrixRow["C"]],4)))
            gRow = "%s%s\t" % (gRow, str(round(col[matrixRow["G"]],4)))
            tRow = "%s%s\t" % (tRow, str(round(col[matrixRow["T"]],4)))

        print "A:\t%s\nC:\t%s\nG:\t%s\nT:\t%s\n" % (aRow, cRow, gRow, tRow)
        print "%s\n" % consLine


    def saveLogo(self, outfilename, height=-1, width=-1):
        """ saves a logo version of the motif as outfilename (assumes has .png)
            if the motif is built from sequences.
            will fail if weblogo 2.8.2 package is not installed in correct place.
        """
        logoPath = "%s/programs/weblogo/seqlogo" % cisRoot
        if outfilename[-4:] in [".png", ".PNG"]:
            outfilename = outfilename[:-4]

        if len(self.sequences) < 1:
            print "cannot run logo representation without founder sequences"
        else:
            if True:
                seqfilename = "%s.tmp" % tempfile.mktemp()
                seqfile = open(seqfilename, "w")
                for sequence in self.sequences:
                    seqfile.write("%s\n" % sequence)

                seqfile.flush()
                dimensions = ""
                if height > 0:
                    dimensions += "-h %d " % height
                if width > 0:
                    dimensions +=  "-w %d " % width
                else:
                    dimensions += "-w %d " % len(self.motifPWM)

                cmd = logoPath + " -f " + seqfilename + " -F PNG -c " + dimensions + "-a -Y -n -M -k 1 -o " + outfilename
                contents = os.system(cmd)
                seqfile.close()
                os.remove(seqfilename)
            else:
                print "failed to make logo: expecting weblogo 2.8.2 package in %s" % logoPath
                print "also check if ghostscript package is correctly installed."


    def getSymbol(self, col):
        """ helper function for buildConsensus()
        """
        for NT in ["A", "C", "G", "T"]:
            row = matrixRow[NT]
            if col[row] > 0.9:
                return NT

        aColValue = col[matrixRow["A"]]
        cColValue = col[matrixRow["C"]]
        gColValue = col[matrixRow["G"]]
        tColValue = col[matrixRow["T"]]

        dualsList = [("R", aColValue + gColValue),
                     ("Y", tColValue + cColValue),
                     ("W", aColValue + tColValue),
                     ("S", cColValue + gColValue),
                     ("M", aColValue + cColValue),
                     ("K", tColValue + gColValue)
        ]

        bestDual = self.getBestSymbol(dualsList)
        if bestDual[1] > 0.9:
            return bestDual[0]

        trioList = [("B", cColValue + gColValue + tColValue),
                    ("D", aColValue + gColValue + tColValue),
                    ("H", aColValue + cColValue + tColValue),
                    ("V", aColValue + cColValue + gColValue)
        ]

        bestTrio = self.getBestSymbol(trioList)
        if bestTrio[1] > 0.9:
            return bestTrio[0]

        return "N"


    def getBestSymbol(self, symbolProbabilityList):
        bestSymbol = symbolProbabilityList[0]
        for symbol in symbolProbabilityList[1:]:
            if symbol[1] > bestSymbol[1]:
                bestSymbol = symbol

        return bestSymbol


    def buildConsensus(self):
        """ returns the best consensus using the IUPAC alphabet.
        """
        consensus = ""
        for col in self.motifPWM:
            consensus += self.getSymbol(col)

        return consensus


    def buildStrictConsensus(self):
        """ returns the best consensus using solely nucleotides.
        """
        consensus = ""
        for col in self.motifPWM:
            mRow = []
            for nt in ("A", "C", "G", "T"):
                mRow.append((col[matrixRow[nt]], nt))

            mRow.sort()
            consensus += mRow[3][1]

        return consensus


    def bestConsensusScore(self):
        """ returns the best consensus score possible.
        """
        score = 0.0
        for col in self.motifPWM:
            mRow = []
            for nt in ["A", "C", "G", "T"]:
                mRow.append((col[matrixRow[nt]], nt))

            mRow.sort()
            score += mRow[3][0]

        return score


    def expected(self, length, background=[0.25, 0.25, 0.25, 0.25], numMismatches=0):
        """ returns the expected number of matches to the consensus in a sequence of a given length and background. 
        """
        expectedNum = length * self.consensusProbability(background, numMismatches)
        return expectedNum


    def consensusProbability(self, background=[0.25, 0.25, 0.25, 0.25], numMismatches=0):
        """ returns the probability of the consensus given the background.
        """
        prob = 0
        motifs = []
        if numMismatches> 0:
            seqs = self.seqMismatches(self.buildConsensus().upper(), numMismatches)
            motifs = seqs.split("|")
        else:
            motifs.append(self.buildConsensus())

        for theCons in motifs:
            motProb = 0
            for NT in theCons:
                currentProb = 0.0
                if NT in ("A", "W", "R", "M", "H", "V", "D", "N"):
                    currentProb += background[matrixRow["A"]]

                if NT in ("C", "S", "Y", "M", "H", "B", "V", "N"):
                    currentProb += background[matrixRow["C"]]

                if NT in ("G", "S", "R", "K", "B", "V", "D", "N"):
                    currentProb += background[matrixRow["G"]]

                if NT in ("T", "W", "Y", "K", "H", "B", "D", "N"):
                    currentProb += background[matrixRow["T"]]

                motProb = motProb + log(currentProb)

            prob += exp(motProb)

        return prob


    def pwmProbability(self, background):
        """ returns probability of the PWM.
        """
        prob = 1.0
        for row in self.motifPWM:
            currentProb = 0.0
            for NT in ["A", "C", "G", "T"]:
                currentProb += row[matrixRow[NT]] * background[matrixRow[NT]]

            prob = prob * currentProb

        return prob


    def revComp(self):
        """ returns the reverse complement of the consensus of this motif.
        """
        return complement(self.buildConsensus(), len(self.motifPWM))


    def numberOfN(self):
        """ returns numbers of effective Ns in motif.
        """
        index = 0
        for col in self.motifPWM:
            if self.getSymbol(col) == "N":
                index += 1

        return index


    def buildMismatchSeq(self, rootSeq, tailSeq, numMismatches):
        """ helper function called from seqMismatches().
        """
        finalSeq = ""
        tailLen = len(tailSeq)
        if tailLen < 1 or numMismatches < 1:
            return rootSeq + tailSeq

        for pos in range(tailLen - numMismatches + 1):
            newRootSeq = rootSeq
            newRootSeq += tailSeq[:pos]
            newRootSeq += "N"
            finalSeq += self.buildMismatchSeq(newRootSeq, tailSeq[pos + 1:], numMismatches - 1)
            finalSeq += "|"

        return finalSeq[:-1]


    def seqMismatches(self, seq, numMismatches):
        """ Returns list of sequences that will be used by initializeMismatchRE().
        """
        return self.buildMismatchSeq("", seq, numMismatches)


    def probMatchPWM(self, PWM, NT, colIndex):
        """ returns the probability of seeing that particular nucleotide according to the PSFM.
        """

        if NT in ["A", "T", "C", "G"]:
            row = matrixRow[NT]
            return PWM[colIndex][row]

        if NT == "N":
            return 1.0

        currentProb = 0.0
        for motifNucleotide in ["A", "T", "C", "G"]:
            if NT in motifDict[motifNucleotide]:
                row = matrixRow[NT]
                currentProb += PWM[colIndex][row]

        return currentProb


    def psfmOdds(self, PWM, NT, colIndex, background=[0.25, 0.25, 0.25, 0.25]):
        """ calculates the odds of nucleotide NT coming from position colIndex in thePWM
            as opposed to coming from the background.
        """

        currentProb = self.probMatchPWM(PWM, NT, colIndex)
        backgroundProb = self.getBackgroundProbability(self, NT, background)

        try:
            odds = currentProb / backgroundProb
        except ZeroDivisionError:
            odds = 1.0

        return odds


    def getBackgroundProbability(self, NT, background=[0.25, 0.25, 0.25, 0.25]):

        if NT in ["A", "T", "C", "G"]:
            row = matrixRow[NT]
            return background[row]

        if NT == "N":
            return 1.0

        backgroundProb = 0.0
        for motifNucleotide in ["A", "T", "C", "G"]:
            if NT in motifDict[motifNucleotide]:
                row = matrixRow[NT]
                backgroundProb += background[row]

        return backgroundProb


    def ntMatch(self, motifNT, seqNT):
        """ returns True if seqNT matches motifNT.
        """
        if motifNT == seqNT:
            return True

        if seqNT == "N" or motifNT == "N":
            return True

        if motifNT in motifDict[seqNT]:
            return True

        return False


    def scoreMotif(self, aSeq, diff=1000):
        """ calculates the consensus score using the PSFM
        """
        motLength = len(self.motifPWM)
        if len(aSeq) < motLength:
            return (0.0, 0.0)

        matchPWM = self.probMatchPWM
        motPWM = self.motifPWM
        revPWM = self.reversePWM
        theSeq = upper(aSeq)
        forwardCons = 0.0
        reverseCons = 0.0
        bestCons = 0.0
        for index in range(motLength):
            currentNT = theSeq[index]
            forwardCons += matchPWM(motPWM, currentNT, index)
            reverseCons += matchPWM(revPWM, currentNT, index)
            bestCons += matchPWM(motPWM,self.strictConsensus[index], index)

        if (forwardCons + diff) < bestCons and (reverseCons + diff) < bestCons:
            return (-1, -1)

        return (forwardCons, reverseCons)


    def scoreMotifLogOdds(self, aSeq, background=[0.25, 0.25, 0.25, 0.25]):
        """ calculates the log-odds score using the PSFM given the markov(0) background.
        """
        motLength = len(self.motifPWM)
        if len(aSeq) < motLength:
            return (0.0, 0.0)

        odds = self.psfmOdds
        motPWM = self.motifPWM
        revPWM = self.reversePWM
        theSeq = upper(aSeq)
        forwardCons = 0.0
        reverseCons = 0.0
        bestCons = 0.0
        for index in range(motLength):
            currentNT = theSeq[index]
            try:
                forwardCons += log(odds(motPWM, currentNT, index, background), 2)
            except:
                forwardCons += log(0.01, 2)

            try:
                reverseCons += log(odds(revPWM, currentNT, index, background), 2)
            except:
                reverseCons += log(0.01, 2)

            bestCons += log(odds(motPWM,self.strictConsensus[index], index, background), 2)

        return (forwardCons, reverseCons)


    def bestLogOddsScore(self, background=[0.25, 0.25, 0.25, 0.25]):
        """ calculates the best possible log-odds score using the PSFM given the markov(0) background.
        """
        motLength = len(self.motifPWM)
        odds = self.psfmOdds
        motPWM = self.motifPWM
        bestLogOdds = 0.0
        for index in range(motLength):
            bestLogOdds += log(odds(motPWM,self.strictConsensus[index], index, background), 2)

        return bestLogOdds


    def locateConsensus(self, aSeq):
        """ returns a list of positions on aSeq that match the consensus exactly.
        """
        cons = self.buildConsensus()
        revComp = self.revComp()
        motLength = len(cons)
        Position = []
        if len(aSeq) < motLength:
            return []
        else: 
            theSeq = upper(aSeq)

        pos = 0
        seqLength = len(theSeq)
        while pos <= (seqLength - motLength):
            subSeq = theSeq[pos:pos + motLength].strip()
            try:
                forwardMot = 1
                for index in range(motLength):
                    if not self.ntMatch(cons[index], subSeq[index]):
                        forwardMot = 0
                        break

                revCompMot = 1
                for index in range(motLength):
                    if not self.ntMatch(revComp[index], subSeq[index]):
                        revCompMot = 0
                        break
            except:
                print "chocked at pos %d" % pos
                forwardMot = 0

            if forwardMot == 1:
                Position.append((pos, "F"))
                pos += motLength
            elif revCompMot == 1:
                Position.append((pos, "R"))
                pos += motLength
            else:
                pos +=1

        return Position


    def compareConsensus(self, aSeq):
        """ returns a sequence with nucleotide differences from consensus in lower case.
        """
        cons = self.buildConsensus()
        revComp = self.revComp()
        motLength = len(cons)        
        if len(aSeq) < motLength:
            raise NameError, "Sequence too short"
        else: 
            theSeq = upper(aSeq)

        forwardMismatch = 0
        backwardMismatch = 0
        forwardSeq = ""
        backwardSeq = ""

        for index in range(motLength):
            if not self.ntMatch(cons[index], theSeq[index]):
                forwardMismatch += 1
                forwardSeq += lower(theSeq[index])
            else:
                forwardSeq += theSeq[index]

            if not self.ntMatch(revComp[index], theSeq[index]):
                backwardMismatch += 1
                backwardSeq += lower(theSeq[index])
            else:
                backwardSeq += theSeq[index]

        if forwardMismatch <= backwardMismatch:
            return forwardSeq
        else:
            return backwardSeq


    def scoreDiffPWM(self, compMotif):
        """ returns a score scaled from 0 (no difference) to 2 (completely different) to  
            quantify the difference between the motif and another motif.
        """
        score = 0.0
        diffPWM = self.getDiffPWM(compMotif)
        for pos in range(len(diffPWM)):
            (adiff, cdiff, gdiff, tdiff) = diffPWM[pos]
            score += abs(adiff) + abs(cdiff) + abs(gdiff) + abs(tdiff)

        score /= float(len(diffPWM))

        return score


    def getDiffPWM(self, compMotif):
        """ subtracts the PWM of compMotif from existing PWM to compare differences. 
            Note that the comparison is only done on the length of the shorter motif.
        """
        diffPWM = []
        compPWM = compMotif.getPWM()
        numBasesToCompare = min(len(compMotif), len(self.motifPWM))

        for pos in range(numBasesToCompare):
            pwmCol = self.motifPWM[pos]
            pwmColComp = compPWM[pos]
            pwmEntry = []
            for NT in range(4):
                pwmEntry.append(pwmCol[NT] - pwmColComp[NT])

            diffPWM.append(pwmEntry)

        return diffPWM


    def initializeRE(self):
        """ initializes Regular Expression motif engine.
        """
        global forwardRE
        global backwardRE
        cons = self.buildConsensus().upper()
        revComp = self.revComp().upper()
        reCons = ""
        reBackward = ""
        for NT in cons:
            reCons += reMAP[NT]

        if revComp != cons:
            for NT in revComp:
                reBackward += reMAP[NT]
        else:
            reBackward = "ZZZZZZZ"

        forwardRE = re.compile(reCons, re.IGNORECASE)
        backwardRE = re.compile(reBackward, re.IGNORECASE)


    def initializeMismatchRE(self, numMismatches):
        """ initializes Regular Expression motif engine allowing for mismatches.
        """
        global forwardRE
        global backwardRE
        cons = self.seqMismatches(self.buildConsensus().upper(), numMismatches)
        revComp = self.seqMismatches(self.revComp().upper(), numMismatches)
        reCons = ""
        reBackward = ""
        for NT in cons:
            reCons += reMAP[NT]

        if self.revComp().upper() != self.buildConsensus().upper():
            for NT in revComp:
                reBackward += reMAP[NT]
        else:
            reBackward = "ZZZZZZZ"

        forwardRE = re.compile(reCons, re.IGNORECASE)
        backwardRE = re.compile(reBackward, re.IGNORECASE)


    def locateConsensusRE(self, sequence):
        """ Returns a list of positions on aSeq that match the consensus exactly. 
            Should be run after either initializeRE() or initializeMismatchRE(numMismatches)
        """
        motLength = len(self.motifPWM)
        position = []
        results = []
        if len(sequence) < motLength:
            return []

        forwardIter = forwardRE.finditer(sequence)
        backwardIter = backwardRE.finditer(sequence)
        for match in forwardIter:
            position.append((match.start(), "F"))

        for match in backwardIter:
            position.append((match.start(), "R"))

        positionLength = len(position)
        if positionLength >= 1:
            position.sort()
            (prevPos, prevSense) = position[0]
            results.append((prevPos, prevSense))

            for index in range(1, positionLength):
                (pos, sense) = position[index]
                if pos >= prevPos + motLength:
                    results.append((pos, sense))
                    (pos, sense) = (prevPos, prevSense)

        return results


    def locateStrictConsensus(self, aSeq, mismatches=0):
        """ returns a list of positions on aSeq that match the strict 
            consensus within some mismatches.
            Only available as a C-extension for greater speed-up for now.
        """
        forwardMer = self.buildStrictConsensus()
        motLength = len(forwardMer)
        revcompMer = complement(forwardMer, motLength)
        if hasMotifExtension:
            return  _motif.locateMer(aSeq, forwardMer, revcompMer, mismatches)
        else:
            print "only supported as part of the C extension for now"
            return []


    def locateMotif(self, sequence, threshold=90.0, numberN=0):
        """ returns a list of positions on aSeq that match the PWM within a Threshold, 
            given as a percentage of the optimal consensus score. 
            Will call C-extension for greater speed-up if compiled.
        """
        motifLength = len(self.motifPWM)
        sequenceLength = len(sequence)
        threshold /=  100.0
        if threshold < 0.5:
            print "Threshold less than 50% - will abort locateMotif()"
            return []

        maxScore = self.bestConsensusScore()
        maxDiff = maxScore * (1 - threshold)
        if sequenceLength < motifLength:
            return []
        else:
            sequence.strip()

        if hasMotifExtension:
            return  _motif.locateMotif(sequence, self.motifPWM, self.reversePWM, maxScore, maxDiff)

        sequence = upper(sequence)
        positionList = []
        position = 0
        while position <= (sequenceLength - motifLength):
            subSequence = sequence[position: position + motifLength]
            if subSequence.count("N") > numberN:
                position += 1
                continue

            (seqScore, revSeqScore) = self.scoreMotif(subSequence, maxDiff)
            if seqScore >= revSeqScore and seqScore > 1.0:
                positionList.append((position, "F"))
            elif revSeqScore > 1.0:
                positionList.append((position, "R"))

            position += 1

        return positionList


    def locateMarkov1(self, sequence, maxFold=5.0):
        """ returns a list of positions on sequence that match the Markov1 within maxFold.
        """
        motifLength = len(self.motifPWM)
        sequenceLength = len(sequence)
        if maxFold < 1.0:
            print "maxFold less than 1.0 - will abort locateMarkov1()"
            return []

        maxScore = self.bestMarkov1Score() * maxFold
        if sequenceLength < motifLength:
            return []
        else:
            sequence.strip()

        if hasMotifExtension:
            return  _motif.locateMarkov1(sequence, self.motifMarkov1, self.revMarkov1, maxScore)

        sequence = upper(sequence)
        positionList = []
        position = 0
        while position <= (sequenceLength - motifLength):
            subSequence = sequence[position: position + motifLength]
            if subSequence.count("N") > 0:
                position += 1
                continue

            (seqScore, revSeqScore) = self.scoreMarkov1(subSequence, maxScore)    
            if seqScore <= revSeqScore and seqScore < maxScore:
                positionList.append((position, "F"))
            elif revSeqScore < maxScore:
                positionList.append((position, "R"))

            position += 1

        return positionList


    def calculatePWMfromSeqs(self):
        """ calculate the PWM using a set of non-degenerate instances of the motif.
        """
        PWM = []
        numSeqs = len(self.sequences)

        if numSeqs < 1:
            return

        # using length of first sequence as the length of the motif
        length = len(self.sequences[0])
        for index in range(length):
            PWM.append([0.0, 0.0, 0.0, 0.0])

        for seq in self.sequences:
            index = 0
            theseq = seq.upper()
            for NT in theseq:
                PWM[index][matrixRow[NT]] += 1.0
                index += 1

        for index in range(length):
            for NT in ["A", "C", "G", "T"]:
                PWM[index][matrixRow[NT]] /= numSeqs

        self.motifPWM = PWM
        self.buildReversePWM()
        self.motifSeq = self.buildConsensus()
        self.strictConsensus = self.buildStrictConsensus()


    def printMarkov1(self):
        """ print the Markov1 PSSM of the form previous NT -> current NT.
        """
        row = []
        for prior in range(4):
            row.append(["", "", "", ""])

        for pos in self.motifMarkov1:
            for prior in range(4):
                for current in range(4):
                    row[prior][current] += str(round(pos[prior][current], 4)) + "\t"

        for prior in ["A", "C", "G", "T"]:
            for current in ["A", "C", "G", "T"]:
                try:
                    print "%s -> %s\t%s\n" % (prior, current, row[matrixRow[prior]][matrixRow[current]])
                except:
                    print "ERROR: %s %s" % (prior, current)
        print "\n"


    def bestMarkov1Score(self):
        """ returns the best markov1 score possible.
        """
        motLength = len(self.motifMarkov1)
        matchMarkov1 = self.probMatchMarkov1
        score = 0.0
        for index in range(motLength):
            col = self.motifMarkov1[index]
            mRow = []
            for prior in ["A", "C", "G", "T"]:
                for current in ["A", "C", "G", "T"]:
                    mRow.append((col[matrixRow[prior]][matrixRow[current]], prior, current))
            mRow.sort()

            if index == 0:
                currentProb = matchMarkov1(self.motifMarkov1, "N", mRow[-1][2], index)
            else:
                currentProb = matchMarkov1(self.motifMarkov1, mRow[-1][1], mRow[-1][2], index)

            if currentProb < 0.0001:
                currentProb = 0.0001

            if currentProb > 0.0:
                score -= log(currentProb,2.0)            

        return score


    def worstMarkov1Score(self):
        """ returns the worst markov1 score possible.
        """
        motLength = len(self.motifMarkov1)
        currentProb = 0.0001    
        score = -log(currentProb, 2.0) * (motLength - 1)            

        return score


    def calculateMarkov1(self, pseudoCount=1.0):
        """ calculate the Markov1 PSSM using a set of non-degenerate instances of the motif.
        adds a pseudoCount for unseen combinations.
        """
        self.motifMarkov1 = []
        numSeqs = len(self.sequences) + pseudoCount

        if numSeqs < 2:
            return []

        # using length of first sequence as the length of the motif
        length = len(self.sequences[0])

        for index in range(length):
            self.motifMarkov1.append([[pseudoCount, pseudoCount, pseudoCount, pseudoCount],
                                      [pseudoCount, pseudoCount, pseudoCount, pseudoCount],
                                      [pseudoCount, pseudoCount, pseudoCount, pseudoCount],
                                      [pseudoCount, pseudoCount, pseudoCount, pseudoCount]])

        for seq in self.sequences:
            theseq = seq.upper()
            index = 0
            prior = -1
            for pos in theseq:
                if index == 0:
                    for priorNT in range(4):
                        self.motifMarkov1[index][priorNT][matrixRow[pos]] += 0.25
                else:
                    self.motifMarkov1[index][prior][matrixRow[pos]] += 1.0

                prior = matrixRow[pos] 
                index += 1

        for index in range(length):
            for prior in range(4):
                for current in range(4):
                    self.motifMarkov1[index][prior][current] /= numSeqs

        self.buildReverseMarkov1(pseudoCount)


    def buildReverseMarkov1(self, pseudoCount=1.0):
        """ calculate the Reverse Markov1 PSSM using a set of non-degenerate instances of the motif.
        """
        self.revMarkov1 = []
        numSeqs = len(self.sequences) + pseudoCount

        if numSeqs < 2:
            return []

        # using length of first sequence as the length of the motif
        length = len(self.sequences[0])
        for index in range(length):
            self.revMarkov1.append([[pseudoCount, pseudoCount, pseudoCount, pseudoCount],
                                    [pseudoCount, pseudoCount, pseudoCount, pseudoCount],
                                    [pseudoCount, pseudoCount, pseudoCount, pseudoCount],
                                    [pseudoCount, pseudoCount, pseudoCount, pseudoCount]])

        for aSeq in self.sequences:
            seq = complement(aSeq.upper(), length)
            index = 0
            prior = -1
            for pos in seq:
                if index == 0:
                    for priorNT in range(4):
                        self.revMarkov1[index][priorNT][matrixRow[pos]] += 0.25
                else:
                    self.revMarkov1[index][prior][matrixRow[pos]] += 1.0

                prior = matrixRow[pos] 
                index += 1

        for index in range(length):
            for prior in range(4):
                for current in range(4):
                    self.revMarkov1[index][prior][current] /= numSeqs


    def scoreMarkov1(self, aSeq, maxScore=10000000.):
        """ calculate the matching score using the Markov1.
            limit search if given a low maxScore
        """
        motLength = len(self.motifMarkov1)
        matchMarkov1 = self.probMatchMarkov1
        Score = []
        if len(aSeq) < motLength:
            return Score
        else:
            theSeq = upper(aSeq)

        pos = 0    
        seqProb = 0.0
        revSeqProb = 0.0    
        subSeq = theSeq[pos: pos + motLength]    
        previousNT = "N"
        for index in range(motLength):
            currentNT = subSeq[index]
            currentProb = matchMarkov1(self.motifMarkov1, previousNT, currentNT, index)

            if currentProb < 0.0001:
                currentProb = 0.0001

            if currentProb > 0.0:
                seqProb -= log(currentProb,2.0)

            revCurrentProb = matchMarkov1(self.revMarkov1, previousNT, currentNT, index)
            if revCurrentProb < 0.002:
                revCurrentProb = 0.002

            if revCurrentProb > 0.0:
                revSeqProb -= log(revCurrentProb, 2.0) 

            if seqProb > maxScore and revSeqProb > maxScore:
                return (seqProb, revSeqProb)

            previousNT = currentNT

        return (seqProb, revSeqProb)


    def probMatchMarkov1(self, theMarkov1, previousNT, NT, colIndex):
        """ returns the likelihood of seeing NT given previousNT at this position of the motif.
        """
        currentProb = 0.0
        if NT in ["A", "C", "G", "T"]:
            currentNT = matrixRow[NT]
        else: 
            currentNT = 0

        try:
            prevNT = matrixRow[previousNT]
        except KeyError:
            for index in range(4):
                currentProb += theMarkov1[colIndex][index][currentNT]

            return currentProb

        if NT in ["A", "C", "G", "T"]:
            return theMarkov1[colIndex][prevNT][currentNT]

        if NT == "N":
            return 1.0

        for motifNucleotide in ["A", "T", "C", "G"]:
            if NT in motifDict[motifNucleotide]:
                row = matrixRow[NT]
                currentProb += theMarkov1[colIndex][prevNT][row]

        return currentProb


    def isSane(self, minLen=7, stretchLen=6):
        """ check for motif sanity, which includes: minimum length, less than half N's in consensus, 
            motifs include more than two nucleotides, no nucleotide or dinucleotide is repeated more 
            than stretchlen. The appropriate exceptions are made for 'GC' dinucleotides.
        """
        stretchLen = int(stretchLen)
        minLen = int(minLen)
        stretchLen = min(stretchLen, minLen - 1)

        cons = self.buildConsensus()
        motifLen = float(len(cons))
        if motifLen < minLen:
            return False

        nCount = cons.count("N")
        if (nCount >= 0.5 * motifLen):
            return False

        aCount = cons.count("A")
        gCount = cons.count("G")
        cCount = cons.count("C")
        tCount = cons.count("T")

        atCount = aCount + tCount
        agCount = aCount + gCount
        acCount = aCount + cCount
        gtCount = gCount + tCount
        tcCount = tCount + cCount

        for pairedCount in [atCount, agCount, acCount, gtCount, tcCount]:
            if pairedCount == motifLen:
                return False

        cons = self.buildStrictConsensus()
        repeatSequences = []
        for nucleotide in ["A", "G", "C", "T"]:
            repeatSequences.append(nucleotide * stretchLen)

        if stretchLen % 2 != 0:
            stretchLen += 1

        repeatCount = stretchLen/2
        for dinucleotide in ["AG", "AC", "AT", "CT", "GT"]:
            repeatSequences.append(dinucleotide * repeatCount)

        for testSequence in repeatSequences:
            if cons.count(testSequence):
                return False

        return True


def correlateMotifs(actualMotifA, actualMotifB, maxSlide=1):
    """ Compares two motifs using the "pearson correlation coefficient-like" MSS.
        Will slide a motif up to maxSlide bases compared to the other,
        and reports the best score.
    """
    bestScore = 0.0

    if len(actualMotifA) < len(actualMotifB):
        motA = actualMotifB
        motB = actualMotifA
    else:
        motA = actualMotifA
        motB = actualMotifB

    motApwm = motA.getPWM()
    motBpwm = motB.getPWM()
    motCpwm = motB.buildReversePWM()
    if hasMotifExtension:
        return  _motif.correlateMotifs(motApwm, motBpwm, motCpwm, maxSlide)
    else:
        length = len(motA)
        padLength = length - len(motB)
        Ncol = [symbolToMatrix["N"]]
        for slide in range(-1 * maxSlide, maxSlide + padLength + 1):
            pwmA = deepcopy(motApwm)
            pwmB = deepcopy(motBpwm)
            pwmC = deepcopy(motCpwm)
            if slide < 0:
                pwmA = Ncol * abs(slide) + pwmA
                pwmB = pwmB + Ncol * (abs(slide) + padLength) 
                pwmC = pwmC + Ncol * (abs(slide) + padLength)
            elif slide > 0 and slide <= maxSlide:
                if padLength > 0:
                    if padLength >= slide:
                        adjustedPadLength = padLength - slide
                        adjustedSlide = 0
                    else:
                        adjustedPadLength = 0
                        adjustedSlide = slide - padLength

                    pwmA = pwmA + Ncol * adjustedSlide
                    pwmB = Ncol * slide + pwmB + Ncol * adjustedPadLength
                    pwmC = Ncol * slide + pwmC + Ncol * adjustedPadLength
                else:
                    pwmA = pwmA + Ncol * slide
                    pwmB = Ncol * slide + pwmB 
                    pwmC = Ncol * slide + pwmC 
            elif slide > maxSlide:
                maxDiff = slide - maxSlide
                pwmA = pwmA + Ncol * maxSlide
                pwmB = Ncol * slide  + pwmB + Ncol * (padLength - maxDiff)
                pwmC = Ncol * slide  + pwmC + Ncol * (padLength - maxDiff)
            else:
                pwmB = pwmB + Ncol * padLength
                pwmC = pwmC + Ncol * padLength            

            score1 = 0.0
            score2 = 0.0
            thisLength = len(pwmA)
            for index in range(thisLength):
                score1 += pearsonCorrelation(pwmA[index], pwmB[index])
                score2 += pearsonCorrelation(pwmA[index], pwmC[index])

            score1 = score1 / float(thisLength)
            score2 = score2 / float(thisLength)
            if score1 < score2 and score2 > bestScore:
                bestScore = score2
            elif score1 > bestScore:
                bestScore = score1

    return bestScore


def MSS(motifA, motifB, maxSlide=1):
    """ Compares two motifs using the motif similarity score (MSS).
        Will slide a motif up to maxSlide bases compared to the other,
        and reports the best score. Wrapper around correlateMotifs()
    """
    return correlateMotifs(motifA, motifB, maxSlide)


def printMSS(motifList, maxSlide=1):
    """ Prints a matrix of MSS comparisons between different motifs
        in motifList.
    """
    for mot1 in motifList:
        print mot1.tagID,
        for mot2 in motifList:
            val = "\t%1.2f" % correlateMotifs(mot1, mot2, maxSlide)
            print val,

        print ""