import string
import itertools

from compClust.gui.LabelingSource import LabelingSource
from compClust.mlx.labelings.Labeling import Labeling, ClusteredLabeling

def get_labelings_by_clustering(dataset, l, output_labeling_names):
  labels = l.getLabels()
  labels.sort()
  output = {}
  for label in labels:
    keys = l.getKeysByLabel(label)
    label_output = {}
    output[label] = label_output
    for output_name in output_labeling_names:
      output_labeling = dataset.getLabeling(output_name)
      label_output[output_name] = output_labeling.getLabelByKeys(keys)
  return output

def dump_labels_by_clustering(filename, output_labeling_names, output_labeling_values):
  print filename
  file = open(filename, 'w')
  cluster_names = output_labeling_values.keys()
  cluster_names.sort()
  file.write("cluster_name\t%s\n" % (string.join(output_labeling_names, '\t')))
  for cluster_name in cluster_names:
    cluster_contents = output_labeling_values[cluster_name]
    parallel_lists = [cluster_contents[x] for x in output_labeling_names]
    for index in xrange(len(parallel_lists[0])):
      row_detail = [cluster_name] + [ x[index] for x in parallel_lists ]
      file.write(string.join(map(str, row_detail), '\t')+'\n')
  file.close()

def dump_clusterings_for_dataset(dataset, output_labeling_names, LabelingSubClass=ClusteredLabeling):
  cluster_labelings = [ x for x in dataset.getLabelings() if isinstance(x, LabelingSubClass) ]

  for labeling in cluster_labelings:
    print labeling
    labels_by_clusters = get_labelings_by_clustering(dataset, labeling, output_labeling_names)
    print len(labels_by_clusters)
    dump_labels_by_clustering(dataset.getName()+labeling.getName(), \
                              output_labeling_names,
                              labels_by_clusters)
             
def dump_all_clusterings(dataset_list, output_labeling_names):
  for dataset in dataset_list:
    dump_clusterings_for_dataset(dataset, output_labeling_names)

# call to run the above code
# brian.dump_all_clusterings([d.dataset for d in s.datamanager.values()], output_labeling_names)


#####
# brian's nearest neighbor request

import nn
from kenji import output_labeling_names 

from compClust.util.DistanceMetrics import CorrelationDistance
def brian_nn(dataset, nn_list = ['X72914_at',  'L07925_at', 'X15958_at']):
  metric = CorrelationDistance
  output_labelings = list(output_labeling_names) + [metric.__name__]
  
  for nn_item in nn_list:
    sorted_set, distances=nn.nn_row(dataset, 'Probe Set ID', nn_item, metric)
    stream = open("%s_%s_%s" % (dataset.name, metric.__name__, nn_item), 'w')
    nn.dataset_row_report(stream, sorted_set, output_labelings)

def copy_labeling(destination, source, labeling_name, destination_name=None):
  """Move labeling (attached to source) from source to destination
  (using primary labeling to make sure the keys copy correctly)
  """
  labeling = source.get_labeling_by_name(labeling_name)

  # FIXME: this really should be a more general way of grabbing
  # FIXME: the primary key labels, as source only sometimes contains
  # FIXME: the 
  #for primary_label in source.primary.source:
  #  destination.get

  source_primary_labeling = source.dataset.getLabeling(source.primary.name)
  destination_primary_labeling = source.dataset.getLabeling(source.primary.name)
  if destination_name is None:
    destination_name = source.name + '_' + labeling.name

  destination_labeling = Labeling(destination.dataset, destination_name)

  for row in xrange(source.dataset.numRows):
    # look up our primary label name for this row
    source_pkey_list = source_primary_labeling.getLabelsByRow(row)
    if len(source_pkey_list) != 1:
      raise ValueError("Wrong number of primary keys for row %d, got %s" % \
                       (len(source_pkey_list), source_pkey_list))
    # get the key for the source labeling
    source_key_list = source_primary_labeling.getKeysByLabel(source_pkey_list[0])
    if len(source_key_list) != 1:
      raise ValueError("source primary labeling had %d keys, which were %s" % \
                       (len(source_pkey_list), source_pkey_list))

    # get the label for the key
    label_to_copy = labeling.getLabelByKey(source_key_list[0])
    # convert the name into a key pointing into the destination dataset
    destination_key_list = destination_primary_labeling.getKeysByLabel(source_pkey_list[0])
    if len(destination_key_list) != 1:
      raise ValueError("destination primary labeling had %d keys, which were %s" % \
                       (len(destination_key_list), source_pkey_list))

    destination_labeling.addLabelToKey(label_to_copy, destination_key_list[0])

  print destination_labeling
  source = LabelingSource(destination_name, isrow=True, isannotation=False)
  destination._labeling_sources[destination_name] = source
  print destination.dataset.getLabelings()


# this is just a convienence function
# to copy the clustering over for the current variant datasets of interest.
def brian_copy_mock_to_fcig(datamanager, k_list):
  import re
  labeling_name = "MultiRun_cor mul=%s mul=5 mul=k" % str(k_list)
  mock = [ x for x in datamanager.values() if re.search("Mock", x.name) ][0]
  fcig = [ x for x in datamanager.values() if re.search("FcIg", x.name) ][0]
  copy_labeling(fcig, mock, labeling_name )
