###########################################################################
#                                                                         #
# C O P Y R I G H T   N O T I C E                                         #
#  Copyright (c) 2003-10 by:                                              #
#    * California Institute of Technology                                 #
#                                                                         #
#    All Rights Reserved.                                                 #
#                                                                         #
# Permission is hereby granted, free of charge, to any person             #
# obtaining a copy of this software and associated documentation files    #
# (the "Software"), to deal in the Software without restriction,          #
# including without limitation the rights to use, copy, modify, merge,    #
# publish, distribute, sublicense, and/or sell copies of the Software,    #
# and to permit persons to whom the Software is furnished to do so,       #
# subject to the following conditions:                                    #
#                                                                         #
# The above copyright notice and this permission notice shall be          #
# included in all copies or substantial portions of the Software.         #
#                                                                         #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,         #
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF      #
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND                   #
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS     #
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN      #
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN       #
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE        #
# SOFTWARE.                                                               #
###########################################################################
#
import string
import copy
from math import log

AA = {"TTT": "F",
      "TTC": "F",
      "TTA": "L",
      "TTG": "L",
      "TCT": "S",
      "TCC": "S",
      "TCA": "S",
      "TCG": "S",
      "TAT": "Y",
      "TAC": "Y",
      "TAA": "*",
      "TAG": "*",
      "TGT": "C",
      "TGC": "C",
      "TGA": "*",
      "TGG": "W",
      "CTT": "L",
      "CTC": "L",
      "CTA": "L",
      "CTG": "L",
      "CCT": "P",
      "CCC": "P",
      "CCA": "P",
      "CCG": "P",
      "CAT": "H",
      "CAC": "H",
      "CAA": "Q",
      "CAG": "Q",
      "CGT": "R",
      "CGC": "R",
      "CGA": "R",
      "CGG": "R",
      "ATT": "I",
      "ATC": "I",
      "ATA": "I",
      "ATG": "M",
      "ACT": "T",
      "ACC": "T",
      "ACA": "T",
      "ACG": "T",
      "AAT": "N",
      "AAC": "N",
      "AAA": "K",
      "AAG": "K",
      "AGT": "S",
      "AGC": "S",
      "AGA": "R",
      "AGG": "R",
      "GTT": "V",
      "GTC": "V",
      "GTA": "V",
      "GTG": "V",
      "GCT": "A",
      "GCC": "A",
      "GCA": "A",
      "GCG": "A",
      "GAT": "D",
      "GAC": "D",
      "GAA": "E",
      "GAG": "E",
      "GGT": "G",
      "GGC": "G",
      "GGA": "G",
      "GGG": "G",
      "---": "-",
      "NNN": "X"
}


def getAA(codon):
    """ returns one-letter AA symbol corresponding to codon, if existing.
        returns X for unknown codon, '-' for a complete gap, and '*' for a 
        stop codon.
    """
    codon = codon.upper()
    codon = string.replace(codon, "U", "T")
    try:
        aa = AA[codon]
    except KeyError:
        aa = "X"

    return aa


def translate(mRNA, frame=1):
    """ translate a sequence into protein based on frame (1, 2, 3 only)
    """
    prot = ""
    if frame < 1:
        frame = 1
    elif frame > 3:
        frame = 3

    for nucPos in range(frame - 1, len(mRNA) - 2, 3):
        theCodon = mRNA[nucPos:nucPos+3]
        print theCodon
        theAA = getAA(theCodon)
        prot += theAA

    return prot


def iCodon(codon):
    """ returns the number of possible synonymous changes at
        each site of the codon, e.g (1, 0, 3) for CTG (Leucine)
    """
    isynonyms = [0, 0, 0]
    origAA = getAA(codon)
    if origAA == "X" or origAA == "-":
        return [-1, -1, -1]

    for pos in range(len(codon)):
        isyn = 0
        bases = ["A", "C", "G", "T"]
        site = codon[pos]
        bases.remove(site)
        for nt in bases:
            newcodon = []
            for nuc in codon:
                newcodon.append(nuc)

            newcodon[pos] = nt
            newcodon = string.join(newcodon, "")
            newAA = getAA(newcodon)
            if newAA == origAA:
                isyn += 1

        isynonyms[pos] = isyn

    return isynonyms


