"""Nearest neighbor search for MLX datasets
"""
#########################
# NN search
import os
import string

from compClust.mlx.datasets import Dataset
from compClust.mlx.labelings import Labeling
from compClust.util.DistanceMetrics import EuclideanDistance
from compClust.mlx.views import SortedView


def nn_row(dataset, labeling, target, metric=EuclideanDistance, name=None):
  """Sort the dataset by it's distance from the target
  target must be a unique element in the labeling provided
  """
  # if its a name go get the real labeling
  if not isinstance(labeling, Labeling):
    labeling = dataset.getLabeling(labeling)

  # get target vector
  target_rows = labeling.getRowsByLabel(target)
  if len(target_rows) != 1:
    raise ValueError("There must be one and only one row for the provided target found %d instead" % (len(target_rows)))
  target_vector = dataset.getRowData(target_rows[0])
  
  sorted_view = SortedView(dataset, name=name)

  # construct distance labeling
  distance_labeling = Labeling(sorted_view, metric.__name__)
  for row_key in sorted_view.getRowKeys():
    distance = metric(dataset.getData(row_key), target_vector).toscalar()
    distance_labeling.addLabelToKey(distance, row_key)

  # sort it
  distance_labeling.sortDatasetByLabel()
  return (sorted_view, distance_labeling)

def dataset_row_report(stream, dataset, output_labelings, count=None, skip=0, include_data=True, include_col_header=True):
  """Produce a tab delimited text with each rows label and (possibly) data

  stream is a file stream to write to
  dataset is rather obviously an mlx dataset
  labelings is a list of labelings (or their names) to include
  count is how many rows to output starting from skip
  include_data is a flag indicating if we should include the row data
  """
  # make our list of labelings
  labelings = []
  for l in output_labelings:
    if isinstance(l, Labeling):
      labelings.append(l)
    else:
      labelings.append(dataset.getLabeling(l))

  # write the header
  if include_col_header:
    header = []
    for l in labelings:
      if l.getName() is not None:
        header.append(l.getName())
      else:
        header.append("")
    stream.write(string.join(map(str, header), "\t"))
    stream.write(os.linesep)

  # slice of our skip and count components
  row_indicies = xrange(skip, dataset.numRows)
  if count is not None:
    row_indicies = row_indicies[:count]

  # write body
  for row in row_indicies:
    row_contents = []
    for l in labelings:
      row_contents.append(l.getLabelByRow(row))
    if include_data:
      row_contents.extend(dataset.getRowData(row))
    if len(row_contents) > 0:
      stream.write(string.join(map(str, row_contents), '\t'))
      stream.write(os.linesep)
                 
###############
# Tests
import unittest

class NNTestCases(unittest.TestCase):
  def setUp(self):
    data = [[1,1,1,1,1,1],
            [1,1,1,1,1,2.1],
            [1,1,1,1,1,3],
            [1,1,1,1,1.9,1],
            [1,1,1,3,1,1],
            [1,1,4,1,1,1],
            [2,2,2,2,2,2],
            [3,3,3,3,3,3],
            [-1,1,1,1,1,1],
            [1,1,1,1,1,0.2]]
    self.dataset = Dataset(data)

  def tearDown(self):
    pass


  def testNN(self):
    v = SortedView(self.dataset)

    point = self.dataset.getRowData(0)
    v.sort(lambda x: EuclideanDistance(x, point))
    v.writeDataset()

def suite():
  return unittest.makeSuite(NNTestCases)

if __name__=="__main__":
  unittest.main(defaultTest="suite")
