########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Benjamin J. Bornstein, Lucas Scharenbroich
#
# Original Implementation: April 17, 2001 by Ben Bronstein
# Current Implementation:  Aug    6, 2001 by Lucas Scharenbroich
#

"""
Usage: FullEM.py parameter_filename input_filename output_filename

 Wrapper for fullem algorithm

 Depends on the following environment variables:
   FULLEM_COMMAND   (e.g., /proj/cluster_gazing2/bin/fullem)
"""

import os
import sys
import tempfile
import string
import re
import Numeric

from compClust.config import config
from compClust.util import Verify
from compClust.util import WrapperUtil
from compClust.util.TimeStampedPrintStream import TimeStampedPrintStream

from compClust.mlx.labelings import Labeling
from compClust.mlx.models import MixtureOfGaussians
import compClust.mlx.ML_Algorithm as ML_Algorithm

import compClust.mlx.wrapper

MESSAGE_STREAM = TimeStampedPrintStream("%Y-%b-%d %H:%m: FullEM: ")

# Descriptions - Parameters
max_starts_desc = 'The maximum number of restarts in the case of collapsed clusters (valid only for randomly initialized means).'
num_mean_samples_desc = """If init_means = \"random_sample\", this parameter indicates the
 number of datapoints to sample (without replacement) when
 estimating initial means.
"""
seed_desc = 'The seed to use for the pseudo-random number generator (valid only for randomly initialized means).'
samples_desc = """If the random_sample initialization method is
chosen, then this parameter defines how many points
to sample for each mean.  It must be >0 and <rows.
"""
schedule_desc="""Temperature schedule.  The initial_temp is
multiplied by this number every step.  Needs to be
in the range (0.0, 1.0), but should be in the
high 0.90s.
"""
em_type_desc="""Restricts the freedom of the covariance matrix
calculations.  Assumed to be 'diagonal'.
"""

import compClust.util.WrapperParameters as wp

