########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
Utility functions that are useful for accessing files and directories.
"""
import os
import StringIO
import sys
import gzip
import re
import Numeric
import MA
import urllib
import urllib2
import urlparse

from compClust.util.NaN import nanfloat

def getFilesByRegExp(dir_name, regexp):
  """list_of_files = getFilesByRegExp(dir_name, rexexp)

  Given a directory and regular expression; find all of the
  files in that directory that match the regular expression.
  """

  import os
  
  files = os.listdir(dir_name)
  filtered_files = filter(lambda x, r = regexp: re.search(r, x), files)

  return filtered_files

def openStream(filename):
  """Open a stream

  Handling remote files and compression transparently.
  """
  address_scheme, host, urlpath, urlparams, urlquery, urlfragment = urlparse.urlparse(filename)
  # HACK: Windows drive letters look an awful lot like url schemes
  # HACK: so lets ignore 1 character schems as they're likely to be windows
  # HACK: drive letters
  if len(address_scheme) > 1:
    stream = urllib2.urlopen(filename)
    if stream.headers.has_key('Content-Encoding') and \
       re.search('gzip', stream.headers.getheader('Content-Encoding')):
      # FIXME: it might be cool if the filling of buffered_stream happened in a seperate thread
      buffered_stream = StringIO.StringIO()
      for block in stream:
        buffered_stream.write(block)
      buffered_stream.seek(0)
      stream = gzip.GzipFile(filename, mode='rb', fileobj=buffered_stream)
    else:
      raise IOError("Unrecognized content encoding")
  else:  
    root, ext = os.path.splitext(filename)
    if ext == '.gz':
      stream = gzip.GzipFile(filename, mode='rb')
    else:
      stream = open(filename, 'r')

  return stream
  
###############################################################################
#
# File operations for label files
#
###############################################################################

def readLabelFile(filename, delimiter="\t"):
  """
  Open a stream containing labels from either the local filesystem or a url.
  
  The actual stream is parsed by readLabelStream
  """
  stream = openStream(filename)
    
  labels = readLabelStream(stream, delimiter)
  stream.close()

  return labels

def readLabelStream(stream, delimiter='\t'):
  """
  Returns a list of labels that were seperated by either delimiter or newline.
  """
  import string
  
  newline   = '\n'
  empty     = ''
  
  pattern = re.compile('(' + delimiter + '|' + newline + ')')

  invalid = {}

  invalid[delimiter] = 0
  invalid[newline]   = 0
  invalid[empty]     = 0
  
  labels = []  

  while (1):
    
    line = stream.readline()

    if line == '':
      break
  
    strippedLine = string.strip(line)
    

    #
    # If a blank line is encountered, it is the same as no label(s)
    #
      
    if strippedLine == '':
      labels.append([])
      continue

    #
    # Remove the undesirable tokens
    #

    tokens = filter(lambda x : invalid.get(x, 1), pattern.split(line))
    tokens = [ _parseToken(token) for token in tokens ]
    labels.append(tokens)
    
  return labels

def _parseToken(token):
  """Strip whitespace off of a token, and try coercing to an int or float if possiblle
  """
  token = token.strip()
  try:
    token = int(token)
  except ValueError, e:
    try:
      token = nanfloat(token)
    except ValueError, e:
      pass
  return token
     
#############################################################################  
#
# File operations for dataset files
#
#############################################################################

def readDelimitedFile(filename, delimiter="\t"):
  """
  Open a dataset stream from a filename or url and pass it to readDelimitedStream.
  """
  data = None
    
  stream = openStream(filename)
  data = readDelimitedData(stream, delimiter)
  stream.close()

  return data
  

def readDelimitedData(stream, delimiter="\t", flag=0):
  """
  Reads in a dataset and accounts for missing values/NaNs as well.  It will
  assign a masked array to self.data after successfully reading in the data.
  The flag indicates if one should strip the first column (true) or not.
  """

  import string
  
  newline   = '\n'
  empty     = ''
  
  #
  # Enclosing the regular expression in parentheses ensures that
  # delimiters and newlines will be returned along with all other
  # tokens as a result of the split() method.
  #
  
  pattern = re.compile('(' + delimiter + '|' + newline + ')')

  #
  # Make a hash of invalid tokens
  #

  invalid = {}

  invalid[delimiter] = 0
  invalid[newline]   = 0
  invalid[empty]     = 0
  
  #
  # Process each line in stream (filename).
  #
  
  data = []  
  while (1):
    
    line = stream.readline()

    #
    # An empty line means EOF
    #

    if line == '':
      break
  
    strippedLine = string.strip(line)
    
    #
    # Handle blank lines (presumably at the end of a file).
    #
    # If a blank line is encountered, do not attempt to process it.
    # Instead assume the entire dataset has been read.  Break and
    # construct the Numeric array.
    #
      
    if strippedLine == '' or strippedLine[0] == '#':
      continue

    #
    # Remove the undesirable tokens
    # By ignoring all the ones that are empty or are one 
    # of our delimiters
    #
    tokens = []
    for val in pattern.split(line):
      if not (len(val) == 0 or pattern.match(val)):
        tokens.append(val)
    
    #
    # Convert to a list of floats and add to the list of datum
    #

    if flag:
      data.append(map(nanfloat, tokens[1:]))
    else:

      #
      # Try to be smart.  If the first column is not a number, even if
      # the flag was not set, we'll set the flag ourselves
      #
      
      try:
        data.append(map(nanfloat, tokens))
      except ValueError:
        data.append(map(nanfloat, tokens[1:]))
        flag = 1
                

  #
  # Construct the Numeric array, a, from the list of lists, data.
  #
  # If the Numeric array is really a vector, maintain the 2-D structure
  # since the dataset schema is based on the data being in a 2-D matrix.
  #
  
  a = Numeric.array(data)
  
  #
  # Construct a mask array, m, containing 1's anywhere where a NaN
  # exists in a, and 0's everywhere else.  According to IEEE-754, NaN
  # is the only floating-point value where NaN != NaN is true.  Use m
  # as the mask parameter when constructing the MaskedArray.
  #
    
  m = Numeric.not_equal(a, a)
  
  if Numeric.sum(Numeric.sum(m)) != 0:
    a = MA.MaskedArray(a, copy=0, mask=m)
    
  return a
  
def readIntegratedStream(stream, delimiter="\t"):
  """
  Read an 'integrated' file containing both labelings and data.
  
  The integrated file format is a single easy to parse excel friendly file.
  
              cond1   cond2   cond3  condition
              covar   covar   covar  random_covar
  a    name1    1       2       3
  b    name2    1       2       3
  c    name3    1       2       3
  d    name4    1       2       3
  id   name
  
  the first row and the first column need to contain the primary row and 
  column labelings, which need to be unique and fully specified. The labels
  after the dataset contain the names for each of the labelings.
  """
  data_stream = open(stream, 'r')
  delimiter_re = re.compile(delimiter)

  # build map and storage for row annotations
  row_annotations = {}
  row_annotation_map = {}
  if row is not None:
    for annotation_name, column_id in row.items():
      row_annotations[annotation_name] = []
      row_annotation_map[column_id] = row_annotations[annotation_name]

  # if we have a column annotation suck it in too
  if column is not None:
    column_annotation = []
    file_col_header = tab_re.split(data_stream.readline())
    for column_index in range(len(file_col_header)):
      if column_index not in row_annotation_map.keys():
        column_annotation.append(file_col_header[column_index])

  # accumulators
  data = []
  count = 0
  # process rows out of file
  for file_row in data_stream.xreadlines():
    file_row = tab_re.split(string.strip(file_row))
    # process columns in the row
    data_row = []
    for element_index in xrange(len(file_row)):
      if row_annotation_map.has_key(element_index):
        row_annotation_map[element_index].append(file_row[element_index])
      else:
        datum = nanfloat(file_row[element_index])
        if datum == 1:
          datum = 1.000000001
        data_row.append(datum)
    data.append(data_row)
    # progress bar?
    count += 1
    if (count % 1000) == 0:
      print '\b.',
      count = 0
  ds = DataSource(data)
  # log 2 transform the data
  ds.dataset = views.FunctionView(ds.dataset, lambda x: math.log(x, 2))
  ds.name = name
  # add row annotations
  for annotation_name, labeling_data in row_annotations.items():
    ds.add_labeling(annotation_name, labeling_data, isrow=True, isannotation=True)

  # add column annotation
  if column is not None:
    ds.add_labeling(column, column_annotation, isrow=False, isannotation=True)

  return ds
  
