########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
# Written By:  Lucas Scharenbroich
# Date      :  August 2001
# 
# Purpose   :  Agglomerate over any tree structure which conforms to the tree
#              object schema
#

import sys
import Numeric
import string
import re

class TreeAgglomerator:
    """ Agglomerates a tree structure

    """

    def __init__(self, tree):
        """  """

        self.tree     = tree
        self.root     = tree.root()

        #
        # Build a list of leaves for fast access in getGenes
        # To be able to retrieve slices, also maintain a dictionary of
        # node -> index pairs
        #

        if tree is not None:
        
            self.leafIndex = {}
            self.leaves, self.leafCount = self.__buildLeafList()

            for i in range(len(self.leaves)):
                self.leafIndex[self.leaves[i]] = i

    def __buildLeafList(self):
        """
        returns a list of leaf nodes in order from left to right and
        a list of all nodes in the tree
        """
        
        iter  = self.tree.iterator('LRC')
        
        leafList  = []
        leafCount = 0
        
        node  = iter.next()
        while node != None:
            if self.tree.isLeaf(node.key()):
                leafCount += 1
                leafList += [node]
            node = iter.next()

        return leafList, leafCount 
    
    def __getLeaves(self, root='root'):

        """ return all the leaves within the current tree starting at
        subtree with a root 'node' """

        if root == 'root':
            root = self.root
            
        iter = self.tree.iterator('LCR')
        iter.setRoot(root)
                
        #
        # Now retrieve the first and last leaf of the subtree.  Since the
        # leaves linked together, we can access the leaves we want in
        # lg(n) time 
        #

        first = iter.firstLeaf()
        last  = iter.lastLeaf()

        start  = self.leafIndex[first]
        finish = self.leafIndex[last]+1

        return self.leaves[start:finish] 

    def agglomerateWithSizeThreshold(self, sizeThreshold, root='root', \
                                     mode='CLR'):

        """ return a set of lists of genes which have been partitioned
        using a size aglomeration"""

        clusters = []

        if root == "root":
            root = self.root

        #
        # Do a preorder iteration so we see nodes first
        #
        
        iter = self.tree.iterator(mode)
        iter.setRoot(root)
        
        node = iter.next()
        tree = self.tree
        
        #
        # If a given node does not satisfy our agglomeration criteria, then
        # halt at this node and back up the tree
        #
        
        while node != None:
            key = node.key()
            
            leafList = self.__getLeaves(key)
            
            if (len(leafList) <= sizeThreshold) or tree.isLeaf(key):
                clusters += [leafList]
                iter.reverse()

            node = iter.next()
                        
        return clusters


    def getKClusters(self, k, mode='CLR'):

        """
        performs a binary search of the agglomeration threshold
        space to determine the threshold needed to return as close to
        K clusters as possible

        """
        
        upper = self.leafCount
        lower = 0
        clusters = self.agglomerateWithSizeThreshold(upper, 'root', mode)
        while ((len(clusters) != k) and (upper != lower+1)):            
            threshold = (upper+lower) / 2
            clusters = self.agglomerateWithSizeThreshold(threshold, 'root', \
                                                         mode)
            
            if len(clusters) > k:
                # too many clusters.. make the threshold larger
                lower = threshold
            else:
                # too few clusters.. make the threshold smaller
                upper = threshold
        
        return clusters


