########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Author:  Lucas Scharenbroich
#
# Last Modified: Aug 10, 2001 
#

"""
The Terminator module contains functions which, when given an algorithm,
examin the Dataset, Labeling, and Model to determine if the algorithm
has reached some stopping criterion.

Every function in this module is a second-order function which returns
a function which will return true/false given a set or parameters. 
"""

import sys
import compClust.mlx.wrapper.MCCV


#
# These fuctions return the anonymous functions we use for the
# termination tests
#
# A Terminator is passed a node form the Hierarchical tree.  The only
# defined field of the node is node.algorithm which contains the algorithm.
# A terminator can add other fields as needed for statae information if
# need be.
#
# The reason for passing nodes rather than the algorithm objects themselves,
# is that it is sometimes desirable to change the algorithm during run time and
# statae variable can be place in the node rather than in the algorithm's
# parameters hash.
#

##
#
# Prologues
#
##

def clusterSize(size):
  """
  clusterSize(size)

  Returns false if the number of datapoints in a cluster falls below <size>
  """
  return lambda node : node.algorithm.getDataset().getNumRows() >= size


def PDRatio(ratio):
  """
  PDRatio(ratio)

  Returns false is the ratio of datapoints to dimensions falls below the
  threshold <ratio>
  """
  
  return lambda node : \
         node.algorithm.getDataset().getNumRows() / node.algorithm.getDataset().getNumCols() >= ratio


##
#
# Resets
#
##

##
#
# Epilogues
#
##

def turnOnMCCV(name, values, test, trials):
  """
  turnOnMCCV(name, values, test, train)

  Always returns true, but modifies the algorithm to start performing MCCV.
  
  name   = parameters name to perform MCCV over
  values = list of values
  test   = test fraction
  trials = number of trials
  """
  
  return lambda node : _turnOnMCCV(node, name, values, test, trials)

def _turnOnMCCV(node, name, values, test, trials):
  """
  Chains an MCCV wrapper around the existing node algorithm is it is not
  already an MCCV wrapper.
  """

  algorithm = node.algorithm
  
  if not isinstance(algorithm, compClust.mlx.wrapper.MCCV.MCCV):

    #
    # Grab the old parameters
    
    parameters = algorithm.getParameters()
    dataset    = algorithm.getDataset()

    #
    # Clear the dataset and parameters from the (now old) algorithm

    algorithm.setParameters(None)
    algorithm.setDataset(None)

    #
    # Add in the MCCV parameters

    parameters[ "mccv_parameter_name "  ] = name
    parameters[ "mccv_parameter_values" ] = values
    parameters[ "mccv_test_fraction"    ] = test
    parameters[ "mccv_num_trials"       ] = trials

    #
    # Construct a new MCCV wrapper

    node.algorithm = compClust.mlx.wrapper.MCCV.MCCV(dataset, parameters, algorithm)

  #
  # Continue as usual
  
  return 1


