########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
This provides some visualizations for model views

"""

# standard modules

# contrib modules
import visual
import MLab


# local modules 

from compClust.mlx import Dataset
from compClust.mlx import Labeling
from compClust.mlx import Model

from compClust.util import transformations
from compClust.util import DistanceMetrics
from compClust.util import WrapperUtil
from compClust.util import Usage

def plotPointsFromDataset(display, dataset, color):

    """
    plotPointsFromDataset(Dataset, color)

    Simple function that plots the points from a dataset in a
    paticular color
    """

    display.select()
    for datavector in dataset.getData():
        position = visual.vector(datavector[0],datavector[1], datavector[2])
        visual.sphere(pos=position, radius=.1, color=color)


def getMouseClicks(display):

    """
    Traps the mouse clicks in a given display

    NOTE:  NOT FULLY IMPLIMENTED YET

    """
    
    while not None:

        print display.mouse.pick
        

def distanceFromMeansModel3dView(dataset, labeling):

    """
    distanceFromMeansModel3dView(dataset, labeling)

    ##FIXME: this should take in just a model and model should contain
    ##both the labeling and the dataset... but for now that isn't the
    ##case

    using the visual package for display, this projects the model's
    dataset into it's top 3 principle components and then draws the
    data points in space and each cluster is represented as a sphere
    with the standard deviation of each cluster determining the
    sphere's size
    
    """


    displayWindow = visual.display()
    pcaDataset = transformations.PCA(dataset, dims=3)
    clusters = pcaDataset.subsets(labeling)


    count = 0
    clusterKeys = clusters.keys()
    clusterKeys.sort()
    for clusterKey in clusterKeys:
        hue = (1.0/labeling.get_k())*count
        color = visual.color.hsv_to_rgb((hue, .8, .8))        
        plotPointsFromDataset(displayWindow, clusters[clusterKey], color)
        
        ## FIXME: Perhaps this should use a Model class, but for now
        ## FIXME: it doesn't give me any std dev - thus it is of little
        ## FIXME: use
        meanVector = visual.vector(MLab.mean(clusters[clusterKey].getData()))
        try:
            std  = MLab.mean(MLab.std(clusters[clusterKey].getData()))
        except:
            std = None

        if std:
            visual.ring(pos=meanVector, axis=(0,1,0), radius=std, thickness=.05, color=color)
            visual.ring(pos=meanVector, axis=(1,0,0), radius=std, thickness=.05, color=color)
        
        visual.label(pos=meanVector, text=clusterKey, xoffset=10, yoffset=10, space=std, height=10, border=6)
        count +=1


def MoGModelView(dataset, labeling):

    """
    MOGModelView(dataset, labeling)
    
    ##FIXME: this should take in just a model and model should contain
    ##both the labeling and the dataset... but for now that isn't the
    ##case

    Using the visual package displays the dataset in its top 3 PCA dimensions
    coloring each data point with respect to its cluster.... the hue is set
    by cluster membership, the value is set by distance to the cluster
    according to the Maholonobis distance.
    
    """

    displayWindow = visual.display()
    pcaDataset = transformations.PCA(dataset, dims=3)
    clusters = pcaDataset.subsets(labeling)
    
    means = Model.compute_model_means(pcaDataset, labeling)
    covariances, weights = Model.compute_model_covariances_weights(pcaDataset, labeling, means)

    for c in range(0, len(means)):

        hue = (1.0/labeling.get_k())*c
        saturation = .8
        points = []

        for vector in clusters[labeling.classes_to_label_map[c]].getData():
            try:
                distance = DistanceMetrics.MahalanobisDistance(vector, means[c], covariances[c])
            except:
                print "distance calculation failed"
                distance = 1
            points.append((distance, vector))

        maxDistance = MLab.max(map(lambda x: x[0], points))
        for point in points:
            datavector = point[1]
            value = 1-(point[0]/maxDistance)
            position = visual.vector(datavector[0],datavector[1], datavector[2])
            color = visual.color.hsv_to_rgb((hue, saturation, value))
            visual.sphere(pos=position, radius=.1, color=color)

        position = visual.vector(means[c][0], means[c][1], means[c][2])
        visual.label(pos=position, text=labeling.classes_to_label_map[c],
                     xoffset=10, yoffset=10, space=0, height=10, border=6)




def main(opts, argv):

    """
    modelViews -v <model> -d dataset -l labeling

    modelViews will produce a interactive 3d plot of the given dataset projected
    into it's top 3 principle components.  Depneding on the view selected various
    coloring schemes will be imposed

    possible models:

    mog:  Mixture Of Gaussians.  This view fits your data into a
    mixture of gaussians with a full covariance matrix and displays
    each cluster in a seperate color with the intensity depeneding on
    the malhobinobis distance of the datapoint to the cluster center 

    distance: distance from means model.  This view colors each
    cluster a different color but then draws the 1 std spheres around
    the cluster centers

    options:

    -v or --view
    -d or --dataset
    -l or --labeling

    """

    possibleViews = ['mog', 'distance']

    if opts.has_key('--help'):
        Usage.showHelp(main, exit=1)
    if opts.has_key('-h'):
        Usage.showUsage(main, exit=1)

    datasetName = opts.get('-d', opts.get('--dataset', None))
    labelingName = opts.get('-l', opts.get('--labeling', None))
    model = opts.get('-v', opts.get('--view'))

    if (datasetName and labelingName and (model in possibleViews)):
        dataset = Dataset.Dataset(datasetName)
        labeling = Labeling.Labeling(labelingName)
        
        if model == "distance":
            distanceFromMeansModel3dView(dataset, labeling)

        elif model == "mog":
            MoGModelView(dataset, labeling)

    else:
        Usage.showHelp(main, exit=1)
    
    

if __name__ == "__main__":

    opts, args = WrapperUtil.createOptTree("v:d:l:h",
                                           ["view=",
                                            "dataset=",
                                            "labeling=",
                                            "help"])
    main(opts, args)



