#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

import inspect
import os
import string
import sys
import tempfile
import unittest

from compClust.config import config
from compClust.mlx.labelings import Labeling
from compClust.mlx.wrapper import KMeans, MCCV
from compClust.mlx.wrapper import Launcher
from compClust.score.ConfusionMatrix2 import ConfusionMatrix
from compClust.util import WrapperUtil
from compClust.util.LoadExample import LoadCho

import compClust.mlx.wrapper

historical_kmeans_labels = [['4'], ['2'], ['1'], ['3'], ['3'], ['3'], ['2'], ['1'], ['1'], ['1'], ['3'], ['1'], ['4'], ['1'], ['1'], ['1'], ['2'], ['1'], ['4'], ['4'], ['4'], ['3'], ['3'],['3'], ['3'], ['4'], ['3'], ['3'], ['1'], ['2'], ['1'], ['4'], ['4'], ['1'], ['4'], ['4'], ['4'], ['3'], ['3'], ['1'], ['3'], ['3'], ['1'], ['3'], ['3'], ['4'], ['2'], ['1'], ['3'], ['3'], ['2'], ['2'], ['2'], ['1'], ['4'], ['3'], ['1'], ['1'], ['1'], ['1'], ['3'], ['1'], ['3'], ['4'], ['4'], ['3'], ['3'], ['4'], ['3'], ['4'], ['4'], ['4'], ['3'], ['3'], ['1'], ['3'], ['3'], ['1'], ['3'], ['2'], ['3'], ['3'], ['3'], ['3'], ['4'], ['3'], ['1'], ['2'], ['1'], ['1'], ['1'], ['1'], ['3'], ['1'], ['1'], ['1'], ['1'], ['4'], ['3'], ['1'], ['1'], ['1'], ['3'], ['4'],['3'], ['4'], ['4'], ['4'], ['3'], ['2'], ['3'], ['1'], ['4'], ['1'], ['1'], ['1'], ['1'], ['1'], ['4'], ['1'], ['1'], ['1'], ['3'], ['3'], ['1'], ['3'], ['3'], ['3'], ['1'], ['1'], ['2'], ['4'], ['3'], ['2'], ['1'], ['4'], ['3'], ['4'], ['3'], ['2'], ['2'], ['3'], ['4'], ['2'], ['2'], ['3'], ['1'], ['2'], ['3'], ['4'], ['2'], ['3'], ['2'], ['4'], ['2'], ['4'], ['4'], ['4'], ['2'], ['4'], ['4'], ['1'], ['3'], ['4'], ['2'], ['2'], ['3'], ['3'], ['1'], ['1'], ['4'], ['2'], ['1'], ['3'], ['2'], ['3'], ['1'], ['1'], ['2'], ['1'], ['3'], ['4'], ['1'], ['3'], ['4'],['3'], ['3'], ['3'], ['4'], ['1'], ['1'], ['3'], ['1'], ['1'], ['1'], ['1'], ['1'], ['2'], ['1'], ['1'], ['1'], ['4'], ['3'], ['4'], ['3'], ['4'], ['4'], ['4'], ['2'], ['2'], ['4'], ['3'], ['3'], ['3'], ['1'], ['1'], ['3'], ['3'], ['1'], ['1'], ['1'], ['1'], ['3'], ['3'], ['1'], ['1'], ['3'], ['2'], ['3'], ['3'], ['1'], ['1'], ['1'], ['1'], ['4'], ['3'], ['3'], ['1'], ['1'], ['3'], ['1'], ['4'], ['3'], ['4'], ['3'], ['3'], ['3'], ['2'], ['3'], ['3'], ['2'], ['3'], ['1'], ['3'], ['3'], ['3'], ['4'], ['4'], ['4'], ['4'], ['4'], ['3'], ['3'], ['3'], ['2'], ['3'],['3'], ['3'], ['3'], ['4'], ['1'], ['3'], ['2'], ['4'], ['3'], ['3'], ['3'], ['3'], ['2'], ['2'], ['3'], ['3'], ['2'], ['1'], ['4'], ['4'], ['2'], ['3'], ['3'], ['3'], ['3'], ['1'], ['2'], ['3'], ['1'], ['2'], ['4'], ['1'], ['4'], ['3'], ['3'], ['3'], ['3'], ['3'], ['3'], ['3'], ['3'], ['3'], ['3'], ['3'], ['3'], ['1'], ['4'], ['1'], ['3'], ['3'], ['4'], ['4'], ['3'], ['2'], ['2'], ['2'], ['4'], ['4'], ['1'], ['3'], ['4'], ['2'], ['3'], ['2'], ['2'], ['3'], ['3'], ['2'], ['1'], ['3'], ['2'], ['3'], ['2'], ['4'], ['4'], ['1'], ['3'], ['3'], ['1'], ['2'], ['4'],['4'], ['1'], ['3'], ['2'], ['4'], ['4'], ['3'], ['1'], ['3'], ['2'], ['3'], ['3'], ['3'], ['4'], ['1'], ['4'], ['3'], ['1'], ['3'], ['1'], ['4'], ['2'], ['3'], ['1'], ['3'], ['4'], ['4'], ['3'], ['4'], ['1'], ['2'], ['3'], ['1'], ['2'], ['1'], ['3'], ['3']]
historical_kmeans_labels = [ x[0] for x in historical_kmeans_labels]

