#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
A Brutal testing of the interaction of various combinations of views,
labeligs, global labelings and datasets.  The is _intended_ to break things
as it will build nefarious configurations which will, hopefully, make our
schema much more robust.
"""

import unittest
import random
import MLab
import operator
import sys

from compClust.mlx.views import *
from compClust.mlx.datasets import Dataset
from compClust.util.Counter import *
from compClust.mlx.labelings import *

import Numeric
    
class BrutalTestCases(unittest.TestCase):

  def setUp(self):

    self.rand_init = [self.rand_subset_rows, self.rand_subset_cols,
                      self.rand_sorted_view, self.rand_transpose_view,
                      self.rand_row_filt_view, self.rand_cached_view,
                      self.rand_row_superset, self.rand_col_superset,
                      self.rand_row_aggregate_view]

  #
  # views
  #

  def mutate_nop(self, v):
    pass

  def foo(self, view, data, row):

    labeling = view.getLabeling("X")

    row_lab = labeling.getLabelByRow(row)

    val = []
    for i in range(len(data[0])):
      col_lab = labeling.getLabelByCol(i)
      val.append(row_lab * col_lab)
    return Numeric.array(val)
  
  def rand_row_aggregate_view(self, v):

    # Make a random number of random row sets

    keyset = []

    rows = v.getNumRows()
    num  = random.randrange(1, 2 * rows)

    for i in range(num):
      tmp = random.randrange(1, 5)
      keyset.append(map(random.randrange, [0] * tmp, [rows] * tmp))
      
    # The function maintains consistency (data = row_label * col_label)

    tmp = AggregateFunctionView(v, keyset)
    tmp.setFunction(lambda x, k: self.foo(tmp, x, k))
    
    tmp.setName("Aggregate")
    tmp.mutate = self.mutate_nop

    return tmp

  def rand_row_superset(self, v):
    tmp = RowSupersetView(v,v)
    tmp.setName('RowSuper')
    tmp.mutate = self.mutate_nop
    return tmp

  def rand_col_superset(self, v):
    tmp = ColumnSupersetView(v,v)
    tmp.setName('ColSuper')
    tmp.mutate = self.mutate_nop
    return tmp
  
  def rand_cached_view(self, v):
    tmp = CachedView(v)
    tmp.setName("Cached")
    tmp.mutate = self.mutate_nop
    return tmp
  
  def rand_subset_rows(self, v):

    # Pick a random number of rows between 1 and 2*Rows.  This make sure there
    # can be duplicate rows.  the populate it with random rows

    rows = v.getNumRows()
    num  = random.randrange(1, 2 * rows)
    subset = map(random.randrange, [0] * num, [rows] * num)
    tmp =  RowSubsetView(v, subset)
    tmp.setName('RowSubset')

    tmp.mutate = self.mutate_nop 
    return tmp

  def rand_subset_cols(self, v):

    cols = v.getNumCols()
    num  = random.randrange(1, 2 * cols)
    subset = map(random.randrange, [0] * num, [cols] * num)
    tmp = ColumnSubsetView(v, subset)
    tmp.setName('ColSubset')

    tmp.mutate = self.mutate_nop
    return tmp
  
  def rand_sorted_view(self, v):

    view = SortedView(v)
    view.setName('Sorted')

    self.mutate_sorted_view(view)

    view.mutate = self.mutate_sorted_view
    return view

  def mutate_sorted_view(self, v):

    rows = range(v.getNumRows())
    cols = range(v.getNumCols())

    random.shuffle(rows)
    random.shuffle(cols)
    
    v.permuteRows(rows)
    v.permuteCols(cols)

  def rand_transpose_view(self, v):
    tmp = TransposeView(v)
    tmp.setName('Xpose')

    tmp.mutate = self.mutate_nop
    return tmp

  def rand_row_filt_view(self, v):
    step = random.randrange(1,5)
    tmp = RowFilteredView(v, lambda ds, row : row % step == 0)
    tmp.setName('RowFilt')

    tmp.mutate = self.mutate_row_file_view
    return tmp

  def mutate_row_file_view(self, v):
    
    step = random.randrange(1,5)
    v.setFilter(lambda ds, row : row % step == 0)

  def build_random_chain(self, ds, depth):

    v = ds
    for i in range(depth-1):
      v = random.choice(self.rand_init)(v)
    return v

  def remove_chain(self, v):
    views = v.getLineage()[0]
    for i in range(1, len(views)):
      views[i].removeView(views[i-1])
      
  def check_view(self, view, labeling):
    """
    If there is a row and column label, the product must equal the data
    """
    
    data = view.getData()

    for x in range(view.getNumRows()):
      xval = labeling.getLabelByRow(view, x)
      if xval is None:
        continue
      
      for y in range(view.getNumCols()):
        yval = labeling.getLabelByCol(view, y)
        if yval is None:
          continue
        
        if (xval * yval) != int(data[x][y]):
          lineage = map(lambda x : x.getName(), view.getLineage()[0])
          #for tmp in lineage[:-1]:
          #  print tmp,"-->",
          #print lineage[-1]

          c = labeling.getLabelByCols(view)
          r = labeling.getLabelByRows(view)

          #print " ",
          #for foo in c:
          #  print "\t" + str(foo),
          #print

          #for i in range(view.getNumRows()):
          #  print r[i],
          #  for foo in data[i]:
          #    print "\t" + str(foo),
          #  print
          
          a#ssert 1 == 0

  def build_mult(self, a):

    data = []
    for i in a:
      data.append(map(operator.mul, [i] * len(a), a))

    return Dataset(data)

  def mutable_view_fork(self):
    
    a  = range(1,11)
    ds = self.build_mult(a)
    ds.setName('Root')
    ds.mutate = self.mutate_nop

    sys.stdout.write("Checking mutable view tree (single fork)")
    
    for q in range(10):
      sys.stdout.write(".")
      sys.stdout.flush()

      l1 = GlobalLabeling(ds, "X")
      v = self.build_random_chain(ds, random.randrange(2, 15))
      w = self.build_random_chain(ds, random.randrange(2, 15))
      
      l1.labelRows(ds, a)
      l1.labelCols(ds, a)

      l2 = GlobalLabeling(v)
      for i in range(v.getNumRows()):
        l2.addLabelToRow(v, l1.getLabelByRow(v, i), i)

      for i in range(v.getNumCols()):
        l2.addLabelToCol(v, l1.getLabelByCol(v, i), i)

      for m in range(5):
        for view in v.getLineage()[0]:
          view.mutate(view)
        for view in w.getLineage()[0]:
          view.mutate(view)
          
        self.check_view(v, l1)
        self.check_view(w, l1)
        self.check_view(v, l2)
        self.check_view(w, l2)
      
      ds.removeLabeling(l1)
      ds.removeLabeling(l2)
      
      self.remove_chain(v)
      self.remove_chain(w)

    sys.stdout.write("\n")
    
  def mutable_view_chain(self):
    """
    Builds a random view change, but changes parameters of the views several
    times.  i.e. permute, change functions, etc....
    """

    sys.stdout.write("Checking mutable view chain")

    a  = range(1,11)
    ds = self.build_mult(a)
    ds.setName('Root')
    ds.mutate = self.mutate_nop
    
    for q in range(10):
      sys.stdout.write(".")
      sys.stdout.flush()
      
      l = GlobalLabeling(ds, "X")
      v = self.build_random_chain(ds, random.randrange(5, 10))

      l.labelRows(ds, a)
      l.labelCols(ds, a)

      for m in range(5):
        for view in v.getLineage()[0]:
          view.mutate(view)
        self.check_view(v, l)

      ds.removeLabeling(l)
      self.remove_chain(v)
    sys.stdout.write("\n")
    
  def basic_view_chain(self):
    """
    Test the global labeling out on randomized view chains.
    """

    sys.stdout.write("Checking view chain")
    
    a  = range(1,11)
    ds = self.build_mult(a)
    ds.setName('Root')
    
    
    for q in range(10):
      sys.stdout.write(".")
      sys.stdout.flush()

      l = GlobalLabeling(ds, "X")
      v = self.build_random_chain(ds, random.randrange(2, 15))
  
      l.labelRows(ds, a)
      l.labelCols(ds, a)

      self.check_view(v, l)

      ds.removeLabeling(l)
      self.remove_chain(v)
    sys.stdout.write("\n")

  def basic_view_fork(self):

    a  = range(1,11)
    ds = self.build_mult(a)
    ds.setName('Root')

    sys.stdout.write("Checking view tree (single fork)")
    
    for q in range(10):
      sys.stdout.write(".")
      sys.stdout.flush()

      l1 = GlobalLabeling(ds, "X")
      v = self.build_random_chain(ds, random.randrange(2, 15))
      w = self.build_random_chain(ds, random.randrange(2, 15))
      
      l1.labelRows(ds, a)
      l1.labelCols(ds, a)

      l2 = GlobalLabeling(v)
      for i in range(v.getNumRows()):
        l2.addLabelToRow(v, l1.getLabelByRow(v, i), i)

      for i in range(v.getNumCols()):
        l2.addLabelToCol(v, l1.getLabelByCol(v, i), i)
        
      self.check_view(v, l1)
      self.check_view(w, l1)
      self.check_view(v, l2)
      self.check_view(w, l2)
      
      ds.removeLabeling(l1)
      ds.removeLabeling(l2)
      
      self.remove_chain(v)
      self.remove_chain(w)

    sys.stdout.write("\n")

  def check_failure(self):
    """
    Diagnose specific failures
    """

    a  = range(1,11)
    ds = self.build_mult(a)
    ds.setName('Root')

    l = GlobalLabeling(ds, "X")
    l.labelRows(ds, a)
    l.labelCols(ds, a)
    
    v = self.rand_row_aggregate_view(ds)
    v = self.rand_row_aggregate_view(v)
    
    self.check_view(v, l)
    
    ds.removeLabeling(l)
    ds.removeView(v)

    
def suite(**kw):
    print "ATTN: Need superset functions are a bit behind in the API, skipping brutal tests"
    return None
    suite = unittest.TestSuite()
    suite.addTest(BrutalTest("basic_view_chain"))
    suite.addTest(BrutalTest("basic_view_fork"))
    suite.addTest(BrutalTest("mutable_view_chain"))
    suite.addTest(BrutalTest("mutable_view_fork"))
    suite.addTest(BrutalTest("check_failure"))
    return(suite)

if __name__ == "__main__":
  unittest.main(defaultTest="suite")



