########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Benjamin J. Bornstein, Diane Trout, Lucas Scharenbroich
# Last Modified: 30-Nov-2001, 15:30
#

"""
Usage: DiagEM.py parameter_filename input_filename output_filename

 Wrapper for diagonal em algorithm

 Depends on the following environment variables:
   DIAGEM_COMMAND   (e.g., /proj/cluster_gazing2/bin/diagem)

 Brief Algorithm Description:

            Performs EM segmentation of an array of feature 
            vectors.  The algorithm is from Bishop's "Neural
            Networks for Pattern Recognition", page 65.  This 
            particular EM algorithm fits Gaussians to the data.
            Each element of the feature vector is assumed to
            be independent (i.e. independent channels). 

 Required Parameters:  (note: the list enclosed in the brakets are possible
                              values each one of parameters can take )

         k               = <x>

                           x is the number of clusters to find

         num_iterations  = <x>

                           Where x is the number of iteration to perform over
                           the data set

         distance_metric = [correlation, correlation_centered, euclidean]

                           The correlation metric is actually Euclidean
                           distance on the data set mapped to the surface
                           of a hypersphere.  This approximates the
                           correlation metric.

         init_method     = [church_means, random_means, random_point,
                            random_range, random_sample, file]
         

 Optional / Dependent Parameters:

         k_strict        = ['true', 'false']

                           Turns on/off k strict behavior, which means that
                           is the exact number k clusters is not found, i.e.
                           there are collapsed clusters, then do not return
                           _any_ results.  Collapsed clusters tend to happen
                           more often with the euclidean metric than the
                           correlation metric which can return singleton
                           clusters

         seed            = <x> (optional)

                           Where x is the number used to seed the random
                           number generator.  This parameter allows runs
                           of the algorithm to be deterministic.  If the
                           parameter is omitted, it will be initialized
                           42

         samples         = <x> (depends on init_method)

                           If the random_sample initialization method is
                           chosen, then this parameter defines how many points
                           to sample for each mean.  It must be >0 and <rows.

         means_file      = "file name" (depends on init_method)

                           If the file initialization method is chosen, this
                           parameters specifies the file to load the means
                           from.

         annealing       = ['on', 'off']

                           Turns on annealing.  If not speicified assumed to
                           be 'off'

         initial_temp    = <x>

                           Starting temperature to run the annealer at.

         schedule        = <x>

                           Temperature schedule.  The initial_temp is
                           multiplied by this number every step.  Needs to be
                           in the range (0.0, 1.0), but should be in the
                           high 0.90s.
                                    
         em_type         = ['scalar', 'diagonal']

                           Restricts the freedom of the covariance matrix
                           calculations.  Assumed to be 'diagonal'.

 Depreciated Parameters:

         Not needed (set to constant values)

            test_fraction
            train_fraction

         Superceded by the parameter 'k'

            min_clusters   
            max_clusters

         Only applicable to mccv run which are now handled by the MCCV.py
         wrapper

            stepsize
            seed
            crossvalidation_runs
            crossvalidation_samples
"""

import os
import re
import sys
import string
import Numeric
import tempfile
import types

from compClust.config import config
from compClust.util   import Verify
from compClust.util   import Usage
from compClust.util   import WrapperUtil
from compClust.util.TimeStampedPrintStream import TimeStampedPrintStream

from compClust.mlx.datasets import Dataset
from compClust.mlx.labelings import Labeling, ClusteredLabeling
from compClust.mlx.models import MixtureOfDiagonalGaussians
import compClust.mlx.ML_Algorithm as ML_Algorithm

from compClust.util.WrapperParameterDescription import WrapperParameterDescription as WPD

import compClust.mlx.wrapper  

MESSAGE_STREAM = TimeStampedPrintStream("%Y-%b-%d %H:%m: DiagEM: ")


#
# KMeans Parameter Def
#

