########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
Usage: Barkai.py parameterFilename datasetFilename resultsFilename
  
 Wrapper for Barkai/Imhels algorithm

 Algorithm parameters include the following name value pairs.  Unless
 a default is indicated, the parameter is required.
 
     column_threshold (tC)
       Number of standard deviations from the column mean to keep.  Any
       column data below this threshold will be discarded.

     row_threshold (tG)
       Number of standard deviations from the row mean to keep.  Any
       row data below this threshold will be discarded.

     std_method (hack)
       Pick how to compute standard deviation for the columns.  This can
       be either 'model' or 'normal'.  'normal' computed STD in the traditional
       manner, while 'model' computes STD based on a model of the STD of
       random data.  'model' is most useful when the dataset has a small
       number of columns.

     subset_size
       Number of rows to randomly pick to seed the self consistent search.

     num_trials
       Number of trials to do.
       
     num_iterations
       The number of iterations to converge to a self consistent set.
  
     seed
       The seed to use for the pseudo-random number generator (valid only
       for randomly initialized means).  Defaults to 42.
"""
     
import os
import string
import sys
import types
import warnings
import random

import Numeric
import MLab

from compClust.util import Verify
from compClust.util import WrapperUtil
from compClust.util import listOps

from compClust.util.TimeStampedPrintStream import TimeStampedPrintStream

from compClust.mlx.labelings import Labeling
import compClust.mlx.ML_Algorithm as ML_Algorithm

import compClust.mlx.wrapper

#
# MESSAGE_STREAM
#

MESSAGE_STREAM = TimeStampedPrintStream("%Y-%b-%d %H:%M: Barkai: ")

#
# Barkai
#

class Barkai(ML_Algorithm.ML_Algorithm):
  
  def __init__(self, dataset=None, parameters=None):
     """Barkai(dataset, parameters)

     Creates a new Barkai algorithm with the given dataset and
     algorithm parameters.  To run, use the run() method.
     """

     self.dataset    = dataset
     self.parameters = parameters
     self.labeling   = None
     self.model      = None
      
     self.setMessageStream( MESSAGE_STREAM )


  def copy(self):
    new_obj = Barkai(self.dataset, self.parameters)
    new_obj.labeling = self.labeling
    new_obj.model = self.model
    return new_obj

   
  def getLabeling(self):
    return self.labeling


  def getModel(self):
    return self.model
  

  def run(self):

    #
    # Initialize the random seed
    #
    
    seed = self.parameters.get("seed", 42)
    random.seed(seed)

    data = self.dataset.getData()

    #
    # Whiten (normalize) the data for the inner loop
    #
    
    Eg = self.whiten_rows(data)
    Ec = self.whiten_columns(data)

    trail      = 0
    num_trials = self.parameters["num_trials"]

    while trial < num_trials:

      #
      # Perform the self-consistent algorithm
      #
      
      rows, cols = do_barkai(Eg, Ec)

      #
      # If no rows and columns are returned, it was a failed run, so ignore it
      #
      
      if len(rows) == 0 or len(cols) == 0:
        continue
      
      trial += 1

    

  def do_barkai(self, Eg, Ec):

    iterations = 0
    max_iters  = self.parameters["num_iterations"]

    rows = self.pick_random_row_subset()
    cols = range(self.dataset.getNumCols())
    
    while rows and iterations < max_iters:

      new_cols, values  = self.find_self_consistent_columns(Eg, rows)
      new_rows          = self.find_self_consistent_rows(Ec, new_cols, values)
      
      if len(new_rows) == 0:
        rows = []
        cols = []
        break

      if (new_rows == rows) and (new_cols == cols):
        break

      rows = new_rows
      cols = new_cols

      iterations += 1

    return rows, cols
  

  def pick_random_row_subset(self):

    subset_size = self.parameters["subset_size"]
    num_rows    = self.dataset.getNumRows()
    orig_rows   = range(num_rows)

    if subset_size > num_rows:
      subset_size = num_rows

    if subset_size > num_rows / 2:
      for i in xrange(num_rows - subset_size):
        index = random.randrange(len(orig_rows))
        orig_rows.pop(index)
      rows = orig_rows

    else:
      rows = []
      for i in xrange(subset_size):
        index = random.randrange(len(orig_rows))
        rows.append(orig_rows.pop(index))

    return rows


  def find_self_consistent_columns(self, data, rows):
    """
    performs the imhels/barkai self-consistency algorithm on the
    column data.
    """
    
    #
    # subset dataset with a random selection of genes (rows)
    #
    
    Eg_gss = Numeric.take(data, rows)
  
    #
    # score each experimental (columns) condition by averaging the expression
    # change over the genes of the input set
    #
    
    sc = MLab.mean( Eg_gss )
    normalized_sc = abs( sc - MLab.mean( sc ) )
  
    #
    # Identify experiment signature (SC).  It contains conditions whose
    # absolute score is statistically significant
    #
  
    if self.parameters['std_method'] == 'model':
      sigC = 1.0 / sqrt( len(rows) )
    else:
      sigC = std( sc )

    #
    # find conditions that contain significant absolute scores
    #

    tC = self.parameters['column_threshold']
    
    significantColBitvector = abs( sc - mean( sc ) ) > tC * sigC
    significantColIdx       = listOps.findAll(significantColBitvector, 1)
  
    #
    # collect the values of sc that correspond to the columns
    #

    significantColVals = Numeric.take(sc, significantColIdx)

    return significantColIdx, significantColVals


  def find_self_consistent_rows(self, data, columns, col_vals):

    #
    # subset colspace of Ec using only significant columns
    #
    
    Ec_css = Numeric.take(data, columns, 1)
                                                         
    #
    # score each gene. The score contains the weighted average of the colspace
    # (weights determined by sc for only those columns in SC)
    #

    Ec_css_wt = Numeric.matrixmultiply( Ec_css, MLab.diag( col_vals) )

  
    #
    # Identify gene-signature (SG).  It contains those genes whose score is
    # statistically significant
    # 

    sg   = MLab.mean(Numeric.transpose( Ec_css_wt ) )
    sigG = MLab.std(sg)

    tG   = self.parameters['row_threshold']
    
    significantRowBitvector = sg - mean( sg ) > tG * sigG
    significantRowIdx       = listOps.findAll(significantRowBitvector, 1)
    
    #
    # SG is the new genelist.
    #
    
    return significantRowIdx


  def whiten_rows(self, data):
    """
    Normalize data to create Eg, where Eg has mean = 0 and variance = 1
    w/respect to the rows.
    """
    
    rowAvg = MLab.mean( data )
    rowStd = MLab.std ( data )

    data = data - rowAvg
    data = data / rowStd

    return data

  
  def whiten_columns(self, data):
    """
    Normalize data to create Ec, where Ec has mean = 0 and variance = 1
    w/respect to the columns.
    """

    data = Numeric.transpose( data )
    data = self.whiten_rows( data )
    
    return Numeric.transpose( data )

  def validate(self):
     """
     validate()
      
     Returns 1 if all parameters and environment variables nescessary
     to run barkai are defined, 0 otherwise.
     """
       
     parameterNames   = [ "columns_threshold",
                          "row_threshold",
                          "std_method",
                          "subset_size",
                          "num_trials",
                          "num_iterations"
                        ]

     parameters = self.parameters
     error      = 0
      
     if Verify.parameters_exist( parameterNames, parameters ):
       error = 1
      
     return not error 


if (__name__ == "__main__"):
  from compClust.mlx.wrapper import Launcher

  Launcher.main(sys.argv, Barkai())
