#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

import os
import re
import string
from StringIO import StringIO
import tempfile
import unittest
import urllib2

from compClust.util import FileIO
import compClust.util

# make some sample data
data = [[1.0,2.0,3.0,4.0,5.0],
        [2.0,3.0,4.0,5.0,6.0],
        [3.0,4.0,5.0,6.0,7.0],
        [4.0,5.0,6.0,7.0,8.0],
        [5.0,6.0,7.0,8.0,9.0],
        [6.0,7.0,8.0,9.0,0.0]]
data_stream = StringIO()

def make_data_stream(array, stream=None, delimiter="\t"):
  if stream is None:
    stream = StringIO()
  for row in array:
    stream.write(string.join(map(str, row), delimiter)+os.linesep)
  stream.seek(0)
  return stream

# make some sample labelings
labeling = ['a', 'b', 'c', 'd', 'e', 'f']
def make_label_stream(vector, stream=None, delimiter='\t'):
  if stream is None:
    stream = StringIO()
  stream.write(string.join(vector, delimiter)+os.linesep)
  stream.seek(0)
  return stream


class FileIOTestCases(unittest.TestCase):
  def setUp(self):
    self.original_dir    = os.getcwd()
    self.local_dir = compClust.util.__path__[0]

  def tearDown(self):
    os.chdir(self.original_dir)
      
  
  def testGetFilesByRegExp(self):
    """GetFiles by Regular Expression
    """
    files = FileIO.getFilesByRegExp(self.local_dir, ".*\.py$")
    self.failUnless("FileIO.py" in files)
    
    files = FileIO.getFilesByRegExp(self.local_dir, "^FileIO\.py$")
    self.failUnless(len(files) == 1)

  def testOpenStream(self):
    """Test the open stream function
    """
    s = FileIO.openStream("http://woldlab.caltech.edu/compclust/examples/pgc_diabetes/compclust-labeling-Trigs.clab.gz")
    self.failUnless(float(s.readline()) == float('1.19'))

    # sometimes apache is sneaky and will return a gzip file with out it being part of the name
    self.failUnlessRaises(urllib2.HTTPError, FileIO.openStream, "http://woldlab.caltech.edu/i-hope-i-never-make-a-file-called-this.html")
  
  def testReadDelimitedFile(self):
    data_stream = tempfile.NamedTemporaryFile()
    data_stream = make_data_stream(data, data_stream)
    read_data = FileIO.readDelimitedFile(data_stream.name)    
    self.__compareDataStream(data_stream,read_data)
    
  def testReadDelimitedData(self):
    """Can we load a stream
    """
    data_stream = make_data_stream(data)
    read_data = FileIO.readDelimitedData(data_stream)
    self.__compareDataStream(data_stream, read_data)
    
  def testReadColumnData(self):
    """What happens when we have a single row with the labels as a column?
    """
    column_data = 'alpha\tbeta\tdelta\tgamma\tepsilon 3'
    data_stream = StringIO(column_data)
    read_data = FileIO.readLabelStream(data_stream)
    self.failUnless(read_data[0] == column_data.split('\t'))
    
  def testReadColumnNumericData(self):
    """Does the new numeric coercion code work for columns?
    """
    
    column_numeric = [3.1, 2.1, 4.3, 7.9, 3]
    column_strings = [ str(x) for x in column_numeric]
    column_text = string.join(column_strings, '\t')+os.linesep
    data_stream = StringIO(column_text)
    read_data = FileIO.readLabelStream(data_stream)
    self.failUnless(read_data[0] == column_numeric)
    
  def testReadNumericData(self):
    """Does the numeric coercion code work for rows?
    """
    row_numeric = [3.1, 2.1, 4.3, 7.9, 3]
    row_strings = [ str(x) for x in row_numeric]
    row_text = string.join(row_strings, os.linesep)+os.linesep
    data_stream = StringIO(row_text)
    read_data = FileIO.readLabelStream(data_stream)
    
    self.failUnless(read_data == [[x] for x in row_numeric])
    
  def testReadDelimitedData(self):
    """Can we read numeric data with different delimiters?
    """
      
    for join_delimiter, split_delimiter in [('\t', '\t'), ('   ', ' +'), (',',',')]: 
      data_stream = make_data_stream(data, delimiter=join_delimiter)
      read_data = FileIO.readDelimitedData(data_stream, split_delimiter)
      self.__compareDataStream(data_stream, read_data)
      
  def testParseToken(self):
    """Make sure that parse token works correctly    
    """
    tests = [( "", "" ),
             ( "  ", "" ),
             ( " 3rd st", "3rd st" ),
             ( "3", 3),
             # dodged that whole binary repeating decimal thing by searching for a safe float
             ( "1.01", 1.01), ]
    for test, result in tests:
      self.failUnless(FileIO._parseToken(test) == result)
      
    self.failUnlessRaises(AttributeError, FileIO._parseToken, None)
                        
                        
    
  def __compareDataStream(self, data_stream, read_data):
    self.failUnless(len(read_data) == len(data))
    for row_index in xrange(len(read_data)):
      self.failUnless(len(read_data[row_index]) == len(data[row_index]))
      for col_index in xrange(len(read_data[row_index])):
        read_value = read_data[row_index][col_index]
        data_value = read_data[row_index][col_index]
        self.failUnlessAlmostEqual(read_value, data_value)
    
  def testReadLabelFile(self):
    label_stream = tempfile.NamedTemporaryFile()
    make_label_stream(labeling, label_stream,'\n')    
    read_labels = FileIO.readLabelFile(label_stream.name)
    self.__compareLabels(labeling, read_labels)
      
  def testReadLabelStream(self):
    label_stream = StringIO()
    make_label_stream(labeling, label_stream,os.linesep)    
    read_labels = FileIO.readLabelStream(label_stream)
    self.__compareLabels(labeling, read_labels)
        
  def __compareLabels(self, labeling, read_labels):
    self.failUnless(len(read_labels) == len(labeling))
    for index in range(len(read_labels)):
      self.failUnless(read_labels[index][0] == labeling[index])
    
  def testReadIntegratedStream(self):
    """Test reading from 
    """
    # skip this test while we're developing it
    return 
    tab_file = """  |    |cond1|cond2|cond3|cond4|conditions
  |    |alpha|beta |     |alpha|color
  |    |0 h  |1 h  |2 h  |3 h  |time
1 |HOX |1    |1    |1    |1    |
2 |SHH |1    |1    |1    |1    |
3 |YFG |1    |1    |1    |1    |
4 |YFG2|1    |1    |1    |1    |
5 |PAX7|1    |1    |1    |1    |
id|name|     |     |     |     |
"""
    tab_file = re.sub("|","\t",tab_file)
    stream = StringIO()
    stream.write(tab_file)
    stream.seek(0)
   
    data, rowlabels, collabels = FileIO.readIntegratedStream(stream)
   
   
def suite(**kw):
  return unittest.makeSuite(FileIOTestCases)

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
