#
# Copyright 2005 California Institute of Technology
# 
# This file is part of the CompClust package 
# FIXME: need to update to current license
#
# It is based on NR contaminated C code by Becky Castano and Tobias Man from 
# the JPL machine learning group. 
# 
# The python translation was by Diane Trout

"""Utilities to generate Synthetic Data
"""
import string

import RandomArray
import Numeric as numerix

from compClust.mlx.datasets import Dataset
from compClust.mlx.labelings import Labeling

class HierarchicalSyntheticData:
  """Generate a hierarchical set of synthetic data
  
  the initial means for each level are chosen from a uniform
  distribution, however the final 
  """
  def __init__(self, branching, dimensionality, points, variance=1.0, variance_ratio=1.0, seed=None):
    """
    :Parameters:
      - `branching`: how many branches per tree level
          e.g. [2,3,4] would be 2 top level clusters,with 3 children, and
          each of those children would have 4 children
      - `dimensionality`: the number of dimensions for the generated dataset
      - `points`: the total number of points to generate
      - `variance`: the starting variance (defaults to 1.0)
      - `variance_ratio`: for each level of the tree multiply the previous variance by 
          the variance_ratio to determine the next levels variance.
      - `seed`: the seed used to initialize the random number generator
    """
    # initialize our parameters
    self.points = points
    self.branching = branching + [self.points]
    self.dimensionality = dimensionality
    self.variance = variance
    self.variance_ratio = variance_ratio
    self.seed = seed

    self.tree_distribution = RandomArray.normal
    self.leaf_distribution = RandomArray.normal

    self.datasets = []
    self.labels = []
    self.generate_tree_means()
    
  def generate_tree_means(self):
    classes = self.branching[0]
    branch_variance = self.variance
    
    # generate the top of the tree
    zeros = numerix.zeros((classes, self.dimensionality))
    shape = (classes, self.dimensionality)
    name = "synth data %s" % (self.branching[0])
    self.datasets = [Dataset(self.tree_distribution(zeros, self.variance, shape), name)]
    labeling = Labeling(self.datasets[-1], "ground truth")
    labeling.labelRows(['r']*self.datasets[-1].numRows)
    # create [[0],[1],[2]...] as the seed for creating all the path 
    # specifiers for each datum element
    label_tree = [map(list, zip(range(self.datasets[-1].numRows)))]
    
    # for all the levels of the tree
    for branch_index in range(1, len(self.branching)):
      branch = self.branching[branch_index]
      parent_means = self.datasets[-1].getData()
      branch_means = []
      parent_labels = label_tree[-1]
      branch_labels = []
      dataset_labels = []
      branch_variance = branch_variance * self.variance_ratio
      # for all of the means of the previous level, create branch number of new means
      for parent_class_index in xrange(len(parent_means)):
        parent_class_means = parent_means[parent_class_index]
        shape = (branch, self.dimensionality)
        current_means = self.tree_distribution(parent_class_means, branch_variance, shape)
        current_label = string.join(map(str,parent_labels[parent_class_index]),",")
        for label in xrange(branch):
          # dataset_labels creates a list where the path specifier
          # is repeated for each memeber of the new cluster
          #
          # branch_lables creates a new path specifier where each cluster 
          # member gets a new identifier.
          # the dataset_label becomes the labeling for the current dataset
          # the branch_lables become the seed to construct the next labeling
          dataset_labels.append(current_label)
          branch_labels.append(parent_labels[parent_class_index]+[label])
        branch_means.extend(current_means)
      branch_dataset = Dataset(branch_means, "synth data %s" % (self.branching[:branch_index+1]))
      labeling = Labeling(branch_dataset, "ground truth %d" % (len(label_tree)))
      labeling.labelRows(dataset_labels)
      label_tree.append(branch_labels)
      self.datasets.append(branch_dataset)

if __name__ == "__main__":
  synth = HierarchicalSyntheticData([3,2,4], 2, 3, 1)
  for ds in synth.datasets:
    print ds
    print ds.getLabelings()
    print ds.getLabelings()[0].getAllRowLabels()