# -*- coding: utf-8 -*-
"""
Unit tests for nltk.classify. See also: nltk/test/classify.doctest
"""
from __future__ import absolute_import
from nose import SkipTest
from nltk import classify

TRAIN = [
    (dict(a=1, b=1, c=1), 'y'),
    (dict(a=1, b=1, c=1), 'x'),
    (dict(a=1, b=1, c=0), 'y'),
    (dict(a=0, b=1, c=1), 'x'),
    (dict(a=0, b=1, c=1), 'y'),
    (dict(a=0, b=0, c=1), 'y'),
    (dict(a=0, b=1, c=0), 'x'),
    (dict(a=0, b=0, c=0), 'x'),
    (dict(a=0, b=1, c=1), 'y'),
]

TEST = [
    (dict(a=1, b=0, c=1)),  # unseen
    (dict(a=1, b=0, c=0)),  # unseen
    (dict(a=0, b=1, c=1)),  # seen 3 times, labels=y,y,x
    (dict(a=0, b=1, c=0)),  # seen 1 time, label=x
]

RESULTS = [(0.16, 0.84), (0.46, 0.54), (0.41, 0.59), (0.76, 0.24)]


def assert_classifier_correct(algorithm):
    try:
        classifier = classify.MaxentClassifier.train(
            TRAIN, algorithm, trace=0, max_iter=1000
        )
    except (LookupError, AttributeError) as e:
        raise SkipTest(str(e))

    for (px, py), featureset in zip(RESULTS, TEST):
        pdist = classifier.prob_classify(featureset)
        assert abs(pdist.prob('x') - px) < 1e-2, (pdist.prob('x'), px)
        assert abs(pdist.prob('y') - py) < 1e-2, (pdist.prob('y'), py)


def test_megam():
    assert_classifier_correct('MEGAM')


def test_tadm():
    assert_classifier_correct('TADM')