def calcMutationTypes(codonList1, codonList2, verbose=False):
    """ returns the number of (synonymous, nonsynonymous) 
        mutations in informative codons.
    """
    synonymous = 0.0
    nonsynonymous = 0.0
    for pos in range(len(codonList1)):
        diffList = []
        codon1 = codonList1[pos]
        codon2 = codonList2[pos]
        for isite in range(len(codon1)):
            isite1 = codon1[isite]
            isite2 = codon2[isite]
            if isite1 != isite2:
                diffList.append(isite)

        aa1 = getAA(codon1)
        aa2 = getAA(codon2)
        if aa1 == aa2:
            if aa1 == "S":
                if codon1[0] != codon2[0]:
                    if verbose:
                        print "nonsynonymous Serines"

                    nonsynonymous += 1

                synonymous += 1
            else:
                synonymous += len(diffList)

        else:
            if len(diffList) == 1:
                nonsynonymous += 1
            else:
                if verbose:
                    print "parsimonious mutation path estimator - assuming 1 parameter"
                    print "diffList = %s\t%s (%s) \t%s (%s)" % (diffList, codon1, aa1, codon2, aa2)

                if 1 in diffList:
                    if verbose:
                        print "middle site is nonsynonymous"

                    nonsynonymous += 1
                    diffList.remove(1)

                if 2 in diffList:
                    codon3 = codon1[:-1] + codon1[2]
                    aa3 = getAA(codon3)
                    if aa3 == aa1:
                        if verbose:
                            print "last site is synonymous"

                        synonymous += 1
                        diffList.remove(2)

                if 0 in diffList:
                    codon3 = codon2[0] + codon1[1:]
                    aa3 = getAA(codon3)
                    if aa3 == aa1:
                        if verbose:
                            print "first site is synonymous"

                        synonymous += 1
                        diffList.remove(0)

                if len(diffList) > 0:
                    if verbose:
                        print "%s must be non-synonymous" % str(diffList)

                    nonsynonymous += len(diffList)

    return (synonymous, nonsynonymous)


def calcSubstitutionSites(codonList):
    """ returns the number of (synonymous, nonsynonymous) sites
        in a list of codons, which should be filtered for gaps and X's.
    """
    synonymous = 0.0
    nonsynonymous = 0.0
    for codon in codonList:
        (site0, site1, site2) = iCodon(codon)
        if site0 < 0:
            continue

        synonymous += (site0 + site1 + site2) / 3.0
        nonsynonymous += (9.0 - site0 - site1 - site2) / 3.0

    return (synonymous, nonsynonymous)


def calcSubstitutionsPerSite(codonList1, codonList2):
    """ returns the number of substitutions per site as 
        a triplet in comparable or informative sites.
    """
    site0 = 0.0
    site1 = 0.0
    site2 = 0.0
    for pos in range(len(codonList1)):
        codon1 = codonList1[pos]
        codon2 = codonList2[pos]
        if codon1[0] != codon2[0]:
            site0 +=1

        if codon1[1] != codon2[1]:
            site1 += 1

        if codon1[2] != codon2[2]:
            site2 += 1

    return (site0, site1, site2)


def calcKs(Ms, Ns):
    """ returns a Ks calculated using Ms and Ns and adjusted using
        Jukes and Cantor's formula.
    """
    Ks = -0.75 * log(1 - ((4.0/3.0) * Ms / Ns))
    return Ks


def calcKa(Ma, Na):
    """ returns a Ka calculated using Ma and Na and adjusted using
        Jukes and Cantor's formula.
    """
    Ka = -0.75 * log(1 - ((4.0/3.0) * Ma / Na))
    return Ka


def printCDSdict(cdsDict, printAA=True):
    """ Prints every locus in a cdsDict. Optionally truns off AA translation.
    """
    for locus in cdsDict:
        printCDSlocus(cdsDict, locus, printAA)


