########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
# Mixture of Gaussians Model
#

import Numeric
import MLab
import LinearAlgebra
import random
import math

try:
  import compClust.mlx.DA.cDA as DA
except ImportError, e:
  print "falling back to the numerically less accurate python DA module"
  import compClust.mlx.DA.DA as DA

from compClust.mlx.interfaces import IModel

class MixtureOfGaussians(IModel):

  def __init__(self, k, means, covariances, weights=None):
    """
    Creates a new Mixture of Gaussians (MoG) Model, containing k Gaussian
    clusters in d-dimensional space.

    Where:

      - means is a k-by-d matrix containing the means (one per row) of
        the k-Gaussian models.

      - covariances is a set of k d-by-d covariance matrices (i.e. a
        k-by-d-by-d matrix) containing the covariances of the
        k-Gaussian models.

      - weights is a 1-by-k matrix of weights, one for each model.
        Optional.  If not specified, each of the k-Guassian models
        will have equal weight.  That is, each element of weights will
        default to 1 / k.
    """

    self.setParameters(k, means, covariances, weights)
    
  def getDiagLogLikelihood(self, data):
    """
    Returns the log-likelihood of the given data under the current
    model using a diagonal covariance matrix.  Data is a matrix of
    numbers whose dimensionality (number of columns) must agree with
    that of the model.
    """

    if (data.shape[1] != self.d):
      raise ValueError, "Data matrix must have d=%d columns." % (self.d)

    diag_covariances = []
    for covar in self.covariances:
      s = (covar.shape[0], covar.shape[1])
      diag_covariances.append( MLab.eye(s[0], s[1]) * covar)

    diag_covariances = Numeric.asarray(diag_covariances)
            
    return DA.mixture_likelihood( data, self.means, diag_covariances,
                                  self.weights )
  
  def getLogLikelihood(self, data):
    """
    Returns the log-likelihood of the given data under the current
    model.  Data is a matrix of numbers whose dimensionality (number
    of columns) must agree with that of the model.
    """

    if (data.shape[1] != self.d):
      raise ValueError, "Data matrix must have d=%d columns." % (self.d)
    
    return DA.mixture_likelihood( data, self.means, self.covariances,
                                  self.weights )


  def evaluateFitness(self, data):
    """
    Return the fitness of the model given a paricular set of data.
    
    First attempt to generate the full covariance fitness score, if that
    fails, try to generate a diagonal covariance, if that fails, then the
    log-likelihood is set to -MAX_FLOAT =~ -1e38.
    """
    
    try:
      fitness = self.getLogLikelihood(data)
    except ValueError, e:
      print "WARNING: value exception computing full likelihood function: ", e
      print 
      try:
        fitness = self.getDiagLogLikelihood(data)
      except ValueError, e:
        print "WARNING: value exception computing diag likelihood function: ",e
        fitness = -1e38;

    return fitness

    
  def setParameters(self, k, means, covariances, weights):
    """
    Sets the parameters for this model.  See class constructor documentation
    for more information.
    """

    if (k != means.shape[0]):
      raise ValueError, "Means matrix must have k=%d rows." % (k)

    if (k != covariances.shape[0]):
      raise ValueError, "There must be k=%d covariance matrices." % (k)
    
    if (means.shape[1] != covariances.shape[2]):
      raise ValueError, \
            "Means and covariance matrix number of columns do not match."

    if (covariances.shape[1] != covariances.shape[2]):
      raise ValueError, "Covariance matrices are not square."

    if (weights is None):
      weights = Numeric.array( k * [1.0 / k] )

    if (k != weights.shape[0]):
      raise ValueError, "Weights matrix must have k=%d columns." % (k)

    self.k           = k
    self.d           = means.shape[1]
    self.means       = means
    self.covariances = covariances
    self.weights     = weights

  def __pdf(self, x, m, c):

    d = float(len(x))
    
    log_2pi_d = d * math.log(2.0 * math.pi)
    log_det   = LinearAlgebra.determinant(c)
    diff      = x - m

    tmp       = Numeric.matrixmultiply(diff, LinearAlgebra.inverse(c))
    tmp       = Numeric.matrixmultiply(tmp,  diff)

    return -0.5 * (log_2pi_d + log_det + tmp)
    
  def classify1(self, data):

    labs = []
    for datum in data:

      probs = map(lambda x : self.__pdf(datum, self.means[x],
                                        self.covariances[x]), range(self.k))

      print probs,
      print max(probs),
      k = probs.index(max(probs))
      labs.append(k)
      print k
      
    return labs

  def classify2(self, data):

    labs = []
    for i in range(len(data)):

      probs = map(lambda x : self.__pdf(data[i], self.means[x],
                                        self.covariances[x]), range(self.k))

      tmp   = max(probs)
      map(lambda x : x - tmp, probs)
      probs = map(math.exp, probs)
      denom = Numeric.sum(probs)
      probs = map(lambda x : x / denom, probs)
      
      v = random.random()

      k = 0
      sum = 0.0
      while v > (sum + probs[k]):
        sum += probs[k]
        k   += 1
      
      labs.append(k)

    return labs
          
  def __repr__(self):
    return "MixtureOfGaussians(k=%d, d=%d)" % (self.k, self.d)