# Descriptions - Parameters
k_strict_desc = """If \"true\", kmeans will treat k as a strict parameter.  That is,
 if k clusters could not be found, (after an optional
 num_restarts, in the case of randomly initialized means) no
 result will be reported."""
max_starts_desc = 'The maximum number of restarts in the case of collapsed clusters (valid only for randomly initialized means).'
num_mean_samples_desc = """If init_means = \"random_sample\", this parameter indicates the
 number of datapoints to sample (without replacement) when
 estimating initial means.
"""
seed_desc = 'The seed to use for the pseudo-random number generator (valid only for randomly initialized means).'
samples_desc = """If the random_sample initialization method is
chosen, then this parameter defines how many points
to sample for each mean.  It must be >0 and <rows.
"""
schedule_desc="""Temperature schedule.  The initial_temp is
multiplied by this number every step.  Needs to be
in the range (0.0, 1.0), but should be in the
high 0.90s.
"""
em_type_desc="""Restricts the freedom of the covariance matrix
calculations.  Assumed to be 'diagonal'.
"""

import compClust.util.WrapperParameters as wp

class Parameters(wp.WrapperParameters):
  _params = [
    wp.IntProperty('k', 2, min=1,
                   doc='The number of clusters, k, to find.',
                   priority=wp.Priority.REQUIRED),
    wp.ComboProperty('distance_metric', 'euclidean',
                     ['euclidean', 'correlation', 'correlation_centered'],
                     doc='Distance metric of "correlation" or "euclidean"',
                     priority=wp.Priority.REQUIRED),
    wp.ComboProperty('init_method', 'church_means',
                     ['church_means', 'random_means', 'random_point',
                      'random_sample', 'random_range'],
                     doc='Initialization method of DiagEM',
                     priority=wp.Priority.REQUIRED),
    wp.IntProperty('num_iterations', 100, min=0,
                   doc='The number of iterations.',
                   priority=wp.Priority.REQUIRED),
    # optional parameters
    # FIXME: should this be a boolean property?
    wp.ComboProperty('k_strict', 'false',
                     ['true', 'false'],
                     doc=k_strict_desc,
                     priority=wp.Priority.OPTIONAL),
    wp.IntProperty('seed', 42,
                   doc=seed_desc,
                   priority=wp.Priority.OPTIONAL,),
    wp.IntProperty('samples', 0, min=0,
                   doc=samples_desc,
                   priority=wp.Priority.OPTIONAL),
    wp.ComboProperty('em_type', 'diagonal',
                     ['scalar', 'diagonal'],
                     doc=em_type_desc,
                     priority=wp.Priority.OPTIONAL),
    # experimental parameters
    wp.ComboProperty('annealing', 'off', ['on', 'off'],
                     doc='Turns on/off simulated annealing',
                     priority=wp.Priority.EXPERIMENTAL),
    wp.FloatProperty('initial_temp', 10.0,
                     doc='Starting temperature to run the annealer at.',
                     priority=wp.Priority.EXPERIMENTAL),
    wp.FloatProperty('schedule', 0.9,
                     doc=schedule_desc,
                     priority=wp.Priority.EXPERIMENTAL),
    # internal parameters
    wp.StrProperty('clusteringInputFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringOutputFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringMeansFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringVarianceFilename',priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringWeightsFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringProbsFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringInputFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringInternalFilename', priority=wp.Priority.INTERNAL),
    
    ]
                   

#
# Wrapper class for the DiagEM (diagonal expectation minimization) algorithm
#

class DiagEM(ML_Algorithm.ML_Algorithm):
  def __init__(self, dataset = None, parameters = None):
    self.setMessageStream( MESSAGE_STREAM )
    self.dataset    = dataset
    self.parameters = Parameters(parameters)
    self.model      = None
    self.labeling   = None

    self.default_tempdir= tempfile.gettempdir()
    
  def copy(self):
    newObj = DiagEM(self.dataset, self.parameters)
    newObj.labeling = self.labeling
    newObj.model = self.model
    return newObj

   
  def getLabeling(self):
    return self.labeling


  def getModel(self):
    return self.model


  def getTransformedDataset(self, dataset):

    xform = None
    
    if self.parameters[ "distance_metric" ] != "euclidean":

      #
      # Save the old state
      
      old_dataset  = self.getDataset()
      old_labeling = self.getLabeling()
      old_model    = self.getModel()
      old_param    = self.parameters.copy();

      self.setDataset(dataset)
      
      self.parameters.k = 1
      self.parameters.num_iterations = 0
      
      self.__execute();

      file     = self.parameters["clusteringInternalFilename"]
      internal = self.readDiagemOutput(file)
      xform    = Dataset(internal);

      self.__removeDir( WrapperUtil.tempfile.tempdir )
      if self.labeling is not None:
        dataset.removeLabeling(self.labeling)
      
      #
      # Restore
      
      self.setDataset(old_dataset)
      self.labeling = old_labeling
      self.model    = old_model
      self.parameters = old_param
      
    return xform

    
  def __execute(self):
    """
    __execute(self)

    Builds and executed the clustering algorithm
    """

    #
    # Create a temporary working directory for diagem's input and output
    # files
    #

    WrapperUtil.tempfile.tempdir = WrapperUtil.create_temporary_directory("DiagEM_")

    #
    # Prep the data file and store it in a temporary location, also construct
    # all the base+extension filenames we'll need
    #

    inputFilename    = WrapperUtil.tempfile.mktemp(".tmp")
    outputFilename   = WrapperUtil.tempfile.mktemp(".out")

    base = os.path.splitext(outputFilename)[0]
    
    internalFilename = base + ".internal"
    meansFilename    = base + ".means"
    varianceFilename = base + ".variances"
    weightsFilename  = base + ".weights"
    probsFilename    = base + ".probs"
    
    self.parameters[ "clusteringInputFilename"   ] = inputFilename
    self.parameters[ "clusteringOutputFilename"  ] = outputFilename
    self.parameters[ "clusteringMeansFilename"   ] = meansFilename
    self.parameters[ "clusteringVarianceFilename"] = varianceFilename
    self.parameters[ "clusteringWeightsFilename" ] = weightsFilename
    self.parameters[ "clusteringProbsFilename"   ] = probsFilename
    self.parameters[ "clusteringInternalFilename"] = internalFilename
    
    #
    # Construct the data file for DiagEM and store it in a temporary
    # location
    #

    self.createClusteringInputFile()
    
    #
    # Create diagem command-line.
    #

    commandLine = self.createClusteringCommandLine()
    
    #
    # Launch diagem.
    #

    os.system(commandLine)


  def __removeDir(self, dir):
    """
    __removeDir(dir)
    
    Removes all the files in a directory and the directory itself
    """

    files = os.listdir( dir )
    for file in files:
      os.remove(os.path.join(dir, file ))
    os.rmdir ( dir )

    
  def run(self):
    """run()
    
    Prepares the inputs to the clustering algorithm (DiagEM) and runs it.
    """

    #
    # run the algorithm
    #

    self.__execute()
    
    #
    # unpack clustering output files into a Model
    #

    self.createModel()

    #
    # Load in the full probabilities
    #

    file  = self.parameters[ "clusteringProbsFilename" ]
    probs = self.readDiagemOutput(file)
    probLabeling = Labeling(self.dataset, "Probabilities")

    #
    # Give a list of tuples to labelRows()
    #

    probLabeling.labelRows(map(tuple, probs))

    #
    # Load the clustering results and produce a labeling, also check that the
    # k-strict condition was not violated.  If it was, this is an invalid run
    #

    outputFilename = self.parameters[ "clusteringOutputFilename" ]
    
    if os.access( outputFilename, os.F_OK):
      stream   = open(outputFilename, "r")
      text     = map(string.strip, stream.readlines())
      stream.close()

      self.labeling = ClusteredLabeling(self.dataset, self.__class__, self.parameters)
      self.labeling.labelRows(text)

      #
      # Check for the k-strict option, if it is set and the output file
      # fails the test, set model and labeling to None
      #
      
      if self.parameters.has_key( "k_strict" ) and \
         self.parameters[ "k_strict" ] == "true":

        if not self.checkKStrict(self.labeling):
          self.model    = None
          self.labeling = None
          
          MESSAGE_STREAM.write("k_strict failure!\n")
          
    #
    # Cleanup temporary files and directory.
    #

    self.__removeDir( WrapperUtil.tempfile.tempdir )

    # return tempfile.tempdir to its default tempdir
    tempfile.tempdir = self.default_tempdir

    return compClust.mlx.wrapper.WRAPPER_STATUS_DONE

  def checkKStrict(self, labeling):
    """checkKStrict()

    Ensured that the number of clusters returned by the algorithm actually
    matched the number we want.  Return true if desired k is the returned
    k.
    """

    #
    # Since the wrappers only apply a single set of labels, we can do it
    # like this.
    #

    return len(labeling.getLabels()) == self.parameters["k"]

  def createClusteringInputFile(self):
    """createClusteringInputFile()
    """

    params      = self.parameters
    destFile    = open(params[ "clusteringInputFilename" ],'w')

    #
    # The dataset is in memory, so used that to create the
    # proper file on disk
    #

    numRows    = self.dataset.getNumRows()
    numColumns = self.dataset.getNumCols()

    #
    # create the DiagEM header and write that to the file first
    #

    header      = []

    #
    # Number of rows and columns in the dataset
    #

    header.append( "rows %d" % numRows    )
    header.append( "cols %d" % numColumns )

    #
    # set the algorithm type
    #

    if params.has_key("annealing") and params[ "annealing" ] == "on":
      header.append( "algorithm em_annealing" )
      header.append( "init_temp %s"          % params[ "initial_temp" ]    )
      header.append( "schedule_scalar %s"    % params[ "schedule" ]        )
    else:
      header.append( "algorithm em" )

    #
    # write out all of the parameters
    #
    
    header.append( "num_clusters %s"         % params[ "k" ]               )
    header.append( "max_em_iterations %s"    % params[ "num_iterations" ]  )
    header.append( "random_seed %s"          % params[ "seed" ]            )
    header.append( "distance_metric %s"      % params[ "distance_metric" ] )
    header.append( "init_method %s"          % params[ "init_method"]      )
    header.append( "output_internal 1"                                     )
    header.append( "output_model 1"                                        )

    if params.has_key( "samples" ):
      header.append( "num_samples %s"        % params[ "samples"]          )

    if params.has_key( "means_file" ):
      header.append( "means_file %s"         % params[ "means_file"]       )

    if params.has_key( "em_type" ):
      header.append( "em_type %s"            % params[ "em_type"]          )

    if params.has_key( "fast" ):
      header.append( "fast %s"               % params[ "fast" ]            )

    #
    # write out the indicator for the start of data
    # and a trailing newline to seperate 'begin data' from
    # the actual data
    #
    
    header.append( "begin data" )
    header.append( "" )

    destFile.write( string.join(header, '\n') )

    #
    # Append the dataset itself to the header...
    #

    self.dataset.writeDataset( destFile )

    #
    # ...and done
    #
      
    destFile.close()

    
  def createClusteringCommandLine(self):
    """commandLine = createClusteringCommandLine()

    Construct the command line needed to run the DIAGEM command on the
    current dataset.
    """

    command = []
    
    command.append(config.diagem_command)
    command.append(self.parameters[ 'clusteringInputFilename'  ])
    command.append(self.parameters[ 'clusteringOutputFilename' ])
    
    commandLine = string.join(command, " ")
    
    return commandLine


  def readDiagemOutput(self, file):
    """
    Numeric.array = readDiagemOutput(file)

    Reads in an output file produced by DiagEM and returns the numeric
    valus in a Numeric array.  Patterned after the __castDataset() method in
    mlx.Dataset, but special cased
    """

    data = []

    if os.access(file, os.R_OK) == 1:

      stream = open(file, "r")
    
      #
      # Process each line in stream (filename).
      #
      
      for line in stream.readlines():
        data.append( map( float, string.split( line, "\t" )))
        
        stream.close()

    else:

      MESSAGE_STREAM.write("Output file " + str(file) + " does not exist\n")
    
    #
    # Construct the Numeric array, a, from the list of lists, data.
    #

    return Numeric.array(data)

    
  def createModel(self):
    """createModel()
    
    Read the model parameters files into a set of Numeric arrays and create
    a Mixture of Diagonal Gaussians model from it.  Also read in the internal
    data used by the algorithm so fitness comparisons are valid.
    """

    #
    # load in the means found 
    #

    file  = self.parameters["clusteringMeansFilename"]
    means = self.readDiagemOutput(file)
    
    #
    # load in the weights
    #

    file    = self.parameters["clusteringWeightsFilename"]
    weights = Numeric.ravel(self.readDiagemOutput(file))
  
    #
    # Determine the k found
    #

    k = weights.shape[0]

    if k > 0:
      
      #
      # load in the variances
      #
      # Since these are simply the diagonals of full covariance matricies,
      # they need to be expanded to full matricies
      #
      
      file    = self.parameters["clusteringVarianceFilename"]
      v       = self.readDiagemOutput(file)
      fullvar = Numeric.zeros((k, v.shape[1], v.shape[1]), Numeric.Float)
      
      for i in range(k):
        for j in range(v.shape[1]):
          fullvar[i,j,j] = v[i,j]
    
      #
      # Build the model (Mixture of Diagonal Gaussians)
      #

      self.model = MixtureOfDiagonalGaussians(k, means, fullvar, weights)

    else:

      self.model = None

    
  def validate(self):
    """validate()
    
    Ensures that all parameters and environment variables nescessary
    to run the clustering algorithm (DiagEM) are defined.
    """
    
    parameterNames =   [ "k",
                         "num_iterations",
                         "distance_metric",
                         "init_method"
                        ]

    fail = 0
    
    if Verify.parameters_exist( parameterNames, self.parameters ):
      fail = 1

    #
    # Check dependencies
    #

    if self.parameters["init_method"] == "random_sample":
      if Verify.parameters_exist([ "samples" ], self.parameters):
        fail = 1
      else:
        if self.parameters[ "samples" ] < 1 or \
           self.parameters[ "samples" ] > self.dataset.getNumRows():
          MESSAGE_STREAM.write("samples is out of range\n")
          fail = 1

    if self.parameters["init_method"] == "file":
      if Verify.parameters_exist([ "means_file" ], self.parameters):
        fail = 1

    #
    # Explicitly check for a seed, if one does not exist provide a default
    # value
    #
    
    if not self.parameters.has_key( "seed" ):
      self.parameters[ "seed" ] = 42

    #
    # See if annealing is turned on, if it is, then a temp and schedule
    # must be provided
    #

    if self.parameters.setdefault("annealing", "off") == "on":
      if Verify.parameters_exist( ["initial_temp", "schedule"],
                                  self.parameters):
        fail = 1
    
    #
    # Fail if the command cannot be executed
    #
    
    if Verify.fs_objects_have_permissions( config.diagem_command ,
                                           os.X_OK ) == 0:
      MESSAGE_STREAM.write("%s is not executable." % (config.diagem_command))
      fail = 1
    
    return not fail

if (__name__ == "__main__"):
  from compClust.mlx.wrapper import Launcher

  Launcher.main(sys.argv, DiagEM())
