########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

__all__ = ["MixtureOfGaussians", "MixtureOfFullGaussians",
           "MixtureOfDiagonalGaussians","DistanceFromMean",
           "compute_model_weights", "compute_model_covariances_weights",
           "compute_model_means", "estimateParameters",
           "constructMixtureOfGaussiansFromLabeling",
           "constructMixtureOfDiagonalGaussiansFromLabeling",
           "constructMixtureOfFullGaussiansFromLabeling"]
  
from MixtureOfGaussians import MixtureOfGaussians
from MixtureOfFullGaussians import MixtureOfFullGaussians
from MixtureOfDiagonalGaussians import MixtureOfDiagonalGaussians
from DistanceFromMean import DistanceFromMean

from compClust.mlx.interfaces import IDataset
from compClust.mlx.interfaces import ILabeling

import Numeric
try:
  import compClust.mlx.DA.cDA as DA
except ImportError, e:
  print "falling back to the numerically less accurate python DA module"
  import compClust.mlx.DA.DA as DA


def compute_model_weights(dataset, labels):
  """
  Evaluates the weight of each class realtive to the whole.  This is simply
  this number of datapoints in each class divided by the total number of
  datapoints.
  """

  classes     = labels.getLabels()
  num_points  = dataset.getNumRows()
  num_classes = len(classes)

  counts = Numeric.zeros( num_classes, Numeric.Float )

  for i in range(num_classes):
    counts[i] = len(labels.getRowsByLabel(classes[i]))

  Numeric.divide(counts, num_points, counts)
  return counts

def compute_model_covariances_weights(dataset, labels, means):
  """
  Estimates the the covariances of each class given a dataset, class labeling
  and class means.  The results are returned an a three dimensional Numeric
  array.
  """
  
  data    = dataset.getData()
  classes = [None] * dataset.getNumRows()

  c = 0
  for label in labels.getLabels():
    rows = labels.getRowsByLabel(label)
    for row in rows:
      classes[row] = c
    c = c + 1
  
  return DA.covar_weights_estimate(data, means, classes)


def compute_model_means(dataset, labels):
  """
  Given a dataset and a labeling compute the means for each cluster and
  return the results as a two dimensional Numeric array with each row
  corresponding to a particular class.
  """

  num_points  = dataset.getNumRows()
  classes = labels.getLabels()
  num_classes = len(labels.getLabels())
  data        = dataset.getData()
  cols        = dataset.getNumCols()
  
  means        = Numeric.zeros((num_classes, cols), Numeric.Float)
  class_counts = Numeric.zeros((num_classes, 1),    Numeric.Float )

  for i in range(num_classes):
    rows = labels.getRowsByLabel(classes[i])
    for row in rows:
      means[i] += data[row]
    class_counts[i] = len(rows)

  return means / class_counts

def estimateParameters(dataset, labels):
  """
  Fully estimate the Mixture of Gaussians parameters for a dataset with a
  given hard partitioning.  The results are returned as a 4-tuple:
  (k, means, covariances, weights).

  See the documentation for compute_model_means() and
  compute_model_variances() for details of returned data.
  """
  
  if dataset is None or labels is None:
    return None
  
  if not isinstance(dataset, IDataset):
    raise ValueError, "dataset paramemter must be a subclass of IDataset"
  if not isinstance(labels , ILabeling ):
    raise ValueError, "labels paramemter must be a subclass of ILabeling"

  k             = len(labels.getLabels())
  means         = compute_model_means( dataset, labels )
  covar_weights = compute_model_covariances_weights( dataset, labels, means )

  covariances, weights = covar_weights

  #
  # Issue: singleton clusters have covarience matricies of zeros
  #
  # Solutions:
  #   1. Ignore such clusters
  #   2. Set the covariance matrix to I
  #   3. set the covariance matrix to I*epsilon ( implemented )
  #
  # Note that the weight for a singleton cluster is 1/num_points
  #

  numPoints  = dataset.getNumRows()
  numDims    = dataset.getNumCols()
  badVal     = 1.0 / numPoints
  num        = 0
  epsilon    = 1e-20
  
  for i in range(k):
    if abs(weights[i] - badVal) < (badVal / 2.0):
      covariances[i] = Numeric.array(k) * epsilon;
      num += 1

  return (k, means, covariances, weights)


def constructMixtureOfGaussiansFromLabeling(dataset, labels):
  """
  Return an estimated mixture of gaussians class instance from a dataset and
  labeling.
  """
  
  (k, means, covariances, weights) = estimateParameters(dataset, labels)
  return MixtureOfGaussians(k, means, covariances, weights)


def constructMixtureOfDiagonalGaussiansFromLabeling(dataset, labels):
  """
  Return an estimated mixture of diagonal gaussians class instance from a
  dataset and labeling.
  """
  (k, means, covariances, weights) = estimateParameters(dataset, labels)
  return MixtureOfDiagonalGaussians(k, means, covariances, weights)


def constructMixtureOfFullGaussiansFromLabeling(dataset, labels):
  """
  Return an estimated mixture of full gaussians class instance from a dataset
  and labeling.
  """
  (k, means, covariances, weights) = estimateParameters(dataset, labels)
  return MixtureOfFullGaussians(k, means, covariances, weights)



