#!/usr/bin/env python
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

import os
import inspect
import string
import sys
import types
import tempfile
import unittest

from compClust.config import config
from compClust.util import WrapperUtil
from compClust.util import Verify
from compClust.util.LoadExample import LoadCho


import compClust.mlx.wrapper
from compClust.mlx.wrapper import KMeans

# a previous clustering run, so we can see if the algorithm ever changes
cho_labels = [['4'], ['2'], ['5'], ['5'], ['3'], ['3'], ['2'], ['1'], ['1'], ['1'], ['3'], ['1'], ['4'], ['1'], ['1'], ['1'], ['2'], ['1'], ['4'], ['4'], ['4'], ['5'], ['3'],['3'], ['5'], ['4'], ['3'], ['3'], ['1'], ['2'], ['1'], ['4'], ['4'], ['1'], ['4'], ['4'], ['5'], ['5'], ['5'], ['5'], ['5'], ['5'], ['1'], ['5'], ['5'], ['4'], ['2'], ['1'], ['3'], ['3'], ['2'], ['2'], ['2'], ['1'], ['4'], ['5'], ['1'], ['1'], ['1'], ['1'], ['3'], ['1'], ['5'], ['4'], ['4'], ['3'], ['5'], ['4'], ['3'], ['4'], ['4'], ['4'], ['5'], ['5'], ['5'], ['5'], ['5'], ['5'], ['5'], ['2'], ['3'], ['5'], ['3'], ['5'], ['4'], ['3'], ['1'], ['2'], ['1'], ['1'], ['1'], ['1'], ['5'], ['1'], ['1'], ['1'], ['1'], ['4'], ['5'], ['1'], ['1'], ['1'], ['5'], ['4'],['3'], ['4'], ['4'], ['5'], ['3'], ['2'], ['5'], ['1'], ['4'], ['1'], ['1'], ['5'], ['5'], ['5'], ['5'], ['1'], ['5'], ['1'], ['3'], ['5'], ['5'], ['3'], ['5'], ['3'], ['1'], ['5'], ['2'], ['4'], ['5'], ['2'], ['1'], ['4'], ['5'], ['4'], ['5'], ['2'], ['2'], ['5'], ['4'], ['2'], ['2'], ['3'], ['1'], ['2'], ['3'], ['4'], ['2'], ['5'], ['2'], ['4'], ['2'], ['4'], ['4'], ['4'], ['2'], ['4'], ['4'], ['1'], ['5'], ['4'], ['2'], ['2'], ['3'], ['3'], ['5'], ['1'], ['4'], ['2'], ['1'], ['5'], ['2'], ['3'], ['1'], ['1'], ['2'], ['1'], ['3'], ['4'], ['1'], ['5'], ['4'],['3'], ['3'], ['5'], ['4'], ['1'], ['1'], ['5'], ['1'], ['1'], ['1'], ['5'], ['1'], ['2'], ['1'], ['1'], ['1'], ['4'], ['5'], ['4'], ['5'], ['4'], ['4'], ['4'], ['2'], ['2'], ['4'], ['3'], ['3'], ['3'], ['1'], ['1'], ['5'], ['5'], ['1'], ['1'], ['1'], ['1'], ['5'], ['3'], ['1'], ['1'], ['3'], ['2'], ['3'], ['5'], ['1'], ['1'], ['1'], ['1'], ['4'], ['5'], ['3'], ['5'], ['1'], ['5'], ['1'], ['4'], ['5'], ['4'], ['3'], ['5'], ['5'], ['2'], ['5'], ['3'], ['2'], ['5'], ['1'], ['5'], ['5'], ['5'], ['4'], ['4'], ['4'], ['4'], ['4'], ['3'], ['3'], ['5'], ['2'], ['5'],['5'], ['5'], ['3'], ['4'], ['1'], ['5'], ['2'], ['4'], ['5'], ['3'], ['5'], ['3'], ['2'], ['2'], ['3'], ['5'], ['2'], ['1'], ['4'], ['4'], ['2'], ['3'], ['3'], ['3'], ['3'], ['1'], ['2'], ['3'], ['1'], ['2'], ['4'], ['1'], ['4'], ['3'], ['5'], ['3'], ['5'], ['5'], ['3'], ['3'], ['3'], ['5'], ['5'], ['5'], ['5'], ['1'], ['4'], ['1'], ['3'], ['3'], ['4'], ['4'], ['3'], ['2'], ['2'], ['1'], ['4'], ['4'], ['1'], ['5'], ['4'], ['2'], ['3'], ['2'], ['2'], ['5'], ['3'], ['2'], ['1'], ['3'], ['2'], ['5'], ['2'], ['4'], ['4'], ['1'], ['3'], ['5'], ['1'], ['2'], ['4'],['4'], ['1'], ['5'], ['2'], ['4'], ['4'], ['3'], ['1'], ['3'], ['2'], ['3'], ['5'], ['5'], ['4'], ['1'], ['4'], ['3'], ['1'], ['5'], ['1'], ['4'], ['2'], ['5'], ['1'], ['5'], ['4'], ['4'], ['3'], ['4'], ['1'], ['2'], ['5'], ['1'], ['2'], ['1'], ['5'], ['3']]

