########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
# Filename     : SVM.py
# Description  : Support Vector Machine Wrapper
# Author(s)    : Ben Bornstein
# Organization : Machine Learning Systems, Jet Propulsion Laboratory
# Created      : February 2002
# Revision     : $Id: SVM.py,v 1.8 2004/04/02 02:40:14 diane Exp $
# Source       : $Source: /proj/CVS/code/python/compClust/mlx/wrapper/SVM.py,v $
#

"""
Support Vector Machine (SVM)

"""

import os
import string
import tempfile

import Numeric

import compClust.mlx.wrapper
import compClust.util

from compClust.mlx.Supervised import Supervised
from compClust.mlx.Supervised import SupervisedModel

from compClust.mlx.labelings import Labeling


class SVM (Supervised):
  """SVM

  Support Vector Machine.

  """

  def __init__(self, dataset=None, labeling=None, parameters=None, model=None):
    """SVM(dataset, labeling, parameters) or SVM(dataset, model)

    """
    Supervised.__init__(self, dataset, labeling, parameters, model)

    self.__setDefaultParameters()

    if self.model is None:
      self.model = SVMModel(labeling, parameters['kernel'], parameters['C'])

    return None
  

  def run(self):
    """run() -> status code

    Runs the SVM wrapper and returns a status code.

    If getMode() == 'learn', getModel() will return a trained SVM
    model.  If getMode() == 'predict', getLabeling() will return the
    predicted target values for the given dataset, based on the given
    model.

    """
    self.__createTempFiles()
    self.__writeTempFiles()

    if self.__run() != 0:
      status = compClust.mlx.wrapper.WRAPPER_STATUS_ERROR
    else:
      status = compClust.mlx.wrapper.WRAPPER_STATUS_DONE
      self.__readTempFiles()

    self.__removeTempFiles()

    return status


  def validate(self):
    """validate() -> boolean (i.e. 0 | 1)

    """
    error = 0

    if compClust.util.Verify.environment_variables_exist( [ 'SVM_TOOLBOX_HOME' ] ):
       error = 1

    return not error


  def __createTempFiles(self):
    """__createTempFiles()

    Creates a temporary directory (e.g. /tmp/SVM*) and within that PSVM
    startup.m and model.mat files.  The following are set appropriately:

      - self.commandFilename
      - self.modelFilename
      - self.outputFilename

    """
    tempfile.tempdir = \
      compClust.util.WrapperUtil.create_temporary_directory('SVM')
    
    self.commandFilename = os.path.join( tempfile.tempdir, 'startup.m' )
    self.modelFilename   = os.path.join( tempfile.tempdir, 'model.mat' )
    self.outputFilename  = os.path.join( tempfile.tempdir, 'output.m'  )

    return None


  def __readOutputPredictions(self, stream):
    """__readOutputPredictions(stream) -> Labeling

    """
    reverseLabelMap = self.model.getReverseLabelMap()
    labeling        = Labeling(self.dataset, 'SVM Predictions')
    row             = 0

    predictions = map(float, stream.readlines())

    for prediction in predictions:
      label = reverseLabelMap[prediction]
      labeling.addLabelToRow(label, row)
      row = row + 1

    return labeling


  def __readTempFiles(self):
    """__readTempFiles()

    """
    if self.mode == 'learn':
      stream            = open(self.modelFilename, 'r')
      self.model.binary = stream.read()
      stream.close()

    else:
      stream        = open(self.outputFilename, 'r')
      self.labeling = self.__readOutputPredictions(stream)
      stream.close()

    return None


  def __removeTempFiles(self):
    """__removeTempFiles()

    Removes the temporary directory and files created by
    __createTempFile().  The self.*Filename variables are cleared.

    """
    for file in os.listdir( tempfile.tempdir ):
      os.remove( os.path.join(tempfile.tempdir, file) )

    os.rmdir(tempfile.tempdir)

    self.commandFilename  = None
    self.modelFilename    = None
    self.outputFilename   = None

    #
    # Restore tempfile.tempdir to its default.
    #
    tempfile.tempdir = self.default_tempdir


  def __run(self):
    """__run() -> os.WEXITSTATUS value

    Runs PSVM by calling operating system services.  Returns the Matlab /
    PSVM exit status: 0 indicates success, non-zero indicates failure.

    """
    os.chdir(tempfile.tempdir)
    return os.WEXITSTATUS( os.system("matlab -nodisplay -nojvm > /dev/null") )


  def __setDefaultParameters(self):
    """__setDefaultParameters()

    Creates self.parameters if necessary and assigns reasonable
    default values for any unset parameters.

    """
    if self.parameters is None:
      self.parameters = {}

    parameters = self.parameters

    parameters.setdefault( 'C'     ,  '1'     )
    parameters.setdefault( 'kernel', 'linear' )

    return None


  def __writeCommandsToLearn(self, stream):
    """__writeCommandsToLearn(stream)

    Writes Matlab PSVM commands file to the given stream from the values in
    the internal parameters dictionary and SVMModel.

    The commands written are specific to learning a mapping from a dataset
    to a labeling.  See also __writeCommandsToPredict().

    """
    stream.write( "addpath('%s');\n" % os.environ['SVM_TOOLBOX_HOME'] )

    self.__writeDataset( stream, self.dataset  )
    self.__writeTargets( stream, self.labeling )

    stream.write( "C      =  %s; \n" % str( self.parameters[ 'C'      ] ) )
    stream.write( "kernel = '%s';\n" % str( self.parameters[ 'kernel' ] ) )
    stream.write( "model  = PSVM(dataset, targets, C, kernel);\n"         )
    stream.write( "save('%s', 'model')\n" % self.modelFilename            )
    stream.write( "quit;\n"                                               )

    return None


  def __writeCommandsToPredict(self, stream):
    """__writeCommandsToPredict(stream)

    """
    stream.write( "addpath('%s');\n" % os.environ['SVM_TOOLBOX_HOME'] )

    self.__writeDataset( stream, self.dataset )

    stream.write( "load('%s');\n"          % self.modelFilename    )
    stream.write( "targets = sign( PSVM(model, dataset) );\n"      )
    stream.write( "save %s targets -ASCII\n" % self.outputFilename )
    stream.write( "quit;\n"                                        )

    return None


  def __writeDataset(self, stream, dataset):
    """__writeDataset(stream, dataset)

    """
    numRows = dataset.getNumRows()
    data    = dataset.getData()

    stream.write( "dataset = [\n" )

    for row in range(numRows):
      features = map(str, data[row, :])

      stream.write( "[" )
      stream.write( string.join(features, '\t') )
      stream.write( "];\n" )

    stream.write( "]';\n" )
    return None


  def __writeTargets(self, stream, labeling):
    """__writeTargets(stream, labeling)

    """
    labelMap = self.model.labelMap
    labels   = labeling.getLabelByRows()
    targets  = []

    for label in labels:      
      targets.append( labelMap[label] )

    targets = map(str, targets)
    targets = string.join(targets, '\n')

    stream.write( 'targets = [%s];\n' % targets)
    return None


  def __writeTempFiles(self):
    """__writeTempFiles()

    Writes the uwbp .patterns, .command, and if self.mode ==
    'predict', .weights files.

    """
    stream = open(self.commandFilename , 'w')

    if self.mode == 'learn':
      self.__writeCommandsToLearn(stream)
    else:
      self.__writeCommandsToPredict(stream)

      modelStream = open(self.modelFilename, 'w')
      modelStream.write(self.model.binary)
      modelStream.close()

    stream.close()

    return None




