########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Diane Trout, Ben Bornstein, Lucas Scharenbroich
# Last Modified: 12-Apr-2003, 10:45a
#

"""
Utility functions for operating on files and directories.
"""

import os
import re
import sys
import string
import tempfile
import getopt

from compClust.util import Usage

ERROR_STREAM = sys.stderr

#
# FIXME: The ERROR_STREAM variable, and get_error_stream() and
# FIXME: set_error_stream() methods will probably be common across all
# FIXME: these modules.  Can this be re-factored somehow?
#


def create_clustering_input_file(dataset, destination_directory, header=None):
  """dest_pathname = create_clustering_input_file(dataset,
                                                  destination_directory,
                                                  header)

  This function writes out an input dataset into the input format
  required by a clustering algorithm.

  The data file format is of the following form:
    <row name><TAB><data_point_1><TAB><data_point_2>...<TAB><data_point_n>

  However clustering algorithms only process matrices, so we need to remove
  the initial <row name><TAB> character sequence.
  """


  #data_filename                  = os.path.basename(data_pathname)
  #data_basename, data_extenstion = os.path.splitext(data_filename)
  #data_basename = "cluster_input"
  #
  #destination_pathname = destination_directory + \
  #                       data_basename         + \
  #                       ".tmp"

  # Create a temporarly file to save the input into.

  #tempfile.tempdir     = destination_directory
  # WARNING: tempfile.template is deprecated since python 2.0
  tempfile.template    = "cluster_input@"
  destination_pathname = tempfile.mktemp(".tmp")
  
  destination_stream = open( destination_pathname, 'w' )

  dataset.writeDataset(destination_stream)
  
  destination_stream.close()
  
  return destination_pathname


def create_temporary_directory(dir_name_prefix="tmp"):
  """temp_dir_name = create_temporary_directory(dir_name_prefix=\"tmp\")
  
  Creates a temporary directory with an optional directory name prefix.

  The full-name of the temporary directory created (with a trailing
  delimiter) is returned.

  If a temporary directory could not be created, an IOError exception
  is raised.
  """

  tempfile.template = None
  tempfile.tempdir=None
  return tempfile.mkdtemp(prefix=dir_name_prefix)


def get_error_stream():
  """stream = get_error_stream()

  Returns the error stream used by the FileIO module.
  """

  return ERROR_STREAM


def set_error_stream(stream):
  """set_error_stream(stream)

  Sets the error stream used by the compClust.util.FileIO module to stream.
  The stream object must have a write method defined.  If it does not
  an AttributeError exception is thrown.
  """

  global ERROR_STREAM

  try:
    getattr( stream, "write" )
    ERROR_STREAM = stream
  except AttributeError:
    raise AttributeError,                               \
          "compClust.util.WrapperUtil.set_error_stream(): " + \
          "stream must contain a write method."


def remove_column_from_file(column_to_remove,
                            source_stream,
                            destination_stream,
                            separator="\t"):
  """
  num_rows, num_cols = remove_column_from_file(column_to_remove,
                                               source_stream,
                                               destination_stream,
                                               separator='\t')
                                                    
  Removes the Nth column (column_to_remove) of the TAB delimited file
  available for reading from source_stream.  The resulting stream is
  written to destination_stream.  The number of rows and columns in
  the resulting destination file are returned.
  """

  num_rows         = 0
  num_cols         = 0

  source_line = source_stream.readline()

  # If there is trailing whitespace the computation for the number
  # of columns will be off. This will cause some of the clustering
  # algorithms to fail.
  #
  # Stripping off the trailing whitespace should fix most of the
  # cases.
  #
  # Unfortunatly we then need to re-add the trailing newline
  #  DET 2001-may-30

  columns     = re.split(separator, string.strip(source_line))

  num_cols = len(columns)

  #
  # create the entire list of indices and the delete the one we
  # want to remove
  #
  selected_indices = range( num_cols )
  del selected_indices[column_to_remove]
  num_cols -= 1

  if num_cols == 0:
    return None

  num_rows = 0
  while (source_line != ""):
    num_rows += 1

    #
    # write out the first column
    #
    destination_stream.write( columns[ selected_indices[0] ] )

    #
    # write out any remaining columns
    #
    for current_index in selected_indices[1:]:
      destination_stream.write( separator )
      destination_stream.write( columns[current_index] )

    destination_stream.write("\n")

    #
    # get the next line
    #
    source_line = source_stream.readline()
    columns     = re.split(separator, string.strip(source_line))

  return num_rows, num_cols


def load_parameter_file(parameter_filename):
  """
  parameters = load_parameter_file(parameter_filename)
  
  Return the parameters located in the supplied file as a dictionary.
  """
  parameters = {}

  execfile( parameter_filename, parameters )

  del parameters['__builtins__']
  return parameters


def createOptTree(flags, longopts=[]):

    """
    createOptTree(flags, longopts=[])
    
    returns a dictionary with keys being the command flag and the value
    being value specified on the command line

    <flags> and <longopts> are passed to this function as specified in the
    getopt module.

    Parses command line options and parameter list.  args is the
    argument list to be parsed, without the leading reference to
    the running program.  Typically, this means "sys.argv[1:]".
    shortopts is the string of option letters that the script
    wants to recognize, with options that require an argument
    followed by a colon (i.e., the same format that Unix getopt()
    uses).  If specified, longopts is a list of strings with the
    names of the long options which should be supported.  The
    leading '--' characters should not be included in the option
    name.  Options which require an argument should be followed
    by an equal sign ('=').
    """

    opts, args = getopt.getopt(sys.argv[1:], flags, longopts)
    optTree = {}
    for opt in opts:
      optTree[opt[0]] = opt[1]
        
    return(optTree, args)

