##Sarah Aerni
##Created:  June 25, 2005
##Modified: July 11, 2005
##Motif Finder using Gibbs sampling
##convergence is measured when the difference in the sequences is negligible

import random
import math
import sys
import copy
from time import time

from math import ceil

Frequency = {}
ConsensusScore= {True:2, False:1}

NTDIndices = {"A": 0, "C": 1, "G": 2, "T": 3}
IndexNTDs  = {0: "A", 1: "C", 2: "G", 3: "T"}
INSERTION_N = "N"
global minSize
global maxSize
global thresholdPercent
global sequences
global sizeOfMotif
global numOfMismatches
global maxIterations
global maxLoopsWithoutImprovement

"""
markov size - 1 = markov model
example:
if you want a markov 1 you want to check 2 ntds total
markov_size = 2
MARKOV_SIZE = 2

AlignmentScore
input:     sequenceToCompare   List of sequences whose substrings will be 
           sizeOfMotif         aligned integer size of the motif being found 
                               (length of the subseqeunce that will be aligned
                                from each sequence in above list (window size)
           startLocs           start locations of the motifs to be aligned to
                               each other
           CompareList         the indices of sequencesToCompare to be aligned
           numSeqs             the number of sequences -1 from 
                               sequencesToCompare to be aligned, the indices
                               of sequenceToCompare stored in the first 
				numSeqs indices of in CompareList.
           Frequency           markov model being used to calculate background
           originalSeqs        contain unmasked sequences used for checking
                               the markov score
output:    integer             Score indicating the consensus score of these
                               sequences
           2D-List             contains the PSFM
           2D-List             contains the log markov scores

will be calculated as an ungapped consensus between those elements in the 
CompareList Consensus score is calculated by choosing the largest number of
the same elements in each column, and adding all these numbers up across all
columns
"""


def AlignmentScore(sequencesToCompare, sizeOfMotif, startLocs, CompareList, numSeqs):
    TotalScore = 0;
    PWM = []
    maxScores=[]
    len(sequencesToCompare)
    for i in range (sizeOfMotif):
        PWMi = [0.0, 0.0, 0.0, 0.0]      
        for j in range(numSeqs+1):
            SequenceIndex = CompareList[j]
            CurrSeq = sequencesToCompare[SequenceIndex]

            #some sequences may not contain the motifs, if so you do not want 
            #to include them in the consensus. These have uninitialized start
            #locations (ie startLocs would be -1
            if startLocs[SequenceIndex] != -1:
                if sequencesToCompare[SequenceIndex]\
                    [startLocs[SequenceIndex]+i] == "N":
                    print sequencesToCompare
                    print "\nBAD HERE!"
                    print CurrSeq
                    print startLocs
                    print startLocs[SequenceIndex]
                    print CompareList
                    print j
                    print SequenceIndex
                    print numSeqs
                    print sizeOfMotif
                    print CurrSeq[startLocs[SequenceIndex]:startLocs[SequenceIndex]+sizeOfMotif]
                PWMi[NTDIndices[CurrSeq[startLocs[SequenceIndex]+i]]] += 1.0
        maxHere=max(PWMi)
        TotalScore += maxHere
        maxScores.append(maxHere)
        PWM.append(PWMi)

    return (TotalScore, PWM,maxScores)


def MarkovFreq (prefix, actualNTD, Frequency,MARKOV_WINDOW):
    """ MarkovFreq
        input:     prefix      string of length MARKV_SIZE - 1 prefix used for model
                   actualNTD   character NTD at the en dof the prefix being calculated
                   Frequency   Markov model for calculations
        output:    float that gives the markov score for this specific sequence

        The helper function will run through and find all possible words with the
        prefix and determine the markov score based on this
    """
    denominator = 0.0
    numerator = 0.0
    for NTD in ["A", "C", "G", "T"]:
        value = M_Score(prefix+NTD, Frequency, False, MARKOV_WINDOW)
        if NTD == actualNTD :
            numerator = value
        denominator += value
    retVal = numerator/denominator

    return retVal


def revComp (sequence):
    """ revComp
        input:     sequence    DNA sequence to be converted to reverse complement
        output:    string      reverse complement of input sequence

        obtains the reverse complement of an input sequence
    """
    RevDict={"A": "T",
             "T": "A",
             "C": "G",
             "G": "C",
             "N": "N"
    }
    reverse = ""
    for i in range(len(sequence)):
        reverse = RevDict[sequence[i].upper()]+reverse

    return reverse


