#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
# Last Modified: 29-Aug-2002, 2:37p
#
#
# Given a dataset and a full GTR tree, writes out a PPM picture of the
# tree structure and the data
#

import sys
import MLab
import string

def dumpGTR(gtrTree, dataset, filename, padding=8, pix_width=5, pix_height=5):
  """
  Write out a graphical representataion of a GTR clustering.

  gtrTree    : an XClustTree object
  dataset    : a Dataset object
  filename   : output file
  padding    : (optional) number of pixels between tree levels
  pix_width  : (optional) width of blocks of for the dataset
  pix_height : (optional) height of blocks of for the dataset
  """

  K = dataset.getNumRows()
  N = dataset.getNumCols()
  
  #
  # Fetch the maximum height of the tree

  tree_height = process_tree(gtrTree)
  
  #
  # Parameterize the image

  horz_pad   = padding * (tree_height + 1)
  img_width  = (pix_width * N) + horz_pad
  img_height = (pix_height * K)

  img = MLab.ones((img_height, img_width, 3))
  img = img * 255

  #
  # Create the color table
  # Spread from red to black to green
  
  color_level = 100
  
  color_table = MLab.zeros((color_level, 3))
  color_table[0] = (0, 255, 0)
  color_table[color_level - 1] = (255, 0, 0)

  m = color_level / 2
    
  diff = color_table[m] - color_table[0]
  for l in range(m):
    color_table[l] = color_table[0] + (diff * l / m)

  diff = color_table[color_level - 1] - color_table[m]
  for l in range(m+1, color_level):
    color_table[l] = color_table[m] + (diff * (l-m) / (color_level - m - 1))

  #
  # Find the range of the data

  data = dataset.getData()

  min = MLab.min(MLab.min(data))
  max = MLab.max(MLab.max(data))

  v = max - min

  #
  # Run through the tree which should give us the proper order to
  # iterate through the data

  order = gtraverse(gtrTree, img, tree_height, padding, pix_height)

  #
  # plot all the data in the numeric array

  for k in range(K):
    for n in range(N):
      l = (color_level - 1) * (data[order[k], n] - min) / v
      x1 = (pix_width * n) + horz_pad
      x2 = x1 + pix_width
      y1 = pix_height * k
      y2 = y1 + pix_height
      img[y1:y2,x1:x2] = color_table[int(l)]

  #
  # write out the picture to file

  fp = open(filename, 'w')

  fp.write("P6\n")
  fp.write(str(img_width) + " " + str(img_height) + "\n")
  fp.write("255\n")
  
  for i in range(img_height):
    for j in range(img_width):
      for k in range(3):
        fp.write(chr(img[i][j][k]))
        
  fp.close()

def process_tree(tree):

  # compute the maximum depth of the tree
  
  iter = tree.iterator('LRC')
  node = iter.next()
  max_depth = 0
  
  while node != None:
    if tree.isLeaf(node.key()):
      depth = tree.depth(node.key()) 
      if depth > max_depth:
        max_depth = depth
      
    node = iter.next()

  return max_depth

def gtraverse(tree, img, tree_height, padding, pix_height):

  order  = []
  right  = (tree_height + 1) * padding
  top    = 0
  bottom = img.shape[0] - 1
  total  = img.shape[0] / pix_height
  offset = pix_height / 2
  sum    = 0
  
  iter = tree.iterator('LRC')
  node = iter.next()

  while node != None:

    left = (tree.depth(node.key()) + 1) * padding

    if tree.isLeaf(node.key()):
      sum += 1

      tmp = total - sum
      i = (tmp * bottom) / total
      i += offset
      node.pix = i
      for j in range(left, right):
        img[i][j] = (0,0,0)

      order.append(int(string.split(string.split(node.key(),"GENE")[1],"X")[0]))
        
    else:

      children = tree.children(node.key())
      ia = tree.find(children[1]).pix
      ib = tree.find(children[0]).pix
      node.pix = (ia + ib) / 2
      
      for i in range(ia, ib+1):
        img[i][left] = (0,0,0)
      for j in range(padding):
        img[ia][left+j] = (0,0,0)
        img[ib][left+j] = (0,0,0)
        
    node = iter.next()
  
  order.reverse()
  return order

def main():

  from compClust.mlx.XClustTree import XClustTree
  from compClust.mlx.datasets import Dataset

  if len(sys.argv) == 1:
    sys.exit(0)

  tree = XClustTree()
  tree.read(sys.argv[1])

  if len(sys.argv) == 4:
    dumpGTR(tree, Dataset(sys.argv[2]), sys.argv[3])

  if len(sys.argv) == 5:
    dumpGTR(tree, Dataset(sys.argv[2]), sys.argv[3], int(sys.argv[4]))

  if len(sys.argv) == 6:
    dumpGTR(tree, Dataset(sys.argv[2]), sys.argv[3], int(sys.argv[4]),
            int(sys.argv[5]))

  if len(sys.argv) == 7:
    dumpGTR(tree, Dataset(sys.argv[2]), sys.argv[3], int(sys.argv[4]),
            int(sys.argv[5]), int(sys.argv[6]))

if (__name__ == "__main__"):
  main()