class MCCVTestCases(unittest.TestCase):

  def setUp(self):
    """Create temporary directory and file handles for test runs of mccv.
    """
    source = os.path.realpath(inspect.getsourcefile(MCCVTestCases))
    self.datadir = os.path.split(source)[0]
    self.executable=string.join([sys.executable,
                                 os.path.join(self.datadir,'..','KMeans.py')])
    
    self.original_dir    = os.getcwd()
    os.chdir(compClust.mlx.wrapper.__path__[0])

    self.orig_tempdir  = tempfile.tempdir
    self.temp_dir_name = WrapperUtil.create_temporary_directory("cvtst")

    tempfile.tempdir      = self.temp_dir_name

    self.param_filename   = tempfile.mktemp("parameter_file")
    self.param_stream     = open(self.param_filename, "w")
    self.result_filename  = tempfile.mktemp(".result_file")
    self.fitness_filename = tempfile.mktemp(".fitness")

  def tearDown(self):
    """Clean up after ourselves.
    """
    # FIXME: should this delete files that setUp did not create?
    tempfile.tempdir = self.orig_tempdir
    try:
      os.remove(self.result_filename)
    except OSError, e:
      # result_filename may or may not exist so ignore deletion failures
      pass
    self.param_stream.close()
    if os.path.exists(self.param_filename): os.remove(self.param_filename)
    if os.path.exists(self.fitness_filename): os.remove(self.fitness_filename)
    os.rmdir(self.temp_dir_name)
    os.chdir(self.original_dir)

  def getKMeansParametersDictionary(self, k=5):
    return {"distance_metric":'euclidean',
            "init_means": 'church',
            "num_iterations": 100, 
            "k": 5,}

  def getMCCVParametersDictionary(self, k=5):
    return {"num_trials": 100,
            "seed": 42,
            "parameter_name": 'k',
            "parameter_values": range(2,k)
            }

  def test_mccv_wrapper(self):
    """Run a MCCV clustering using KMeans just to make sure the class
    is behaving correctly.
    """
    dataset = LoadCho()
    mccv_parameters = self.getMCCVParametersDictionary()
    kmeans_parameters = self.getKMeansParametersDictionary()
    
    sub_algorithm = KMeans(dataset, kmeans_parameters)
    algorithm = MCCV(dataset, mccv_parameters, sub_algorithm)
    algorithm.run()
    
    model = algorithm.getModel()
    cluster_labeling = algorithm.getLabeling()
    
    historical_kmeans_labeling = Labeling(dataset)
    historical_kmeans_labeling.labelRows(historical_kmeans_labels)
    # more clusters should be the best fit for this dataset
    self.failUnless(len(cluster_labeling.getLabels()) == 4 )
    cluster_labels = cluster_labeling.getAllRowLabels()
    self.failUnless(len(cluster_labels) == len(historical_kmeans_labels))
    cm = ConfusionMatrix([cluster_labeling, historical_kmeans_labeling])
    self.failUnless(cm.linearAssignment() > 0.95)
    

  def test_kmeans_small_tree_reasonable_k(self):
    """Test mccv using a small tree search for a number of clusters
    that is noticably below the number of data points.

    """
    data_filename = os.path.join(self.datadir,
                                 "synth_t_05c0_p_0075_d_03_v_0d1.txt")
    result_filename = os.path.join(self.datadir, 'mccv_kmeans_fitness.small')

    kmeans_parameters = self.getKMeansParameters()
    mccv_parameters = self.getMCCVParameters()
    mccv_parameters.fitness = self.fitness_filename
                            
    self.param_stream.write(string.join(parameters, "\n"))
    self.param_stream.close()

    argv=[self.executable, self.param_filename, data_filename, \
          self.result_filename, "--MCCV=on" ]
    
    result = Launcher.main(argv, KMeans())
    result = os.system("cmp -s %s "%(self.result_filename) + `self.result_filename`)

    self.failIf(result != 0, "small tree failed")

  def test_kmeans_medium_tree_reasonable_k(self):
    """Test mccv using a small tree search for a number of clusters
    that is noticably below the number of data points.

    """
    data_filename = os.path.join(self.datadir,
                                 "synth_t_05c0_p_0750_d_03_v_0d1.txt")
    result_filename = os.path.join(self.datadir, 'mccv_kmeans_fitness.medium')

    parameters = self.getParameters(k=20)
    parameters.append("mccv_fitness = '" + self.fitness_filename + "'")
        
    self.param_stream.write(string.join(parameters, "\n"))
    self.param_stream.close()

    argv=[self.executable, self.param_filename, data_filename, \
          self.result_filename, "--MCCV=on" ]
    result = Launcher.main(argv, KMeans())
    result = os.system("cmp -s %s "%(self.result_filename) + `self.result_filename`)
    
    self.failIf(result != 0, "medium tree failed")

  def test_kmeans_large_tree_reasonable_k(self):
    """Test mccv using a small tree search for a number of clusters
    that is noticably below the number of data points.

    """
    data_filename = os.path.join(self.datadir,
                                 "synth_t_05c0_p_7500_d_03_v_0d1.txt")
    result_filename = os.path.join(self.datadir, 'mccv_kmeans_fitness.large')

    parameters = self.getParameters(40)
    parameters.append("mccv_fitness = '" + self.fitness_filename + "'")
    
    data_filename    = "test/synth_t_05c0_p_7500_d_03_v_0d1.txt"
                        
    self.param_stream.write(string.join(parameters, "\n"))
    self.param_stream.close()

    argv=[self.executable, self.param_filename, data_filename, \
          self.result_filename, "--MCCV=on" ]
    result = Launcher.main(argv, KMeans())
    result = os.system("cmp -s %s "%(self.result_filename) + `self.result_filename`)

    
    self.failIf(result != 0, "large tree failed")
    
def suite(**kw):
  suite = unittest.TestSuite()
  if os.path.exists(config.kmeans_command):
    suite.addTest(MCCVTestCases("test_mccv_wrapper"))
    #suite.addTest(MCCVTestCases("test_kmeans_small_tree_reasonable_k"))
    #suite.addTest(MCCVTestCases("test_kmeans_medium_tree_reasonable_k"))
    #suite.addTest(MCCVTestCases("test_kmeans_large_tree_reasonable_k"))
  else:
    print "MCCV needs KMEANS_COMMAND, which isn't available, MCCV test skipped"

  return suite

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