def Markov(sequences, IncludeRC, MARKOV):
    """ Markov3
        input:     sequences   list that are being used to create the background
        output:    dictionary of all 6mers (reverse complement also) and their -log2
                   proportion seen

        background will build a markov model of the background in order to be able   
        to differentiate the motifs from the pure background of a certain size
        they will be stored as -log(fraction)
    """
    MARKOV_WINDOW = MARKOV + 1
    WordDict = {}
    totalWindows = 0
    for i in sequences:
        totalWindows += (len(i)-MARKOV_WINDOW+1)*2

    for seq in sequences:            
        for index in range(len(seq)-MARKOV_WINDOW+1):
            subseq = seq[index:index+MARKOV_WINDOW].upper()
            if subseq not in WordDict:
                WordDict[subseq] = 0.0

            WordDict[subseq] += 1.0
            if IncludeRC:
                RC = revComp(subseq)
                if RC not in WordDict:
                    WordDict[RC] = 0.0

                WordDict[RC] += 1.0
               
    for key in WordDict:
        WordDict[key] = 1.0*WordDict[key]/totalWindows

    return WordDict    


def Average_M (sequence, Model, l, MARKOV_WINDOW):
    """ Average_M
        input:     sequences   List of sequences on which to find the average
                               markov score
                   Model       Dictionary containing pvalues for seeing 3mers
                   l           integer designating the word sizes from which to
                               determine average pvalue
        output:    average probability of all input lmers in sequences in the Model

        finds the probability of seeing all subsequence in the total strings
        using the markov model created using the background. Markov3 is used
        (window size of 3) and from this determine the average. This function will
        also screen the background model
    """
    totalSum = 0.0;
    totalWords = 0.0;
    for seq in sequence:
        for i in range(MARKOV_WINDOW-1,len(seq)-l+1):
            totalWords += 1.0
            PVal = M_Score(seq[i-MARKOV_WINDOW+1:i+l], Model, True, MARKOV_WINDOW)
            totalSum += PVal

    retVal = totalSum/totalWords
    print totalWords
    return retVal


def M_Score (sequence, Model, check, MARKOV_WINDOW):
    """ M_Score
        input:     sequence    string for which the Pvalue is to be determined
                   Model       Dictionary containing log2 pvalues for seeing 6mers
                   check       Boolean which determines whether to also check for
                               completeness of markov model
        output:    log2 probability of seeing the input seqeunce in the Model

        gives the probability of seeing the given subsequence in the total strings
        using the markov model created using the background. Markov6 is used
        (window size of 3)
    """
    PVal = 0.0
    for j in range(len(sequence)-MARKOV_WINDOW+1):
        if sequence[j:j+MARKOV_WINDOW] not in Model:
            if check:
                print "The Markov Model is inadequate for your input",
                print "sequences\n %s is"%sequence[j:j+MARKOV_WINDOW],
                print "not contained in model provided\n",
                print "Please revise your provided model or consider",
                print "using Background Modelling provided"
                sys.exit(0)

            continue

        PVal += -math.log(Model[sequence[j:j+MARKOV_WINDOW]],math.e)

    return PVal


def LogOdds(PWM, LogsM, sequence, Frequency, MARKOV_WINDOW):
    """ LogOdds
        input:     sequence    relevant part of the sequence being added to the PWM
                   PWM         information on sequences already in the motif
                   LogsM       frequency information on sequences already in motif
                   Frequency   markov model for background
                   sizeOfMotif size f the motif being found
        output     returns the log odds score for the consensus
        the equation used is as follow:
        S(j = 1 to sizeOfMotif (S(i = [A,C,G,T]) f_ij * ln(S(Prob each path))))
    """
    Score = 0
    PWMout = copy.deepcopy(PWM)
    LogsMout = copy.deepcopy(LogsM)
    #since each column of the PWM must add up to the total umber of sequences
    #in that PWM, in addition one must be added for the current sequence
    totalSeqs = PWM[0][0]+PWM[0][1]+PWM[0][2]+PWM[0][3] + 1
    for j in range(len(PWMout)):
        for i in ['A', 'C', 'G', 'T']:
            if i == sequence[j+MARKOV_WINDOW-1]:
                PWMout[j][NTDIndices[i]] += 1.0
                word = sequence[j:j+MARKOV_WINDOW]
                LogsMout[j][NTDIndices[i]] += Frequency[word]
 
            if PWMout[j][NTDIndices[i]]> 0:
                Score += PWMout[j][NTDIndices[i]]/totalSeqs*math.log(PWMout[j][NTDIndices[i]]/(totalSeqs*LogsMout[j][NTDIndices[i]]/PWMout[j][NTDIndices[i]]),math.e)

    return Score, PWMout, LogsMout


