########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
This module provides functions which are ment to help minimize the
amount of typing one has to do when using the compClust tools from the
python interpreter.  No new functionality is implemented in these
functions, simply wrappers to make data analysis inside of our schema
faster and easier from the command line

"""

from compClust.util import TimeStampedPrintStream
from compClust.util import WrapperUtil

from compClust.mlx.labelings import Labeling
from compClust.mlx.views import RowSubsetView
from compClust.mlx.views import RowFunctionView

import compClust.mlx.wrapper

import MLab
import Numeric

import types
import tempfile
import os
import thread


def safeStdDev(array, dimension = 0, missingValue=0):

  """
  safeStdDev(array, dimension = 0, missingValue=0)

  Returns either the std dev. of the 2d array or zero- note the array
  must be 2d """

  fail = 0
  try:
    r = MLab.std(array, dimension)
  except:
    fail =1 
  
  if fail:
    if dimension == 0:
      r = Numeric.array(map(lambda x: missingValue, range(len(array[0]))))
    else:
      r = Numeric.array(map(lambda x: missingValue, range(len(array))))

  return(r)


def wrapperRunThread(wrapper, quiet= 0):

  """
  clusterThread(wrapper, quiet = 0)

  Launches a primed (ready to run) wrapper in a new thread.  
  """
  outstream = TimeStampedPrintStream.TimeStampedPrintStream()
  if not quiet:
    outstream.write('Algo Running\n')

  thread.start_new_thread(wrapper.run, ())

def cluster(wrapper, parameters, dataset, quiet=0, useThreading=1):

  """
  runAlgo = cluster (algo, paramters, dataset, quiet=0, useThreading=1)

  algo is one of our compClust algo wrappers
  (ie. compClust.mlx.wrapper.KMeans.KMeans).

  Parameters can either be a
  parameters dictionary or a parameter file.  

  dataset is a compClust.mlx.Dataset object

  This utility mirrors the usage of the command line utilities to an
  extent.
  
  cluster returns a run Algo if succesful, otherwise it returns None.
  if you are only interested in the labeling, you can do something
  like

  >>> lab = cluster(KMeans.KMeans, parameters).getLabeling()

  """

  #
  # Cast input parameters if need be
  # 
  outstream = TimeStampedPrintStream.TimeStampedPrintStream()
  if type(parameters) == types.StringType:
    try:
      parametersDict =  WrapperUtil.load_parameter_file(parameters)
    except:
      if not quiet:
        outstream.write("Can't open parameter file %s\n"%(parameters) )
      return(None)
  elif type(parameters) == types.DictType:
    parametersDict = parameters
  else:
    if not quiet:
      outstream.write("Parameters of unknown type\n")
    return(None)

  #
  # Munge the parameters output directories to me tmp files.
  #
  if not parametersDict.has_key('results_dir'):
    parametersDict["results_dir"] = tempfile.mktemp('intTools')
    os.mkdir(parametersDict['results_dir'])
    parametersDict["save_intermediate_files_base"] = 'interpreterClust'
    
  #
  # Deal with the special Meta Wrappers - MCCV and Hiearical
  # 
  if ((parametersDict.has_key( "MCCV" ) and parametersDict["MCCV"] == 'on') or
      (parametersDict.has_key( "mccv" ) and parametersDict["mccv"] == 'on')):
    outstream.write("Using MCCV Meta Wrapper\n")
    algo = compClust.mlx.wrapper.MCCV(dataset, parametersDict, wrapper())

  elif  ((parametersDict.has_key("HIERARCHICAL") and
          parametersDict["HIERARCHICAL"] == 'on')
         or
         (parametersDict.has_key("hierarchical") and
          parametersDict["hierarchical"] == 'on')):
    outstream.write("Using  Hierachical Meta Wrapper\n")
    algo = compClust.mlx.wrapper.Hierachical(dataset, parametersDict, wrapper())

  else:
    algo = wrapper(dataset, parametersDict)

  #
  # Validate and Run
  #

  if algo.validate():
    if not quiet:
      outstream.write('Validation Passed. Running...\n')

    if useThreading:
      if not quiet:
        outstream.write('starting new thread...\n')
        wrapperRunThread(algo)
      status = 2

    else:
      status = algo.run()

    if status == 1:
      if not quiet:
        outstream.write('Run Finished\n')
    elif status == 2:
      if not quiet:
        outstream.write('Running in new thread\n')
    else:
      if not quiet:
        outstream.write('Run FAILED\n')
      algo = None
  else:
    if not quiet:
      outstream.write('Validation FAILED\n')
      algo = None

  return(algo)


def clusterSizes(labeling):

  """
  clusterSizes(labeling)

  return a dictionary of cluster sizes.
  """

  sizeDict = {}
  map(lambda x: sizeDict.setdefault(x, len(labeling.getRowsByLabel(x))),  labeling.getLabels())
  return(sizeDict)


  
def findSharedRows(ds1, ds2, lab1, lab2):

  """
  findSharedRows(ds1, ds2, lab1, lab2)

  Using the labeling lab1, and lab2 return two subset views based on
  datasets ds1, and ds2 such that the subsets contain all rows that
  contain labels shared in both lab1, and lab2.  The subsets are
  returned sorted such that rows are directly comparable and they are
  ready - if needed to be supersetted.

  Note, labels are assumed to be unique and point to only one row.  

  ie) Return two subset views that have the same set of genes where
      lab1 and lab2 are the gene name labels for ds1 and ds2

  """

  dict1 = {}
  map(dict1.setdefault, lab1.getLabelByRows())
  intersection = filter(dict1.has_key, lab2.getLabelByRows())

  keys1 = map(lambda x: lab1.getKeysByLabel(x)[0], intersection)
  keys2 = map(lambda x: lab2.getKeysByLabel(x)[0], intersection)

  subset1 = RowSubsetView(ds1, keys1)
  subset2 = RowSubsetView(ds2, keys2)
  
  newLab1 = Labeling(subset1, lab1.name)
  newLab1.labelFrom(lab1)

  newLab2 = Labeling(subset2, lab2.name)
  newLab2.labelFrom(lab2)

  return (subset1, subset2)


def mergeAndSubset(ds1, ds2, lab1, lab2):

  """
  mergeAndSubset(ds1, ds2, lab1, lab2)

  finds the common rows between ds1 and ds2 based on the labelings
  lab1, and lab2.  Produces the intersection subset for both datasets,
  makes a superset based on these two subsets and the returns two
  subsets consisting of the intersecting rows, but now labelings can
  be shared between the datasets.
  
  """
  pass
  

def normalizationByRowMedian(ds, threshold):

  """
  NormalizationByRowMedian(ds, threshold):
  
  returns a RowFuctionView of the dataset where each row is
  normalized by the row median as long as the median value is above
  the threshold.

  """

  def normFunc(row, thr=threshold):
    import MLab
    med = MLab.median(row)
    if med > threshold:
      return (row/med)
    else:
      return row

  return (RowFunctionView(ds, normFunc))


