########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Author: Lucas Scharenbroich
# 
# Original Implementation: June 22 by Lucas Scharenbroich

"""
Usage: TSplit.py parameter_filename input_filename output_filename

 Wrapper for the tsplit algorithm.

      Note:  The class labels will have the extension you specify on the
             command line and the tsplit intermediate file, if saved, will
             have a .gtr extension.

 Depends on the following environment variables:
   TSPLIT_COMMAND   (e.g., /proj/cluster_gazing2/bin/tsplit)

 Brief Algorithm Description:

 Required Parameters:  (note: the list enclosed in the brakets are possible
                              values each one of parameters can take )

         distance_metric    =  [correlation, euclidean, Bhattacharyya]

                    Bhattacharyya : takes into account of not only the
                                    difference between the two mean vectors,
                                    but also the distributions of the two
                                    groups of data points.

         agglomerate_method =  [none, native, size, clusterNumber]

                   none - do not agglomerate, just generate the
                   normal tsplit output files

                   native - use tsplit built in agglomeration to
                   produce as close to K clusters as possible

                   size - perform a size threshold agglomeration.
                   Starting at the root recurse through the tree
                   attempting to agglomerate at each node stopping
                   only when the number of genes in the agglomerated
                   sub-tree is less then the parameter "size_threshold"

                   clusterNumber - return as close to K clusters as possible
                   using the "size" agglomeration method to partition
                   the tree

                   size/clusterNumber agglomeration is identical to the
                   agglomeration used in xclust

         splitting_method   = ['PCA', 'Best'}

                    uses either PCA splitting or 'best' splitting, which
                    utilizes the energy parameter

         k                  = <x> 

                    where x is the target number of clusters

 Optional / Dependent Parameters:

         min_cluster_size   = <x> (required for 'size' and 'native'
                                   agglomeration)

                    where x is the minimun number of genes that will appear
                    is any given cluster.

         energy             = <x> (required if method = 'PCA') 

                    number in the range of (0, 100].  Indicating the quantity
                    of energy to preserve at each node.


         merge              = ['closest', 'prune']

                    Selects the method to merge node back to the target number
                    of clusters.  Only applicable if 'agglomerate_method' is
                    'native'.  'closest' is the default and merges nodes
                    together which are closest depending on the chosen
                    'distance_metric'.  'prune' simply merges sibling nodes
                    togvether. 

         save_intermediate_files = ['yes', 'no'] (default 'no')
 
                 if you choose yes, the generated tsplit files (.gtr) will be
                 saved.  Otherwise they will be deleted.
"""

import os
import sys
import tempfile
import string

from compClust.util.TimeStampedPrintStream import TimeStampedPrintStream

from compClust.util   import Verify
from compClust.util   import WrapperUtil

from compClust.mlx.labelings import Labeling
from compClust.mlx.models import DistanceFromMean

from compClust.mlx.ML_Algorithm import ML_Algorithm

from compClust.mlx.XClustTree import XClustTree
from compClust.mlx.wrapper.TreeAgglomerator import TreeAgglomerator

import compClust.mlx.wrapper

MESSAGE_STREAM = TimeStampedPrintStream("%Y-%b-%d %H:%m: TSplit: ")