class KMeansTestCases(unittest.TestCase):
##   def __init__(self):
##     # FIXME: there's a small chance that this could end up to be
##     # FIXME: something other than a single item list. Check python
##     # FIXME: docs about fiddling with a package path.
##     #self.test_dir = compClust.mlx.wrapper.__path__[0] + "/test"

  def setUp(self):
    """Create temporary directory and file handles for test runs of kmeans.
    """
    source = os.path.realpath(inspect.getsourcefile(KMeansTestCases))
    self.datadir = os.path.split(source)[0]
    self.executable=string.join([sys.executable,
                                 os.path.join(self.datadir,'..','KMlust.py')])
    
    self.original_dir    = os.getcwd()
    os.chdir(compClust.mlx.wrapper.__path__[0])
    self.temp_dir_name   = WrapperUtil.create_temporary_directory("kmtst")

    self.orig_tempdir    = tempfile.tempdir
    tempfile.tempdir     = self.temp_dir_name
    self.param_filename  = tempfile.mktemp("parameter_file")
    self.param_stream    = open(self.param_filename, "w")
    self.result_filename = tempfile.mktemp("result_file")

  def tearDown(self):
    """Clean up after ourselves.
    """
    tempfile.tempdir = self.orig_tempdir
    # FIXME: should this delete files that setUp did not create?
    try:
        os.remove(self.result_filename)
    except OSError, e:
        pass
    
    self.param_stream.close()
    os.remove(self.param_filename)
    os.rmdir(self.temp_dir_name)
    os.chdir(self.original_dir)

  def getParameters(self, k=5):
    """parameters_dict = self.getParameters(k)

    Return a reasonable list of parameters for the given algorithm.
    """

    # Note the default k for DiagEM is 4 while all the other algorithms
    # it's 5. The reason is k_strict fails for k=5 on the 75 point
    # dataset
    #
    # No longer fails (7/3/02)
    params = []
    for k,v in self.getParametersDictionary(k).items():
      if type(v) in types.StringTypes:
        params.append("%s = '%s'" % (str(k), str(v)))
      else:
        params.append("%s = %s" % (str(k), str(v)))
    return params

  def getParametersDictionary(self, k=5):
    return {'k': k,
            'num_iterations': 100,
            'distance_metric': 'euclidean',
            'init_means': 'church'}

  def test_kmeans_python(self):
    """Run a MultiRun clustering using DiagEM just to make sure the class
    is behaving correctly.
    """
    dataset = LoadCho()
    kmeans_parameters = self.getParametersDictionary()
    
    algorithm = KMeans(dataset, kmeans_parameters)
    algorithm.run()
    
    model = algorithm.getModel()
    labeling = algorithm.getLabeling()

    # more clusters should be the best fit for this dataset
    self.failUnless(len(labeling.getLabels()) == 5 )
    cluster_labels = labeling.getAllRowLabels()
    self.failUnless(len(cluster_labels) == len(cho_labels))
    for i in xrange(len(cluster_labels)):
      self.failUnless(cluster_labels[i][0] == cho_labels[i][0])
    
  def testKMeansSmallTreeReasonableK(self):
    """Test kmeans using a small tree search for a number of clusters
    that is noticably below the number of data points.

    """
    parameters = self.getParameters()
    
    data_filename = os.path.join(self.datadir,
                                 "synth_t_05c0_p_0075_d_03_v_0d1.txt")
    result_filename = os.path.join(self.datadir, 'kmeans.small')
                        
    self.param_stream.write(string.join(parameters, "\n"))
    self.param_stream.close()

    argv=[sys.executable, "./KMeans.py", self.param_filename, data_filename, self.result_filename, " > /dev/null" ]
    
    os.system(string.join(argv, " "))
    result = os.system("cmp -s %s "%(self.result_filename) + `self.result_filename`)
    result = os.WEXITSTATUS(result)

    self.failIf(result != 0, "small tree failed")


  def testKMeansMediumTreeReasonableK(self):
    """Test kmeans using a medium tree search for a number of clusters
    that is noticably below the number of data points.

    """
    parameters = self.getParameters()
    
    data_filename = os.path.join(self.datadir,
                                 "synth_t_05c0_p_0750_d_03_v_0d1.txt")
    result_filename = os.path.join(self.datadir, 'kmeans.medium')
                        
    self.param_stream.write(string.join(parameters, "\n"))
    self.param_stream.close()

    argv=[sys.executable, "./KMeans.py", self.param_filename, data_filename, self.result_filename, " > /dev/null" ]
    
    os.system(string.join(argv, " "))
    result = os.system("cmp -s test/kmeans.medium " + `self.result_filename`)
    result = os.system("cmp -s %s "%(self.result_filename) + `self.result_filename`)
    result = os.WEXITSTATUS(result)

    self.failIf(result != 0, "medium tree failed")

  def testKMeansLargeTreeReasonableK(self):
    """Test kmeans using a large tree search for a number of clusters
    that is noticably below the number of data points.

    """
    parameters = self.getParameters()
    
    data_filename = os.path.join(self.datadir,
                                 "synth_t_05c0_p_7500_d_03_v_0d1.txt")
    result_filename = os.path.join(self.datadir, 'kmeans.large')
                        
    self.param_stream.write(string.join(parameters, "\n"))
    self.param_stream.close()

    argv=[sys.executable, "./KMeans.py", self.param_filename, data_filename, self.result_filename, "> /dev/null" ]

    os.system(string.join(argv, " "))
    result = os.system("cmp -s %s "%(self.result_filename) + `self.result_filename`)
    result = os.WEXITSTATUS(result)

    self.failIf(result != 0, "large tree failed")
    
def suite(**kw):

  print "These tests will take several minutes to run..."
  
  suite = unittest.TestSuite()
  if os.path.exists(config.kmeans_command):
      suite.addTest(KMeansTestCases("test_kmeans_python"))
    #suite.addTest(KMeansTestCases("testKMeansSmallTreeReasonableK"))
    #suite.addTest(KMeansTestCases("testKMeansMediumTreeReasonableK"))
    #suite.addTest(KMeansTestCases("testKMeansLargeTreeReasonableK"))
  else:
    print "KMEANS_COMMAND is not available, kmeans tests skipped"

  return suite

if __name__ == "__main__":
  unittest.main(defaultTest="suite")


