########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Benjamin J. Bornstein
#                Christopher Hart
#                Lucas Scharenbroich
#
# Last Modified: July 18, 2001 
#

"""
Usage: XClust.py parameter_filename input_filename output_filename

 Wrapper for XClust algorithm

     Note:  The class labels will have the extension you specify on the
             command line and the two xclust intermediate files if
             saved will have a .cdt and .gtr extension.
 
 Depends on the following environment variables:
   XCLUST_COMMAND   (e.g., /proj/cluster_gazing2/bin/xclust)

 Brief Algorithm Description:

             See: Eisen et. al. 1998, PNAS 95 (25) for a description of xclust
                  we have implemented various methods to construct a
                  partioning based of the phylogenetic tree output from xclust

 Required Parameters:  (note: the list enclosed in the brakets are possible
                              values each one of parameters can take )

         transform_method   =  [log, none]
         distance_metric    =  [correlation,correlation_centered, euclidean]
         cluster_on         =  [columns, rows]

                    rows   : cluster Genes
                    columns: cluster conditions

         agglomerate_method =  [none, size, clusterNumber]

                   none - do not agglomerate, just generate the
                   normal xclust output files

                   size - perform a size threshold agglomeration.
                   Starting at the root recurse through the tree
                   attempting to agglomerate at each node stopping
                   only when the number of genes in the agglomerated
                   sub-tree is less then the parameter "size_threshold"

                   clusterNumber - return as close to K clusters as possible
                   using the "size" agglomeration method to partition
                   the tree

 Optional / Dependent Parameters:

        size_threshold = <x>  (required if agglomerate_method = size)

                   where x is an interger and 0 < x <= # Of Data Vectors
                   see above for a complete description.

        k = <x> (required if agglomerate_method = "clusterNumber"
                   where  x is the number of clusters you'd like

        save_intermediate_files = ['yes', 'no'] 
                 (required if agglomerate_method doesn't equal "none")
 
                 if you choose yes, the original xclust files (both
                 the cdt and gtr) will be saved.  Otherwise they will
                 be deleted.

       use_intermediate_files = ['yes', 'no']
                 if yes, will load an existing GTR/CTD file instead of
                 building the tree again.  Will detemine the
                 intermediate filenames either from the input file (if
                 run from the shell wrapper) or the
                 save_intermediate_files_base and results_dir optional
                 parameters

        save_intermediate_files_base = 'filename'
                The base filename for intermidate file (ie, no GTR or CTD file)

        results_dir
                Location of results.


"""
        

import os
import re
import sys
import string
import tempfile
import types

from compClust.util.TimeStampedPrintStream import TimeStampedPrintStream

from compClust.util   import Verify
from compClust.util   import WrapperUtil

from compClust.mlx.labelings import Labeling, ClusteredLabeling
from compClust.mlx.models import DistanceFromMean
from compClust.mlx.ML_Algorithm import ML_Algorithm
from compClust.mlx.XClustTree import XClustTree

from compClust.mlx.wrapper.TreeAgglomerator import TreeAgglomerator
import compClust.mlx.wrapper

from compClust.util.WrapperParameterDescription import WrapperParameterDescription as WPD

MESSAGE_STREAM = TimeStampedPrintStream("%Y-%b-%d %H:%m: XClust: ")

def getDefaultParameterDescriptions():
  wrapper_parameters = {}
  
  # Required
  wrapper_parameters['transform_method'] = WPD('none', types.ListType, 'Transform method, either log or none', True, ['log', 'none'])
  wrapper_parameters['distance_metric'] = WPD('euclidean', types.ListType, 'Distance metric to use', True, ['correlation', 'correlation_centered', 'euclidean'])
  wrapper_parameters['cluster_on'] = WPD('rows', types.ListType, 'Run xclust on rows or columns', True, ['rows', 'columns'])
  wrapper_parameters['agglomerate_method'] = WPD('none', types.ListType, 'how to agglomerate the xclust tree. None performs no agglomeration. Size performs a size threshold agglomeration, starting the root recurse through the tree attempting to aggloberate at each node stopping only when the number of genes in the agglomerated sub-tree is less than the parameter "size_threshold". clusterNumber return as close to K clusters as possible using the "size" agglomeration method to partition the tree.', True, ['none', 'size', 'clusterNumber'])
  
  
  # Optional
  wrapper_parameters['size_threshold'] = WPD(0, types.IntType,'an integer between 0 and number of data vectors. required if agglomerate_method is size', False)
  wrapper_parameters['k'] = WPD(2, types.IntType, 'number of clusters to look for, required if agglomerate method is clusterNumber', False)