class TSplit(ML_Algorithm):
  def __init__(self, dataset = None, parameters = None):
    self.setMessageStream( MESSAGE_STREAM )
    self.dataset    = dataset
    self.parameters = parameters
    self.model = None
    self.labeling = None
    
    self.default_tempdir = tempfile.gettempdir()
    
  def copy(self):
    new_obj = TSplit(self.dataset, self.parameters)
    new_obj.labeling = self.labeling
    new_obj.model = self.model
    return new_obj
  
  def getLabeling(self):
    return self.labeling


  def getModel(self):
    if self.model is None:
      dataset = self.dataset
      labeling = self.labeling
      self.model = DistanceFromMean(data=dataset, labels=labeling)
    return self.model
      
      
  def run(self):
    """run(self)
      
    Prepares the inputs to the clustering algorithm (tsplit) and runs it.
    """

    #
    # Invalidate the current model
    #

    self.model = None
    
    #
    # Set default min_cluster_size equal to the dimentionality of the dataset
    #
    if not self.parameters.has_key("min_cluster_size"):
      self.parameters["min_cluster_size"] = self.dataset.getNumCols()
      
    #
    # Creates a temporary directory for kmeans input and output files
    #

    tempfile.tempdir = WrapperUtil.create_temporary_directory("tsplit_")

    #  
    # Prepare data file and store in temporary location.
    #

    self.parameters["clusterInputFilename"]   = tempfile.mktemp(".i")
    self.parameters["clusterParamFilename"]   = tempfile.mktemp(".p")
    self.parameters["clusterOutputFilename"]  = tempfile.mktemp(".o")
    self.parameters["clusterResultsFilename"] = tempfile.mktemp(".r")
    
    outputFilename                = self.parameters["clusterOutputFilename"]
    outputBasename, outputExtname = os.path.splitext(outputFilename)

    outputGTRFilename  = outputBasename + ".gtr"
    outputTreeFilename = outputBasename + ".tree"

    #
    # Build the input file
    #

    self.create_clustering_input_file()

    #
    # Create tsplit command-line.
    #

    command_line = self.create_clustering_command_line()

    #
    # Launch tsplit.
    #

    os.system(command_line)

    #
    # If the algorithm itself has done agglomeration, simply copy the
    # output file to the result file, otherwise read in the GTR file
    # and perform our own agglomeration
    #

    destStream = open(self.parameters["clusterResultsFilename"],"w")

    if self.parameters["agglomerate_method"] != "none" and \
       self.parameters["agglomerate_method"] != "native":

       #
       # Build the tree and perform the agglomeration
       #

       tree     = XClustTree()
       numNodes = tree.read(outputGTRFilename)
       agglomerator = TreeAgglomerator(tree) 
    
       if self.parameters['agglomerate_method'] =='clusterNumber':
         clusters = agglomerator.getKClusters(self.parameters['k'])

       elif self.parameters['agglomerate_method'] == 'size':
         threshold = self.parameters['size_threshold']
         clusters = agglomerator.agglomerateWithSizeThreshold(threshold)

       #
       # Now use the XClustTree to overwrite the existing
       # clustering_output_file
       #

       foo = [0] * self.dataset.getNumRows()

       for cls in range(len(clusters)):
          for gene in clusters[cls]:
             index = int(filter(lambda x: x in string.digits, gene.key()))
             foo[index] = cls

       for i in range(len(foo)):
          destStream.write(`foo[i]` + "\n")

    else:

      source_stream = open(outputFilename)
      for line in source_stream.readlines():
         destStream.write(line)

    destStream.close()

    #
    # Load the clustering results into a Labeling and constrict a model
    # from that Labeling
    #

    dataset  = self.dataset

    resultFilename = self.parameters["clusterResultsFilename"]
    resultBasename, resultExtname = os.path.splitext(resultFilename)
 
    if os.access(resultFilename, os.F_OK):
       stream   = open(resultFilename, "r")
       text     = map(string.strip, stream.readlines())
 
       self.labeling = Labeling(self.dataset)
       self.labeling.labelRows(text)
 
       stream.close()
    
    #
    # Copy intermediate "xclust" output if desired
    #
      
    if self.parameters['save_intermediate_files'] == 'yes':
      if self.parameters.has_key('save_intermediate_files_base'):
        saveBase = self.parameters['results_dir']
        saveBase += os.sep
        saveBase += self.parameters['save_intermediate_files_base']
        saveTree = saveBase + ".tree"
        saveGTR  = saveBase + ".gtr"

        os.system("cp " + outputGTRFilename  + " " + saveGTR)
        os.system("cp " + outputTreeFilename + " " + saveTree)
                
    #
    # Cleanup temporary files.
    #

    files = os.listdir( tempfile.tempdir )
    for file in files:
      os.remove(os.path.join( tempfile.tempdir, file ))
    os.rmdir ( tempfile.tempdir )

    #
    # return tempfile.tempdir to its default tempdir
    #

    tempfile.tempdir = self.default_tempdir

    return compClust.mlx.wrapper.WRAPPER_STATUS_DONE

 
  def create_clustering_input_file(self):
    """create_clustering_input_file(self)
    
    """

    #
    # The input file is raw, tab-delimited data
    #
    
    destStream   = open(self.parameters["clusterInputFilename"], "w")    
    self.dataset.writeDataset(destStream)
    destStream.close()

    #
    # The paramater file contains name/value pairs, one per line
    #

    num_rows = self.dataset.getNumRows()
    num_cols = self.dataset.getNumCols()
    
    destStream = open(self.parameters["clusterParamFilename"], "w")

    destStream.write("rows %s\n" % num_rows)
    destStream.write("cols %s\n" % num_cols)

    if self.parameters["agglomerate_method"] == "native":
      destStream.write("min_size %s\n" % self.parameters["min_cluster_size"])
      destStream.write("k %s\n" % self.parameters["k"])
    else:
      destStream.write("k %s\n" % num_rows)

    if self.parameters["splitting_method"] == "PCA":
      destStream.write("mode pca\n")
      destStream.write("energy %s\n" % self.parameters["energy"])
    else:
      destStream.write("mode best\n")

    destStream.write("distance_metric ")
    if self.parameters["distance_metric"] == "correlation":
      destStream.write("correlation\n")
    elif self.parameters["distance_metric"] == "Bhattacharyya":
      destStream.write("bhattacharyya\n")
    else:
      destStream.write("euclidean\n")


    destStream.write("mtype ")
    if self.parameters["merge"] == "prune":
      destStream.write("1\n")
    else:
      destStream.write("0\n")
    

    destStream.close()


  def validate(self):
    """validate_parameters(self)
      
    Ensures that all parameters and environment variables nescessary
    to run the clustering algorithm (tsplit) are defined.
    """

    #
    # These are the _required_ parameters needed.
    #

    environment_names = [ "TSPLIT_COMMAND" ]
    parameter_names   = [ "distance_metric",
                          "agglomerate_method",
                          "splitting_method",
                          "k" ]
      
    err = 0
      
    if Verify.environment_variables_exist( environment_names ):
      return 0
      
    if Verify.parameters_exist( parameter_names, self.parameters ):
      return 0

    if self.parameters["agglomerate_method"] == "native":
      if Verify.parameters_exist(['k'], self.parameters):
        err = 1

    if self.parameters["splitting_method"] == "PCA":
      if Verify.parameters_exist(['energy'], self.parameters):
        err = 1

    if not self.parameters.has_key("save_intermediate_files"):
      self.parameters["save_intermediate_files"] = "no"

    if not self.parameters.has_key("merge"):
      self.parameters["merge"] = "closest"
         
    return not err
      
      
  def create_clustering_command_line(self):
    """command_line = create_clustering_command_line(self)
     
    """
     
    space        = " "
    command      = []
    command_line = ""
  
    #
    # TSplit expects three command line parameters, an input filename,
    # a parameter filename and an output filename
    #
    
    command.append(os.environ["TSPLIT_COMMAND"])
  
    #
    # Input filename
    #
  
    command.append(self.parameters["clusterInputFilename"])
    
    #
    # Parameter filename
    #
  
    command.append(self.parameters["clusterParamFilename"])
  
    #
    # Output file name (the final file name)
    #
  
    command.append(self.parameters["clusterOutputFilename"])
   
    #
    # Send stdout output to the bit bucket
    #
  
    command_line = space.join(command)
    
    return command_line
  
if (__name__ == "__main__"):
  from compClust.mlx.wrapper import Launcher

  Launcher.main(sys.argv, TSplit())