def printCDSlocus(cdsDict, locus, printAA=True):
    """ Prints a locus in a given cdsDict. Optionally turns off AA translation.
    """
    cdsOutLines = []
    aaOutLines = []
    cdsOutLine = locus + "\t"
    aaOutLine = " " * len(locus) + "\t"
    for pos in range(len(cdsDict[locus])):
        if len(cdsOutLine) > 69:
            cdsOutLines.append(cdsOutLine + cdsDict[locus][pos] + "\n")
            cdsOutLine = locus + "\t"
            aaOutLines.append(aaOutLine + " " + getAA(cdsDict[locus][pos]) + " \n")
            aaOutLine = " " * len(locus) + "\t"
        else:
            cdsOutLine += cdsDict[locus][pos] + " " 
            aaOutLine += " " + getAA(cdsDict[locus][pos]) + "  "
 
    cdsOutLines.append(cdsOutLine + "\n")
    aaOutLines.append(aaOutLine + " \n")
    for index in range(len(cdsOutLines)):
        print cdsOutLines[index]
        if printAA:
            print aaOutLines[index]


def getComparableSites(cdsDict, loci=[]):
    """ given a cdsDict and a list of loci, returns a new cdsDict with positions with 
        gaps or X codons in any one sequence deleted in all sequences.
    """
    newDict = {}
    seqArray = []
    locArray = []
    deleteList = []
    index = 0
    for locus in loci:
        locArray.append(locus)
        seqArray.append(copy.deepcopy(cdsDict[locus]))

    for index in range(len(locArray)):
        for pos in range(len(seqArray[index])):
            aa = getAA(seqArray[index][pos])
            if aa == "X" or aa == "-":
                if pos not in deleteList:
                    deleteList.append(pos)

    deleteList.sort()
    deleteList.reverse()
    for pos in deleteList:
        for index in range(len(locArray)):
            del seqArray[index][pos]

    for index in range(len(locArray)):
        newDict[locArray[index]] = seqArray[index]

    return newDict


def getInformativeSites(cdsDict, loci=[]):
    """ given a cdsDict of comparable sites and a list of loci, returns a new 
        cdsDict with positions with only codons that differ in one or more 
        sequences and which are therefore (possibly) informative.
    """
    newDict = {}
    seqArray = []
    locArray = []
    deleteList = []
    index = 0
    for locus in loci:
        locArray.append(locus)
        seqArray.append(copy.deepcopy(cdsDict[locus]))

    for pos in range(len(seqArray[0])):
        deleteCodon = True
        refcodon = seqArray[0][pos]
        for index in range(1, len(locArray)):
            seqcodon = seqArray[index][pos]
            if seqcodon != refcodon:
                deleteCodon = False

        if deleteCodon:
            deleteList.append(pos)

    deleteList.sort()
    deleteList.reverse()
    for pos in deleteList:
        for index in range(len(locArray)):
            del seqArray[index][pos]

    for index in range(len(locArray)):
        newDict[locArray[index]] = seqArray[index]

    return newDict


def buildCDSdict(cdsFileName):
    """ imports a set of *ALIGNED* sequences in a fasta-format file and splits 
        each sequence into its individual codons. Returns a Dictionary of the 
        sequences with the fasta ID as the key and the codons in a list.
    """
    cdsfile = open(cdsFileName, "r")
    cdslines = cdsfile.readlines()
    cdsfile.close()	
    cdsDict = {}
    locus = ""
    for line in cdslines:
        partialCodon = ""
        inFrame = True
        line = line[:-1]
        if line[0] == ">":
            fields = line.split(" ")
            if len(fields[0]) > 1 and fields[0][0] == ">":
                locus = fields[0][1:]
            else:
                locus = fields[1]

            cdsDict[locus] = []
        else:
            for pos in range(0, len(line), 3):
                codon = line[pos:pos+3]
                if codon == "---":
                    cdsDict[locus].append("---")
                elif "-" in codon and inFrame:
                    inFrame = False
                    partialCodon = string.replace(codon, "-", "")
                    cdsDict[locus].append("---")
                elif not inFrame:
                    partialCodon += string.replace(codon, "-", "")
                    if len(partialCodon) == 3:
                        cdsDict[locus].append(partialCodon)
                        inFrame = True
                        partialCodon = ""

                    if len(partialCodon) > 3:
                        cdsDict[locus].append(partialCodon[:3])
                        partialCodon = partialCodon[3:]
                else:
                    cdsDict[locus].append(codon)

    return cdsDict