## NOTE: the following parameters are perahps to advanced to be shown in a gui
#  wrapper_parameters['save_intermediate_files'] = WPD('no', types.ListType, 'save xclust cdt and gtr files', False, ['no', 'yes'])
#  wrapper_parameters['use_intermediate_files'] = WPD('no', types.ListType, 'load an existing GTR/CTD file instead of building the tree again. Will determine the intermediate filenames either from the input file or the save_intermediate_files_base and results_dir optional parameters', False, ['no', 'yes'])
#  wrapper_parameters['save_intermediate_files_base'] = WPD('', types.StringType, 'the base filename for intermediate GTR and CTD files.', False)
#  wrapper_parameters['results_dir'] = WPD('', types.StringType, 'directory to save result files', False)
  return wrapper_parameters

#
# Wrapper for the xclust algorithms.  This unsupervised clustering algorithm
# builds a minimum spanning tree from the dataset
#

class XClust(ML_Algorithm):
  def __init__(self, dataset = None, parameters = None):
    self.setMessageStream( MESSAGE_STREAM )
    self.dataset    = dataset
    self.parameters = parameters
    self.labeling   = None
    self.model      = None

    self.default_tempdir = tempfile.gettempdir()

  def copy(self):
    new_obj = XClust(self.dataset, self.parameters)
    new_obj.labeling = self.labeling
    new_obj.model = self.model
    return new_obj
   
  def getLabeling(self):
    return self.labeling

  def getModel(self):
    if self.model is None:
      dataset = self.dataset
      labeling = self.labeling
      self.model = DistanceFromMean(data=dataset, labels=labeling)
    return self.model

  def run(self):
    """run()

    Prepares the inputs to the clustering algorithm (XClust) and runs it.
    
    """

    #
    # Invalidate the current model
    #

    self.model = None

    #
    # Create a temporary directory for xclust input and output files
    #

    WrapperUtil.tempfile.tempdir = WrapperUtil.create_temporary_directory("XClust_")
    
    #  
    # Prepare data file and store in temporary location.
    #

    inputFilename = WrapperUtil.tempfile.mktemp(".tmp")
    self.parameters[ "clustering_input_filename" ] = inputFilename
    
    outputBasename, outputExtname = os.path.splitext( inputFilename   )
    if self.parameters['cluster_on'] == 'columns':
      suffix = '.atr'
    else:
      suffix = '.gtr'
    outputGTRFilename = outputBasename + suffix
    outputCDTFilename = outputBasename + ".cdt"

    clusterResultsFilename = WrapperUtil.tempfile.mktemp(".r")

    
    self.create_clustering_input_file()
    
    #
    # Create XClust command-line.
    #

    command_line = self.create_clustering_command_line()

    #
    # Launch XClust.
    #
    if self.parameters.get('use_intermediate_files', 'no') == 'no':
      basename = outputBasename
    else:
      basename = os.path.join(self.parameters['results_dir'], self.parameters['save_intermediate_files_base'])
 
    if (self.parameters.get('use_intermediate_files', 'no') == 'no' or
        (self.parameters.get('use_intermediate_files') == 'yes' and 
         not os.path.exists(basename+'.gtr'))):
      os.system(command_line)

 
    #
    # Copy intermediate "xclust" output if desired
    #
      
    if self.parameters['save_intermediate_files'] == 'yes':
      if self.parameters.has_key('save_intermediate_files_base'):
        saveBase = self.parameters['results_dir']
        saveBase += os.sep
        saveBase += self.parameters['save_intermediate_files_base']
        saveCDT  = saveBase + ".cdt"
        saveGTR  = saveBase + suffix
        os.system("cp " + outputGTRFilename + " " + saveGTR)
        os.system("cp " + outputCDTFilename + " " + saveCDT)


         

    #
    # Perfrom Agglomeration if requested
    #

    if self.parameters['agglomerate_method'] is not "none":
     self.agglomerate(clusterResultsFilename,  
                       basename + ".cdt", 
                       basename + suffix)
      
    #
    # Load in the clustering results and build a model
    #

    if os.access( clusterResultsFilename, os.F_OK):
      stream   = open(clusterResultsFilename, "r")
      text     = map(string.strip, stream.readlines())
      
      self.labeling = ClusteredLabeling(self.dataset, self.__class__, self.parameters)
      self.labeling.labelRows(text) 

      stream.close()

      
    #
    # Cleanup temporary files and directory.
    #

    files = os.listdir( WrapperUtil.tempfile.tempdir )
    for file in files:
      os.remove(os.path.join( WrapperUtil.tempfile.tempdir, file ))
    os.rmdir ( WrapperUtil.tempfile.tempdir )

    #
    # return tempfile.tempdir to its default tempdir
    #
    tempfile.tempdir = self.default_tempdir
      
    return compClust.mlx.wrapper.WRAPPER_STATUS_DONE
    
  def create_clustering_input_file(self):
    """create_clustering_input_file()
    """

    destination_file = open(self.parameters["clustering_input_filename"],'w')

    #
    # The dataset is in memory, so used that to create the
    # proper file on disk
    #

    num_rows     = self.dataset.getNumRows()
    row_range    = range( num_rows )
    
    num_columns  = self.dataset.getNumCols()
    column_range = range( num_columns )

    data         = self.dataset.getData()
    labels       = map(str, self.dataset.getRowKeys())
    
    #
    # Column Names
    #
    destination_file.write("UID\tNAME\tGWEIGHT");
    
    for column_index in column_range:
      destination_file.write('\t')
      destination_file.write( repr(column_index) )

    destination_file.write('\n')
    destination_file.write("EWEIGHT\t\t")

    #
    # Column Weights
    #
    for column_index in column_range:
      destination_file.write('\t')
      destination_file.write('1')

    destination_file.write('\n')

    #
    # Rows: Rowname, Annotations (blank), Row Weights (1), Data Values...
    #

    for row_index in row_range:
      destination_file.write( labels[row_index] )
      destination_file.write( '\t' )
      destination_file.write( '\t' )
      destination_file.write( '1'  )

      #
      # Data Values
      #
      for column in data[row_index][0:]:
        destination_file.write( '\t' )
        destination_file.write( str(column)  )

      destination_file.write( '\n' )
      
    destination_file.close()

  def create_clustering_command_line(self):
    """command_line = create_clustering_command_line()
    """

    space        = " "
    command      = []
    command_line = ""
    
    command.append( os.environ["XCLUST_COMMAND"] )
    
    #
    # Transform Method: "log" or "none"
    #
    
    command.append("-l")
    if ( self.parameters["transform_method"] == "log" ):
      command.append('1')
    else:
      command.append('0')

    #
    #    Partition Method: None
    # Randomize Partition: No
    #
    command.append("-s 0")
    command.append("-r 0")

    #
    # Distance Metric: "correlation", "correlation_centered", or "euclidean"
    #
    command.append("-p")
    if ( self.parameters["distance_metric"] == "correlation"          or \
         self.parameters["distance_metric"] == "correlation_centered" ):
      command.append('1')
    else:
      command.append('0')

    #
    # Cluster On: "columns" or "rows"
    #
    if ( self.parameters["cluster_on"] == "columns" ):

      command.append("-g 0")
      command.append("-e")

      if ( self.parameters["distance_metric"] == "correlation_centered" ):
        command.append('2')
      else:
        command.append('1')

    else:

      command.append("-e 0")
      command.append("-g")

      if ( self.parameters["distance_metric"] == "correlation_centered" ):
        command.append('2')
      else:
        command.append('1')

    command.append("-f")
    command.append( self.parameters["clustering_input_filename"] )

    command_line = space.join(command)

    return command_line

  def agglomerate(self, outputFilename, CDTFilename, GTRFilename):

    """
    agglomerate(outputFilename, CDTFilename. GTRFilename)

    Agglomerated over the specified GTR file and writes the output to
    outputFilename.  The CDT file is used for labels only.
    """
    
    MESSAGE_STREAM.write("agglomerating data...\n")

    #
    # Read in the CDT file
    #

    leafAndGeneNameIndex = {}
  
    CDTinfile = open(CDTFilename, 'r')

    #
    # skip the first line
    #

    CDTinfile.readline() 

    for line in CDTinfile.readlines():
      tokens = string.split(line, '\t')
      leafAndGeneNameIndex[tokens[0]] = tokens[1]
      leafAndGeneNameIndex[tokens[1]] = tokens[0]
    
    CDTinfile.close()

    tree     = XClustTree()
    numNodes = tree.read(GTRFilename)
    agglomerator = TreeAgglomerator(tree) 
    
    if self.parameters['agglomerate_method'] =='clusterNumber':
      clusters = agglomerator.getKClusters(self.parameters['k'])

    elif self.parameters['agglomerate_method'] == 'size':
      threshold = self.parameters['size_threshold']
      clusters = agglomerator.agglomerateWithSizeThreshold(threshold)

    elif self.parameters['agglomerate_method'] == 'none':
      clusters = agglomerator.agglomerateWithSizeThreshold(1)

    #
    # create the class labels from the CDT file information
    # Use the key field from the genes
    #
    
    classLabels = {}

    classCount = 0
    for cluster in clusters:
      for gene in cluster:
        classLabels[leafAndGeneNameIndex[gene.key()]] = str(classCount)
      classCount += 1

    #
    # Write out the labeling
    #
    
    labels  = map(str, self.dataset.getRowKeys())
    outfile = open(outputFilename, 'w')

    for name in labels:
      outfile.write("%s\n" % (classLabels[name]))

    outfile.close()
    
    MESSAGE_STREAM.write("done agglomerating data\n")

  def validate(self):
    """validate()

    Ensures that all parameters and environment variables nescessary
    to run the clustering algorithm (XClust) are defined.
    """

    environment_names = [ "XCLUST_COMMAND" ]
    parameter_names   = [ "transform_method"       ,
                          "distance_metric"        ,
                          "cluster_on"             ,
                          "agglomerate_method"    
                        ]

    err = 0
    
    if Verify.environment_variables_exist( environment_names ):
      return 0
      
    if Verify.parameters_exist( parameter_names, self.parameters ):
      return 0

    if not self.parameters.has_key("save_intermediate_files"):
      self.parameters["save_intermediate_files"] = "no"
      
    if self.parameters['agglomerate_method'] != "none":
      if Verify.parameters_exist(['save_intermediate_files'], self.parameters):
        err = 1
      
      if self.parameters['agglomerate_method'] == "size":
        if Verify.parameters_exist(['size_threshold'], self.parameters):
          err = 1
          
      if self.parameters['agglomerate_method'] == "clusterNumber":
        if Verify.parameters_exist(['k'], self.parameters):
          err = 1

    #
    # If we cannot execute the command, fail
    #
    
    if Verify.fs_objects_have_permissions( os.environ['XCLUST_COMMAND'],
                                           os.X_OK ) == 0:
      err = 1

    return not err
  
if (__name__ == "__main__"):
  from compClust.mlx.wrapper import Launcher

  Launcher.main(sys.argv, XClust())