def convert2motif(sequences, size):
    """ convert2motif
        input:     sequences   list containing the sequences in A,C,G,T alphabet 
                               comprising the motif to be converted into symbols
                   size        size of the motif
        maybe add threshold?!?!
        output:    string      motif converted into descriptive symbols

        takes in a list of motifs that were found at each point and converts them to
        an actual motif
    """
    #column composition is replaced by symbols
    SymbolDict = {'CGT':'B','AGT':'D','ACT':'H','GT':'K','AC':'M', 'ACGT':'N','AG':'R','CG':'S','ACG':'V','AT':'W', 'CT':'Y','A':'A','C':'C','G':'G','T':'T'}
    Motif = ""
    for i in range(size):
        A = 0
        C = 0
        G = 0
        T = 0        
        for seq in sequences:
            if seq[i].upper() == "A":
                A += 1
            elif seq[i].upper() == "C":
                C += 1
            elif seq[i].upper() == "G":
                G += 1
            else:
                T += 1

        characterCode = ""

        #translate column composition into symbols
        ###########should we use percentages?! ie. A >certain percent##################
        if (A > 0):
            characterCode += "A"

        if (C > 0):
            characterCode += "C"

        if (G > 0):
            characterCode += "G"

        if (T > 0):
            characterCode += "T"

        Motif += SymbolDict[characterCode]

    return Motif


def convert2PSFM (sequences, NumOfSeqs):
    """ convert2PSFM
        input:     sequences   list containing the sequences in A,C,G,T alphabet 
                               comprising the motif to be converted into symbols
                   size        size of the motif
        output:    2Darray     will contain the PSFM where indices 0-3 of each list 
                               will be A,C,G,T respectively

        takes in a list of motifs that were found at each point and converts them to
        a PSFM
    """
    PSFM = []
    PWM = convert2PWM(sequences, len(sequences[0]))
    for i in xrange(len(PWM)):
        index = []
        for j in [0,1,2,3]:
            index.append(PWM[i][j]/NumOfSeqs)

        PSFM.append(index)

    return PSFM


def convert2PWM (sequences, size):
    """ convert2PWM
        input:     sequences   list containing the sequences in A,C,G,T alphabet 
                               comprising the motif to be converted into symbols
                   size        size of the motif
        output:    2Darray     will contain the PSFM where indices 0-3 of each list 
                               will be A,C,G,T respectively

        takes in a list of motifs that were found at each point and converts them to
        a PWM
    """
    PWM = []
    for i in range(size):
        indices = [0.0, 0.0, 0.0, 0.0]
        for seq in sequences:
            indices[NTDIndices[seq[i].upper()]] += 1.0

        PWM.append(indices)

    return PWM


def add2PWM(sequence, PWM):
    """ add2PWM
        input:     sequence    sequence to be added to PWM
                   PWM         PWM being modiifed         
 
        takes in a sequence and adds it to the PWM
    """
    #determine the composition to add
    for i in range(len(PWM)):
        PWM[i][NTDIndices[sequence[i].upper()]] += 1.0


def Align2PWM(Motifi,PWM):
    """ Align2PWM
        input:     Motifi      Sequence being aligned to PWM
                   PWM         PWM to which the sequence is being aligned
        output:    float       alignment score to the matrix

        takes in a PWM and aligns the sequence to this PWM and returns
        the consensus scoreo
    """
    Score = 0.0
    for i in range(len(PWM)):
        Score += PWM[i][NTDIndices[Motifi[i]]]
    return Score


def Factorial(n):
    """ Factorial
        input:     n       Number for which we will calculate the factorial
        output:    float   factorial of input number

        calculates the factorial of input number n
    """
    #by definition factorial should be 1
    Fact = 1
    for i in range(2,n+1):
        Fact *= i
    return Fact


def Choose(n, k):
    """ Choose
        input:     n       integer for total number of elements in a set
                   k       integer for total number of elements to be chose from set
        output:    float   the binomial coefficient of the above number, ie nCk

        calculates the number of ways to choose k elements from a set of n
    """
    return Factorial(n)/Factorial(n-k)/Factorial(k)

      
getNTDVals = {0:'A',1:'C',2:'G', 3:'T'}

getNTDIndex = {'A':0,'C':1,'G':2, 'T':3}


