#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
Test suite for the Model module.
"""

import Numeric
import RandomArray
import unittest
import os

from compClust.mlx.datasets import Dataset
from compClust.mlx.labelings import Labeling
from compClust.mlx.models import *

import compClust.mlx

class ModelTestCases(unittest.TestCase):
  def setUp(self):
    """Construct a model to play with"""
    data, partitioning = self.createData(42, 42)
    self.dataset = Dataset(data)
    
    self.labels = Labeling(self.dataset)
    self.labels.labelRows(partitioning)
    
  def createData(self, seed1, seed2):
    num_of_samples = 20 # how many samples for each cluster to make
    
    import Numeric
    import RandomArray
    
    covars = [Numeric.asarray([[1, .5, .1],[.5,1,.5],[.1,.5,1]]),
              Numeric.identity(3),
              Numeric.asarray([[1,0,0],[0,.5,0],[0,0,.1]])]
    means = [ Numeric.asarray([3.0,3.0,3.0],'d'),
              Numeric.asarray([0,0,0],'d'),
              Numeric.asarray([-3,-3,-3], 'd')]

    data = []
    partitioning = []
    RandomArray.seed(60,60)
    for i in xrange(len(covars)):
      mean = Numeric.zeros(covars[0].shape[0]) # use number of columns
      for j in xrange(num_of_samples):
        row = RandomArray.multivariate_normal(means[i],covars[i])
        mean = mean + row
        data.append(row)
        partitioning.append(i)
      
    return data, partitioning

  def compare_covariances(self, m1, m2, msg="covariance failed at %d %d %d"):
    self.failUnless(len(m1) == len(m2),"different number of covariant classes")
    for k in xrange(len(m1)): # choose which class to operate on
      self.failUnless(len(m1[k]) == len(m2[k]),
        "different # of rows for covariance matricies for class %d" % (k))
      for row in xrange(len(m1[k])):
        self.failUnless(len(m1[k]) == len(m2[k]),
        "different # of cols for covariance matricies for class %d %d" % (k,
                                                                          row))
        for col in xrange(len(m1[k][row])):
          self.failUnlessAlmostEqual(m1[k][row][col],
                                     m2[k][row][col],
                                     msg=msg%(k,row,col))
          
          
  def compare_means(self, m1, m2, msg="means failed at %d %d "):
    self.failUnless(len(m1) == len(m2), "different number of mean classes")
    for k in xrange(len(m1)): # choose which class to operate on
      self.failUnless(len(m1[k]) == len(m2[k]),
                      "mean vectors were different lengths in class %d" % (k))
      for col in xrange(len(m1[k])):
        self.failUnlessAlmostEqual(m1[k][col], m2[k][col], msg=msg%(k,col))

  def compare_weights(self, w1, w2, msg="weights differed at %d"):
    self.failUnless(len(w1) == len(w2), "weights matrices were different lengths")
    for k in xrange(len(w1)):
      self.failUnlessAlmostEqual(w1[k], w2[k], msg=msg%(k))
      

  def checkComputeModelWeights(self):
    weights = compute_model_weights(self.labels)
    # FIXME: need actual tests
    os.chdir(self.original_dir)

  def checkComputeModelMeans(self):
    computed_means = compute_model_means(self.dataset, self.labels)
    hardcoded_means = [[ 2.92682856, 3.00498105, 3.13383733,],
                       [-0.33332572, 0.07445805,-0.34063035,],
                       [-2.7419303 ,-3.39541307,-2.90612711,]]
    
    for means in  zip(computed_means, hardcoded_means):
      for i in xrange(len(means)):
        self.failUnlessAlmostEqual(means[0][i], means[1][i],
                                   msg="failed check on %s:%s" % (means, i))
  
  def checkComputeModelCovarianceAndWeights(self):
    means = compute_model_means(self.dataset, self.labels)
    covariances,weights = compute_model_covariances_weights(self.dataset,
                                                            self.labels,
                                                            means)

    for w in weights:
      self.failUnlessAlmostEqual(w, 0.33333333, msg="weights failed")

    historical_covariances = [[[ 1.95345055, 1.22612964, 0.56807521,],
                               [ 1.22612964, 1.82283955, 0.91938557,],
                               [ 0.56807521, 0.91938557, 0.91231437,],],
                              [[ 0.60899358,-0.1276676 ,-0.02673082,],
                               [-0.1276676 , 1.22658305,-0.01629416,],
                               [-0.02673082,-0.01629416, 0.98311735,],],
                              [[ 0.86666259,-0.08303237,-0.08434633,],
                               [-0.08303237, 0.51963669, 0.06355446,],
                               [-0.08434633, 0.06355446, 0.08424411,],],]

    error_msg="ComputeModelCovariances failed on index %s %s %s"
    self.compare_covariances(covariances, historical_covariances, error_msg)

  def checkDistanceFromMeanModel(self):
##     means = compute_model_means(self.dataset, self.labels)
##     kmeans_model = KMeansModel(means)
##     fitness = kmeans_model.evaluateFitness(means)
##     if fitness < .038 or fitness > .04:
##       fail("fitness of KMeansModel with test dataset changed.")
    I3 = Numeric.array([[1,0,0],[0,1,0],[0,0,1]])
    variance_range = range(2,17,2)
    
    #means = [Numeric.array([1,2,3]), Numeric.array([6,7,8])]
    means = [Numeric.array([1,2,3])]
    kmeans_model = DistanceFromMean(means)
    fitness_tbl = Numeric.zeros(len(variance_range), Numeric.Float)

    historical_fit_table = [ 0.16870969, 0.08638339, 0.05257422,
                             0.04417582, 0.03531564, 0.02822827,
                             0.02455117, 0.02111854,]

    for v in xrange(len(variance_range)):
      data = []
      variance = variance_range[v]
      # construct ddata
      for i in xrange(250):
        data.append(RandomArray.multivariate_normal(means[0], I3 * variance))
        #data.append(RandomArray.multivariate_normal(means[1], I3 * variance))

      fitness = kmeans_model.evaluateFitness(data)
      fitness_tbl[v] = fitness
      self.failUnlessAlmostEqual(fitness, historical_fit_table[v],
                                 msg="fitness failed for entry %d" %(v))

  def checkConstructMoGModelFromLabeling(self):
    historical_covariances = [[[ 1.95345055, 1.22612964, 0.56807521,],
                               [ 1.22612964, 1.82283955, 0.91938557,],
                               [ 0.56807521, 0.91938557, 0.91231437,],],
                              [[ 0.60899358,-0.1276676 ,-0.02673082,],
                               [-0.1276676 , 1.22658305,-0.01629416,],
                               [-0.02673082,-0.01629416, 0.98311735,],],
                              [[ 0.86666259,-0.08303237,-0.08434633,],
                               [-0.08303237, 0.51963669, 0.06355446,],
                               [-0.08434633, 0.06355446, 0.08424411,],],]
    historical_means = [[ 2.92682856, 3.00498105, 3.13383733,],
                        [-0.33332572, 0.07445805,-0.34063035,],
                        [-2.7419303 ,-3.39541307,-2.90612711,],]
    historical_weights = [ 0.33333333, 0.33333333, 0.33333333,]


    historical_fit = -277.469945404

    m = constructMixtureOfGaussiansFromLabeling(self.dataset, self.labels)

    self.failUnless(m.k == 3, "Wrong number of classes")
    self.failUnless(m.d == 3, "Wrong number of dimensions")
    self.compare_covariances(m.covariances, historical_covariances)
    self.compare_means(m.means, historical_means)
    self.compare_weights(m.weights, historical_weights)
    fit = m.evaluateFitness(self.dataset.getData())
    self.failUnlessAlmostEqual(fit, historical_fit, places=5,
                               msg="Fitness failed %f %f" %(fit,
                                                            historical_fit))


  def checkMoDGModel(self):
    historical_covariances = [[[ 1.95345055, 1.22612964, 0.56807521,],
                               [ 1.22612964, 1.82283955, 0.91938557,],
                               [ 0.56807521, 0.91938557, 0.91231437,],],
                              [[ 0.60899358,-0.1276676 ,-0.02673082,],
                               [-0.1276676 , 1.22658305,-0.01629416,],
                               [-0.02673082,-0.01629416, 0.98311735,],],
                              [[ 0.86666259,-0.08303237,-0.08434633,],
                               [-0.08303237, 0.51963669, 0.06355446,],
                               [-0.08434633, 0.06355446, 0.08424411,],],]
    historical_means = [[ 2.92682856, 3.00498105, 3.13383733,],
                        [-0.33332572, 0.07445805,-0.34063035,],
                        [-2.7419303 ,-3.39541307,-2.90612711,]]
    historical_weights = [ 0.33333333, 0.33333333, 0.33333333,]

    historical_fit = -292.480729002

    d   = constructMixtureOfDiagonalGaussiansFromLabeling(self.dataset, self.labels)

    self.failUnless(d.k == 3, "Wrong number of classes")
    self.failUnless(d.d == 3, "Wrong number of dimensions")

    self.compare_covariances(d.covariances, historical_covariances,
                             "MoDGModel covariances failed at %d %d %d")
    self.compare_means(d.means, historical_means)
    self.compare_weights(d.weights, historical_weights)
    fit = d.evaluateFitness(self.dataset.getData())
    self.failUnlessAlmostEqual(fit, historical_fit, places=5,
                               msg="Fitness failed %f %f" %(fit,
                                                            historical_fit))

    
  def checkMoFGModel(self):
    historical_covariances =  [[[ 1.95345055, 1.22612964, 0.56807521,],
                                [ 1.22612964, 1.82283955, 0.91938557,],
                                [ 0.56807521, 0.91938557, 0.91231437,],],
                               [[ 0.60899358,-0.1276676 ,-0.02673082,],
                                [-0.1276676 , 1.22658305,-0.01629416,],
                                [-0.02673082,-0.01629416, 0.98311735,],],
                               [[ 0.86666259,-0.08303237,-0.08434633,],
                                [-0.08303237, 0.51963669, 0.06355446,],
                                [-0.08434633, 0.06355446, 0.08424411,],],]
    historical_means = [[ 2.92682856, 3.00498105, 3.13383733,],
                        [-0.33332572, 0.07445805,-0.34063035,],
                        [-2.7419303 ,-3.39541307,-2.90612711,],]
    historical_weights = [ 0.33333333, 0.33333333, 0.33333333,]
    historical_fit = -277.469945404


    f = constructMixtureOfFullGaussiansFromLabeling(self.dataset, self.labels)

    self.failUnless(f.k == 3, "Wrong number of classes")
    self.failUnless(f.d == 3, "Wrong number of dimensions")

    self.compare_covariances(f.covariances, historical_covariances,
                             "MoFGModel covariances failed at %d %d %d")
    self.compare_means(f.means, historical_means)
    self.compare_weights(f.weights, historical_weights)

    fit = f.evaluateFitness(self.dataset.getData())

    self.failUnlessAlmostEqual(fit, historical_fit, places=5,
                               msg="Fitness failed %f %f" %(fit, 
                                                            historical_fit))
    
def suite(**kw):
  suite = unittest.TestSuite()
  #not needed#suite.addTest(ModelTestCases("checkComputeModelWeights"))
  suite.addTest(ModelTestCases("checkComputeModelMeans"               ))
  suite.addTest(ModelTestCases("checkComputeModelCovarianceAndWeights"))
  suite.addTest(ModelTestCases("checkConstructMoGModelFromLabeling"   ))
  suite.addTest(ModelTestCases("checkDistanceFromMeanModel"           ))
  suite.addTest(ModelTestCases("checkMoDGModel"                       ))
  suite.addTest(ModelTestCases("checkMoFGModel"                       ))
  return suite

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
 





