#!/usr/bin/env python
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

import inspect
import os
import string
import sys
import tempfile
import types
import unittest

from compClust.config import config
from compClust.util import WrapperUtil
from compClust.util import Verify

import compClust.mlx.wrapper

class DiagEMTestCases(unittest.TestCase):

  def setUp(self):
    """Create temporary directory and file handles for test runs of diagem.
    """
    source = os.path.realpath(inspect.getsourcefile(DiagEMTestCases))
    self.datadir = os.path.split(source)[0]
    self.executable=string.join([sys.executable,
                                 os.path.join(self.datadir,'..','DiagEM.py')])
      
    self.original_dir    = os.getcwd()
    os.chdir(compClust.mlx.wrapper.__path__[0])
    
    self.temp_dir_name   = WrapperUtil.create_temporary_directory("ditst")

    self.orig_tempdir    = tempfile.tempdir
    tempfile.tempdir     = self.temp_dir_name
    self.param_filename  = tempfile.mktemp("parameter_file")
    self.param_stream    = open(self.param_filename, "w")
    self.result_filename = tempfile.mktemp("result_file")

  def tearDown(self):
    """Clean up after ourselves.
    """
    tempfile.tempdir = self.orig_tempdir
    # FIXME: should this delete files that setUp did not create?
    try:
        os.remove(self.result_filename)
    except OSError, e:
        pass
    
    self.param_stream.close()
    os.remove(self.param_filename)
    os.rmdir(self.temp_dir_name)

    os.chdir(self.original_dir)

  def getParameters(self, k=5):
    """parameters_dict = self.getParameters(k)

    Return a reasonable list of parameters for the given algorithm.
    """

    # Note the default k for DiagEM is 4 while all the other algorithms
    # it's 5. The reason is k_strict fails for k=5 on the 75 point
    # dataset
    #
    # No longer fails (7/3/02)
    params = []
    for k,v in self.getParametersDictionary(k).items():
      if type(v) in types.StringTypes:
        params.append("%s = '%s'" % (str(k), str(v)))
      else:
        params.append("%s = %s" % (str(k), str(v)))
    return params

  def getParametersDictionary(self, k=5):
    return {'k': k,
            'k_strict': 'false',
            'num_iterations': 100,
            'seed': 1234,
            'distance_metric': 'euclidean',
            'init_method': 'church_means'}
  
    
  def test_diagem_wrapper(self):
    """Run a DiagEM clustering using the wrapper just to make sure the class
    is behaving correctly.
    """
    from compClust.mlx.datasets import Dataset
    from compClust.mlx.wrapper import DiagEM

    parameters = self.getParametersDictionary()
    
    data_filename = os.path.join(self.datadir,
                                 "synth_t_05c0_p_0075_d_03_v_0d1.txt")
    dataset = Dataset(data_filename)
    start_labelings_count = len(dataset.getLabelings())

    algorithm = DiagEM(dataset, parameters)
    algorithm.validate()
    algorithm.run()

    model = algorithm.getModel()
    labeling = algorithm.getLabeling()

    end_labelings_count = len(dataset.getLabelings())
    self.failUnless(end_labelings_count == start_labelings_count + 2,
                    "labelings start = %d, end = %d" % (start_labelings_count,
                                                        end_labelings_count))
    
  def test_diagem_small_tree_reasonable_k(self):
    """Test diagem using a small tree search for a number of clusters
    that is noticably below the number of data points.

    """
    parameters = self.getParameters()
    
    data_filename = os.path.join(self.datadir,
                                 "synth_t_05c0_p_0075_d_03_v_0d1.txt")
    result_filename = os.path.join(self.datadir, 'diagem.small')
                        
    self.param_stream.write(string.join(parameters, "\n"))
    self.param_stream.close()

    argv=[self.executable, self.param_filename, data_filename, self.result_filename, " > /dev/null" ]

    os.system(string.join(argv, " "))
    result = os.system("cmp -s %s "%(self.result_filename) + `self.result_filename`)
    result = os.WEXITSTATUS(result)

    self.failIf(result != 0, "small tree failed %s" %(argv))

  def test_diagem_medium_tree_reasonable_k(self):
    """Test diagem using a medium tree search for a number of clusters
    that is noticably below the number of data points.

    """
    parameters = self.getParameters()
    
    data_filename = os.path.join(self.datadir,
                                 "synth_t_05c0_p_0750_d_03_v_0d1.txt")
    result_filename = os.path.join(self.datadir, 'diagem.medium')
                        
    self.param_stream.write(string.join(parameters, "\n"))
    self.param_stream.close()

    argv=[self.executable, self.param_filename, data_filename, self.result_filename, " > /dev/null" ]
    
    os.system(string.join(argv, " "))
    result = os.system("cmp -s %s "%(self.result_filename) + `self.result_filename`)
    result = os.WEXITSTATUS(result)

    self.failIf(result != 0, "medium tree failed")

  def test_diagem_large_tree_reasonable_k(self):
    """Test diagem using a large tree search for a number of clusters
    that is noticably below the number of data points.

    """
    parameters = self.getParameters()
    
    data_filename = os.path.join(self.datadir,
                                 "synth_t_05c0_p_7500_d_03_v_0d1.txt")
    result_filename = os.path.join(self.datadir, 'diagem.large')
                        
    self.param_stream.write(string.join(parameters, "\n"))
    self.param_stream.close()

    argv=[self.executable, self.param_filename, data_filename, self.result_filename, "> /dev/null" ]
    
    os.system(string.join(argv, " "))
    result = os.system("cmp -s %s "%(self.result_filename) + `self.result_filename`)
    result = os.WEXITSTATUS(result)

    self.failIf(result != 0, "large tree failed")
    
def suite(**kw):

  print "These tests will take several minutes to run..."
  
  suite = unittest.TestSuite()
  if os.path.exists(config.diagem_command):
    suite.addTest(DiagEMTestCases("test_diagem_wrapper"))
    #suite.addTest(DiagEMTestCases("test_diagem_small_tree_reasonable_k"))
    #suite.addTest(DiagEMTestCases("test_diagem_medium_tree_reasonable_k"))
    #suite.addTest(DiagEMTestCases("test_diagem_large_tree_reasonable_k"))
  else:
    print "DIAGEM_COMMAND is not available, diagem tests skipped"

  return suite

if __name__ == "__main__":
  unittest.main(defaultTest="suite")