def MotifFinder(InputSequences, minS, maxS, numberOfMotifs, Markov_Size,
		        UseRC, Frequency,excludeBelowAve, percID,
		        maxIter, maxLoopsW_outImprovement):
    global minSize
    minSize=minS
    global maxSize
    maxSize=maxS
    global thresholdPercent
    thresholdPercent=percID
    global sequences
    global sizeOfMotif
    global numOfMismatches
    global maxIterations
    maxIterations=maxIter

    global maxLoopsWithoutImprovement
    maxLoopsWithoutImprovement=maxLoopsW_outImprovement

    print "Running Sampler with %i sequences"%(len(InputSequences))
    print "Finding %i motifs of size %i to %i using markov size %i" % (numberOfMotifs,minSize,maxSize,Markov_Size)

    sequences = [InputSequences[i].upper() for i in range(len(InputSequences))]
    PSFMs = []
    if UseRC:
        for seqIndex in xrange(len(sequences)):
            RCSeq = revComp(sequences[seqIndex])
            sequences[seqIndex] += INSERTION_N+	RCSeq

    #this will track the movement of each sequence
    #if a movement exceeds a certain threshold we are not finished
    for motifNum in range(numberOfMotifs):

        #to improve speed shrink sequences by replacing strings of Ns by
        #a single N
        for i in xrange(len(sequences)):
            splitByN=sequences[i].split('N')
            j = 0;
            finalSequence=""
            max_j=len(splitByN)
            while j  < max_j :
                if len(splitByN[j])==0:
                    finalSequence="".join([finalSequence,'N'])
                    while len(splitByN[j])==0:
                        j+=1
                        if j==max_j:
                            break
                else:
                    finalSequence="".join([finalSequence,splitByN[j]])
                    j+=1
            sequences[i]=finalSequence

        print "MOTIF NUMBER %i" %motifNum
        empty=min([len(sequences[i]) for i in xrange(len(sequences))])
        #pick motif size randomly within the range provided by the user
        sizeOfMotif = random.randint(minSize,maxSize)
        if empty < maxSize:
            return ALLPSFMS

        numOfMismatches=sizeOfMotif-ceil(thresholdPercent/100.0*sizeOfMotif)
        (PWM,PWMScores,startLocs)=GibbsRunner(100)
        MaxVals=[0 for i in xrange(len(PWM))]
        for ConsI in xrange(len(PWM)):
            MaxVals[ConsI] = max(PWM[ConsI])
        PWMScores = [0 for i in range(len(sequences))]
        for SIndex in range(len(sequences)):
            subseq = sequences[SIndex][startLocs[SIndex]:startLocs[SIndex]+sizeOfMotif]
            PWMScores[SIndex] = 0
            #######################start here##########
            for subIndex in range(len(subseq)):
                PWMScores[SIndex] += PWM[subIndex][NTDIndices[subseq[subIndex]]]

        maxScore = max(PWMScores)
        #get rid of all the sequences that do not achieve a certain consensus
        #score defined by the top one
        thresh = thresholdPercent/100.0 * maxScore
        FinalPWMSeqs = []
        for SIndex in range(len(PWMScores)):
            if PWMScores[SIndex] > thresh:
                FinalPWMSeqs.append(sequences[SIndex][startLocs[SIndex]:startLocs[SIndex]+sizeOfMotif])
            else:
                startLocs[SIndex] = -1

        FinalPSFM= convert2PSFM (FinalPWMSeqs, len(FinalPWMSeqs))
        PSFMs.append(FinalPSFM)
        for i in xrange(len(sequences)):
            if startLocs[i] != -1:
                sequences[i] = sequences[i][:startLocs[i]]+INSERTION_N*sizeOfMotif+sequences[i][startLocs[i]+sizeOfMotif:]
                sequences[i] = sequences[i][:len(sequences[i])-startLocs[i]-sizeOfMotif]+INSERTION_N*sizeOfMotif+sequences[i][len(sequences[i])-startLocs[i]:]
    return PSFMs