class SVMModel (SupervisedModel):
  """SVMModel

  Support Vector Machine (SVM) Model is currently a convenient container
  with simple public attributes for the most critical information about a
  support vector machine and the model it produces.

  """

  def __init__(self, labeling, kernel='linear', C=1):
    """SVMModel(labeling, kernel='linear', C=1) -> SVMModel

    Creates a new SVMModel based on the given labeling, kernel type and C
    parameter.

    Only the SVM wrapper should creat an SVMModel.

    The unique labels in labeling is used to create a labelMap and set the
    number of output units.

    """

    self.labelMap        = self.__createLabelMap(labeling)
    self.reverseLabelMap = self.__createReverseMap(self.labelMap)

    self.kernel    = kernel
    self.C         = C
    self.Xk        = []
    self.trainSECS = 0
    self.HTH       = []
    self.v         = []
    self.rids      = []
    self.w         = []
    self.b         = 0
    self.evalFunc  = 'PSVM'
    self.binary    = None

    return None


  def __createLabelMap(self, labeling):
    """__createLabelMap(labeling) -> dictionary

    Creates a default SVM LabelMap, a dictionary keyed on each unique label
    whose value is the either -1 or 1 (two-class case).

    """
    
    labels    = labeling.getLabels()
    numLabels = len(labels)
    labelMap  = {}

    for n in range(numLabels):

      if n == 0:
        c = -1
      else:
        c = 1

      labelMap[ labels[n] ] = c

    return labelMap


  def __createReverseMap(self, dictionary):
    """__createReverseMap(dictionary) -> dictionary

    Creates and returns a dictionary with the given dictionary's keys and
    values swapped.  This method is used to create a reverse LabelMap, such
    that a label can be looked-up given an integer.

    See also __createLabelMap().

    """
    reverseMap = {}
    items      = dictionary.items()

    for item in items:
      reverseMap[ item[1] ] = item[0]

    return reverseMap




#
# FIXME: Launcher assumes usage: command <parameters> <input> <output>,
# FIXME: but should now handle targets and models.
#

#
# if __name__ == "__main__":
#   Launcher.main(sys.argv, SVM())
#
