#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

import unittest
import os
import Numeric

from compClust.util import DistanceMetrics
import compClust.util

class DistanceMetricsTestCases(unittest.TestCase):
  def setUp(self):
    self.original_dir    = os.getcwd()
    os.chdir(compClust.util.__path__[0])

  def tearDown(self):
    os.chdir(self.original_dir)

  def test_mahalanobis(self):
    a = Numeric.array([1,1]);
    b = Numeric.array([0,0]);
    c = Numeric.array([[1,0],[0,1]])
    d = DistanceMetrics.MahalanobisDistance(a,b,c)
    assert d == [2.0]

    b = Numeric.array([2,1]);
    c = Numeric.array([[2,0],[0,1]])
    d = DistanceMetrics.MahalanobisDistance(a,b,c)
    assert d == [0.5]

    b = Numeric.array([[0,0],[2,1]])
    c = Numeric.array([[[1,0],[0,1]],[[2,0],[0,1]]])
    d = DistanceMetrics.MahalanobisDistance(a,b,c)
    assert d == [2.0, 0.5]
    
  def test_ranks(self):

    a = Numeric.array([2,3,4,5,6,7,8])
    b = DistanceMetrics.ranks(a)
    assert b == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]

    a = Numeric.array([8,7,6,5,4])
    b = DistanceMetrics.ranks(a)
    assert b == [5.0, 4.0, 3.0, 2.0, 1.0]

    a = Numeric.array([1, 2, 3, 2, 4, 5, 1, 6])
    b = DistanceMetrics.ranks(a)
    assert b == [1.5, 3.5, 5.0, 3.5, 6.0, 7.0, 1.5, 8.0]

    a = Numeric.array([[1,2,3],[1,1,1],[1,2,2],[2,4,3],[4,3,2]])
    b = DistanceMetrics.ranks(a)
    assert b == [[1., 2.,  3. ],
                 [2., 2.,  2. ],
                 [1., 2.5, 2.5],
                 [1., 3.,  2. ],
                 [3., 2.,  1. ]]

  def test_negative_correlation(self):
    a = Numeric.array([1.,  1.])
    b = Numeric.array([-1., -1.])
    c = Numeric.array([0., 0. ])

    assert DistanceMetrics.CorrelationDistance(a,b) == -1.0
    assert DistanceMetrics.PearsonCorrelation(a,c,b) == -1.0
    assert DistanceMetrics.SpearmanCorrelation(a,b) == -1.0

  def test_zero_correlation(self):
    a = Numeric.array([1.,  1.])
    b = Numeric.array([1., -1.])
    c = Numeric.array([0., 0. ])

    assert DistanceMetrics.CorrelationDistance(a,b) == 0.0
    assert DistanceMetrics.PearsonCorrelation(a,c,b) == 0.0
    assert DistanceMetrics.SpearmanCorrelation(a,b) == 0.0
    
  def test_identity(self):

    a = Numeric.array([1., 1., 1.])
    b = Numeric.array([1., 1., 1.])
    c = Numeric.array([0., 0., 0.])
    
    assert DistanceMetrics.EuclideanDistance(a,b) == 0.0
    assert DistanceMetrics.ManhattanDistance(a,b) == 0.0
    assert DistanceMetrics.MaximumDistance(a,b) == 0.0
    assert DistanceMetrics.CorrelationDistance(a,b) == 1.0
    assert DistanceMetrics.PearsonCorrelation(a,c,b) == 1.0
    assert DistanceMetrics.SpearmanCorrelation(a,b) == 1.0

  def test_distance(self):
    a = Numeric.array([1, 2,4,5, 3])
    b = Numeric.array([5,-2,3,1,-8])

    assert DistanceMetrics.EuclideanDistance(a,b) == DistanceMetrics.distance(a,b,'euclidean')
    assert DistanceMetrics.ManhattanDistance(a,b) == DistanceMetrics.distance(a,b,'manhattan')
    assert DistanceMetrics.MaximumDistance(a,b) == DistanceMetrics.distance(a,b,'maximum')
    assert DistanceMetrics.CorrelationDistance(a,b) == DistanceMetrics.distance(a,b,'correlation')
    assert DistanceMetrics.SpearmanCorrelation(a,b) == DistanceMetrics.distance(a,b,'spearman')
    
def suite(**kw):
  return unittest.makeSuite(DistanceMetricsTestCases)

  return suite

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