def GibbsRunner(iterIn):
    iterAll = 0
    BestPWM=[]
    BestScore=0
    BestLocs=[]
    global minSize
    global maxSize
    global thresholdPercent
    global sequences
    global sizeOfMotif
    global numOfMismatches
    global maxIterations
    global maxLoopsWithoutImprovement
    maxScore=sizeOfMotif*len(sequences)
    st=time()

    while iterAll < iterIn: 
        en=time()	
        print "%.03f\t"%(en-st),
        st=time()
        iterAll+=1
        startLocs = [-1 ] * len(sequences)    

        for i in range(len(sequences)):
            startLocs[i] = random.randint(0,len(sequences[i])-sizeOfMotif)
            while "N" in sequences[i][startLocs[i]:startLocs[i]+sizeOfMotif]:
                startLocs[i] = random.randint(0,len(sequences[i])-sizeOfMotif)

        (TotalScore, PWM,dummy)= AlignmentScore(sequences, sizeOfMotif, startLocs, [i for i in range(len(sequences))], len(sequences)-1)
        PWMScore=[Align2PWM(sequences[i][startLocs[i]:startLocs[i]+sizeOfMotif],PWM) for i in range(len(sequences))]
        print "PWM is right now"
        print PWM
        print "scores for each"
        print PWMScore
        SOi = -1
        ConsensusScore = 0
        PreviousBestScore = 0
        PreviousBestTime = -1
        iterations = 0
        while iterations < maxIterations and (ConsensusScore > PreviousBestScore or PreviousBestTime <= maxLoopsWithoutImprovement):
            iterations += 1
            SOi = random.randint(0,len(sequences)-1)
            SeqMotifs = []
            locs=startLocs[:]
            for i in range(len(sequences)):
                if(SOi == i):
                    locs[i]=-1
                    continue
                SeqMotifs.append(sequences[i][startLocs[i]:startLocs[i]+sizeOfMotif])

            (TotalScore, PWM,maxScores)= AlignmentScore(sequences, sizeOfMotif, locs, [i for i in range(len(sequences))], len(sequences)-1)
            startLocsProb = []
            startLocsI= []
            SOSeq = sequences[SOi]    
            total = 0
            start = 0
            endloc=len(SOSeq)-sizeOfMotif
            while(start<=endloc):
                Motif = SOSeq[start:start+sizeOfMotif]
                locOfN=Motif.rfind("N")
                if locOfN>=0:
                    start+=locOfN+1
                    continue
                probAtPosn=0
                j=0
                mmNum=0
                while (j<sizeOfMotif):
                    letterScore=PWM[j][NTDIndices[Motif[j]]]
                    probAtPosn+=letterScore
                    mmNum+=int(letterScore!=maxScores[j])
                    if mmNum>numOfMismatches:
                        probAtPosn=0
                        break
                    j+=1

                if probAtPosn == 0:
                    start+=1
                    continue
                startLocsI.append(start)
                startLocsProb.append(probAtPosn)
                total += probAtPosn
                start+=1

            if len(startLocsProb) == 0:
                continue

            choice = random.random()
            choiceLoc = choice*total
            totalToHere = 0        
            for PrefI in range(len(startLocsProb)):
                if totalToHere+startLocsProb[PrefI] == 0:
                    continue
                if choiceLoc < totalToHere+startLocsProb[PrefI]:
                    break
                totalToHere += startLocsProb[PrefI]

            startLocs[SOi] = startLocsI[PrefI]
            PWMScore[SOi] = startLocsProb[PrefI]
            newMotif=SOSeq[startLocs[SOi]:startLocs[SOi]+sizeOfMotif]
            add2PWM (newMotif, PWM)

            NewScores=[]
            PercentChange = []
            for i in range(len(sequences)):
                NewScore = 0
                Motif_i=sequences[i][startLocs[i]:startLocs[i]+sizeOfMotif]
                for j in xrange(sizeOfMotif):
                    NewScore+=PWM[j][NTDIndices[Motif_i[j]]]

                NewScores.append(NewScore)
                PercentChange.append(math.fabs(NewScore - PWMScore[i])/PWMScore[i])

            TotConsensusScore = sum(NewScores)
            AveConsensusScore = TotConsensusScore/(len(sequences))
            if AveConsensusScore > PreviousBestScore:
                PreviousBestScore = AveConsensusScore
                PreviousBestTime = 0
            else:
                PreviousBestTime += 1

            PWMScore=NewScores[:]

        Consensus=sum([max(PWM[i]) for i in xrange(len(PWM))])
        if Consensus> BestScore:
            BestScore=Consensus
            BestPWM=PWM[:]
            BestLocs=startLocs[:]
            if BestScore==maxScore:
                break

    print "iterated %i times to find"%iterAll
    print BestPWM

    return(BestPWM,BestScore,BestLocs)


def probFromPSFM(sequence, PSFM):    
    probability = 1
    for i in range(len (sequence)):
        probability *= PSFM[i][getNTDIndex[sequence[i]]]

    return probability