class Parameters(wp.WrapperParameters):
  _params = [
    wp.IntProperty('k', 2, min=1,
                   doc='The number of clusters, k, to find.',
                   priority=wp.Priority.REQUIRED),
    wp.IntProperty('num_iterations', 100, min=0,
                   doc='The number of iterations.',
                   priority=wp.Priority.REQUIRED),
    # optional parameters
    # FIXME: should this be a boolean property?
    wp.IntProperty('seed', 42,
                   doc=seed_desc,
                   priority=wp.Priority.OPTIONAL,),
    # experimental parameters
    wp.ComboProperty('annealing', 'off', ['on', 'off'],
                     doc='Turns on/off simulated annealing',
                     priority=wp.Priority.EXPERIMENTAL),
    wp.FloatProperty('initial_temp', 10.0,
                     doc='Starting temperature to run the annealer at.',
                     priority=wp.Priority.EXPERIMENTAL),
    wp.FloatProperty('schedule', 0.9,
                     doc=schedule_desc,
                     priority=wp.Priority.EXPERIMENTAL),
    # internal parameters
    wp.StrProperty('p', priority=wp.Priority.INTERNAL, doc="Parameter filename"),
    wp.StrProperty('clusteringInputFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringOutputFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringMeansFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringVarianceFilename',priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringWeightsFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringProbsFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringInputFilename', priority=wp.Priority.INTERNAL),
    wp.StrProperty('clusteringInternalFilename', priority=wp.Priority.INTERNAL),
    
    ]

#
# 6-13-2001: Bringing fullem to correct API for MCCV
#

class FullEM(ML_Algorithm.ML_Algorithm):
  def __init__(self, dataset = None, parameters = None):
     self.setMessageStream(MESSAGE_STREAM)
     self.dataset    = dataset
     self.parameters = parameters
     self.labeling   = None
     self.model      = None
     
     self.default_tempdir = tempfile.gettempdir()
     
  def copy(self):
    new_obj = FullEM(self.dataset, self.parameters)
    new_obj.labeling = self.labeling
    new_obj.model = self.model
    return new_obj

   
  def getLabeling(self):
    return self.labeling


  def getModel(self):
    return self.model

     
  def run(self):
    """run()
    
    Prepares the inputs to the clustering algorithm (fullem) and runs it.
    """

    status = compClust.mlx.wrapper.WRAPPER_STATUS_DONE
    
    #
    # Ensure that filenames are a full (absolute) path and filename.    
    #
    # create a private directory structure
    #

    temp_dir_name = WrapperUtil.create_temporary_directory("fullem_")

    #
    # Create a temporary location for the clustering input & output
    #

    tempfile.tempdir = temp_dir_name

    cluster_output_filename = tempfile.mktemp("cluster_output")
    self.parameters['clusteringOutputFilename'] = cluster_output_filename
      
    #  
    # Prepare data file and store in temporary location.
    #
    
    self.parameters['clusteringInputFilename'] = \
         WrapperUtil.create_clustering_input_file(self.dataset,
                                                  temp_dir_name)
                                                              
    #
    # Make sure the parameters file gets written
    #
    # FIXME: This forces p to be overriden each time we test a new param

    need_to_delete_model_file = 1
    cluster_model_file = tempfile.mktemp("cluster_output")
    self.parameters[ "p" ] = cluster_model_file
      
    #
    # Create fullem command-line.
    #

    command_line = self.createClusteringCommandLine()
    
    #
    # Launch fullem.
    #

    os.system(command_line)

    #
    # Load the clustering results
    #
    #

    if Verify.fs_objects_have_permissions(cluster_output_filename, os.F_OK ):
      labeling_stream = open(cluster_output_filename, "r")
      labeling_text   = map(string.strip, labeling_stream.readlines())
      labeling_stream.close()
      self.labeling   = Labeling(self.dataset)
      self.labeling.labelRows(labeling_text)

    else:

      #
      # If no output file is produced, create a random labeling and
      # set the return code to WRAPPER_STATUS_ERROR
      #
      
      MESSAGE_STREAM.write("No output file produced from C code\n")
      self.labeling = None

      status = compClust.mlx.wrapper.WRAPPER_STATUS_ERROR
      
    #
    # Load the model
    #

    if Verify.fs_objects_have_permissions(self.parameters["p"], os.F_OK ):
      self.model = self.load_model_parameters()

    else:

      #
      # analagous to the code to create the labeling
      #
      
      MESSAGE_STREAM.write("No model file produced from C code\n")
      self.model = None
      status = compClust.mlx.wrapper.WRAPPER_STATUS_ERROR
    
    #
    # Cleanup temporary files and directory.
    #
      
    if need_to_delete_model_file and Verify.fs_objects_have_permissions(cluster_model_file, os.F_OK):
      os.remove( cluster_model_file)

    if Verify.fs_objects_have_permissions(self.parameters["clusteringOutputFilename"], os.F_OK):  
      os.remove(self.parameters["clusteringOutputFilename"])
    os.remove(self.parameters["clusteringInputFilename"])
    os.rmdir ( temp_dir_name )
    # return tempfile.tempdir to its default tempdir
    tempfile.tempdir = self.default_tempdir

    #
    # return the status of the run
    #
    
    return status
  
  def validate(self):
    """validate()
    
    Ensures that all parameters and environment variables nescessary
    to run the clustering algorithm (fullem) are defined.
    """
    
    parameter_names   = [ "seed",
                          "k"   ]
    
    err = 0
    
    if Verify.parameters_exist( parameter_names, self.parameters ):
      err = 1
          
    return not err

    
  def createClusteringCommandLine(self):
    """command_line = createClusteringCommandLine(parameters)
    """
    
    space        = " "
    command      = []
    command_line = ""
    
    command.append( config.fullem_command )
    
    #
    # Rows
    #
    command.append("-rows")
    command.append( repr( self.dataset.getNumRows() ) )
    
    #
    # Columns
    #
    command.append("-cols")
    command.append( repr( self.dataset.getNumCols() ) )
    
    #
    # K
    #
    command.append("-k")
    command.append( repr( self.parameters["k"] ) )
    
    #
    # Seed
    #
    if self.parameters.has_key( "seed" ):
      command.append("-seed")
      command.append( repr( self.parameters["seed"] ) )
      
    #
    # Dataset filename
    #
    command.append( "-i" )
    command.append( self.parameters["clusteringInputFilename"] )
    
    #
    # Result filename (classes)
    #
    command.append( "-c" )
    command.append( self.parameters["clusteringOutputFilename"] )
    
    # ----- optional parameters -----
    #
    # output Gaussian parameters file
    #
    if self.parameters.has_key( "p" ):
      command.append("-p")
      command.append( repr( self.parameters["p"] ) )
      
    #
    # output log file
    #
    if self.parameters.has_key( "l" ):
      command.append("-l")
      command.append( repr( self.parameters["l"] ) )
      
    #
    # minimum class weight
    #
    if self.parameters.has_key( "mindiag" ):
      command.append("-mindiag")
      command.append( repr( self.parameters["mindiag"] ) )
      
    #
    #  minimum class weight
    #
    if self.parameters.has_key( "minwgt" ):
      command.append("-minwgt")
      command.append( repr( self.parameters["minwgt"] ) )
      
    #
    # convergence threshold
    #
    if self.parameters.has_key( "thresh" ):
      command.append("-thresh")
      command.append( repr( self.parameters["thresh"] ) )
      
    #
    #  number of optimization attempts
    #
    if self.parameters.has_key( "tries" ):
      command.append("-tries")
      command.append( repr( self.parameters["tries"] ) )
      
    #
    # okay to collapse clusters
    #
    if self.parameters.has_key( "collapse" ):
      command.append("-collapse")
      command.append( repr( self.parameters["collapse"] ) )
      
    #
    # number of cross validation trials
    #
    if self.parameters.has_key( "trials" ):
      command.append("-trials")
      command.append( repr( self.parameters["trials"] ) )
      
    #
    # log print level
    #
    if self.parameters.has_key( "loglev" ):
      command.append("-loglev")
      command.append( repr( self.parameters["loglev"] ) )

    #
    # maximum num. reconvergence attempts
    #
    if self.parameters.has_key( "nreconv" ):
      command.append("-nreconv")
      command.append( repr( self.parameters["nreconv"] ) )
      
    #
    # maximum num restart attempts
    #
    if self.parameters.has_key( "nrestart" ):
      command.append("-nrestart")
      command.append( repr( self.parameters["nrestart"] ) )
        
    #
    # 
    #
    if self.parameters.has_key( "channels" ):
      command.append("-channels")
      command.append( repr( self.parameters["channels"] ) )
      
    #
    # perform dimensionality reduction
    #
    if self.parameters.has_key( "pca" ):
      command.append("-pca")
      command.append( repr( self.parameters["pca"] ) )
        
    #
    # maximum number of dimensions
    #
    if self.parameters.has_key( "maxdim" ):
      command.append("-maxdim")
      command.append( repr( self.parameters["maxdim"] ) )
      
    #
    # output label probabilities
    #
    if self.parameters.has_key( "probs" ):
      command.append("-probs")
      command.append( repr( self.parameters["probs"] ) )
      
    #
    # perform deterministic annealing
    #
    if self.parameters.has_key( "annealing" ):
      command.append("-anneal")
      command.append( repr( self.parameters["annealing"] ) )
      
    #
    # perform monte carlo cross validation
    #
    if self.parameters.has_key( "xval" ):
      command.append("-xval")
      
    #
    # Don't normalize the data along dimensions
    #
    if self.parameters.has_key( "nonorm" ):
      command.append("-nonorm")

    #
    # initialize the means with mean vectors from this file
    #
    if self.parameters.has_key("initmeans"):
      command.append("-initmeans")
      command.append(os.path.abspath(repr(self.parameters["initmeans"])))
      
    
    command_line = space.join(command)

    return command_line

  


  def load_model_parameters(self):
    """Parse the output of rob's full em parameters file.

    The format is as follows things in <> are variable:
    --- start of file ---
    EM clustering of file <filename>

    EM mixture model parameters
    ---------------------------

    #<Cluster #>

    <mean vector>

    <covariance matrix>

    <repeat the above block for each cluster>

    class weights
    <weights>
    log-likelihood of model = <value>
    --- end of file ---
    """

    def parse_cluster_info(stream):
      """Parse cluster information."""

      blank_line = stream.readline()

      mean = map(float, re.split("\s+", string.strip(stream.readline())))
      blank_line = stream.readline()

      covar = []
      next_covar_line = string.strip(stream.readline())
      while len(string.strip(next_covar_line) ) > 0:
        covar.append(map(float, re.split("\s+", next_covar_line)))
        next_covar_line = string.strip(stream.readline())

      return (mean, covar)

    stream = open(self.parameters["p"], "r")

    # parse header
    file_header = stream.readline()
    file_header = re.sub("EM clustering of file ", "", file_header)
    file_header = re.sub("...", "", file_header)

    # skip some useless static header information
    for l in range(4):
      header = stream.readline()

    # parse cluster parameters
    means       = []
    covariances = []

    cluster_id_line = stream.readline()
    while re.match("^#", cluster_id_line):
      mean, covar = parse_cluster_info(stream)
      means.append(mean)
      covariances.append(covar)
      cluster_id_line = stream.readline()

    # parse class weights
    class_weights_label = stream.readline()
    if string.strip(class_weights_label) != "class weights":
      raise IOError("Parse error")

    weights = string.strip(stream.readline())
    weights = map(float, re.split("\s+", weights))

    # parse class log-likelihood
    likelihood = stream.readline()
    likelihood = re.sub("log-likelihood of model = ", "", likelihood)
    likelihood = float(likelihood)

    stream.close()

    means       = Numeric.array(means)
    covariances = Numeric.array(covariances)
    weights     = Numeric.array(weights)

    return MixtureOfGaussians(len(weights), means, covariances, weights)

if (__name__ == "__main__"):
  from compClust.mlx.wrapper import Launcher
  
  Launcher.main(sys.argv, FullEM())
