# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Some might be indirectly tested already in ``astropy.io.fits.tests``.
"""
import io

import numpy as np
import pytest

from astropy.utils.diff import diff_values, report_diff_values, where_not_allclose
from astropy.table import Table


@pytest.mark.parametrize('a', [np.nan, np.inf, 1.11, 1, 'a'])
def test_diff_values_false(a):
    assert not diff_values(a, a)


@pytest.mark.parametrize(
    ('a', 'b'),
    [(np.inf, np.nan), (1.11, 1.1), (1, 2), (1, 'a'), ('a', 'b')])
def test_diff_values_true(a, b):
    assert diff_values(a, b)


def test_float_comparison():
    """
    Regression test for https://github.com/spacetelescope/PyFITS/issues/21
    """
    f = io.StringIO()
    a = np.float32(0.029751372)
    b = np.float32(0.029751368)
    identical = report_diff_values(a, b, fileobj=f)
    assert not identical
    out = f.getvalue()

    # This test doesn't care about what the exact output is, just that it
    # did show a difference in their text representations
    assert 'a>' in out
    assert 'b>' in out


def test_diff_types():
    """
    Regression test for https://github.com/astropy/astropy/issues/4122
    """
    f = io.StringIO()
    a = 1.0
    b = '1.0'
    identical = report_diff_values(a, b, fileobj=f)
    assert not identical
    out = f.getvalue()
    assert out == ("  (float) a> 1.0\n"
                   "    (str) b> '1.0'\n"
                   "           ? +   +\n")

def test_diff_numeric_scalar_types():
    """ Test comparison of different numeric scalar types. """
    f = io.StringIO()
    assert not report_diff_values(1.0, 1, fileobj=f)
    out = f.getvalue()
    assert out == '  (float) a> 1.0\n    (int) b> 1\n'

def test_array_comparison():
    """
    Test diff-ing two arrays.
    """
    f = io.StringIO()
    a = np.arange(9).reshape(3, 3)
    b = a + 1
    identical = report_diff_values(a, b, fileobj=f)
    assert not identical
    out = f.getvalue()
    assert out == ('  at [0, 0]:\n'
                   '    a> 0\n'
                   '    b> 1\n'
                   '  at [0, 1]:\n'
                   '    a> 1\n'
                   '    b> 2\n'
                   '  at [0, 2]:\n'
                   '    a> 2\n'
                   '    b> 3\n'
                   '  ...and at 6 more indices.\n')


def test_diff_shaped_array_comparison():
    """
    Test diff-ing two differently shaped arrays.
    """
    f = io.StringIO()
    a = np.empty((1, 2, 3))
    identical = report_diff_values(a, a[0], fileobj=f)
    assert not identical
    out = f.getvalue()
    assert out == ('  Different array shapes:\n'
                   '    a> (1, 2, 3)\n'
                   '     ?  ---\n'
                   '    b> (2, 3)\n')


def test_tablediff():
    """
    Test diff-ing two simple Table objects.
    """
    a = Table.read("""name    obs_date    mag_b  mag_v
M31     2012-01-02  17.0   16.0
M82     2012-10-29  16.2   15.2
M101    2012-10-31  15.1   15.5""", format='ascii')
    b = Table.read("""name    obs_date    mag_b  mag_v
M31     2012-01-02  17.0   16.5
M82     2012-10-29  16.2   15.2
M101    2012-10-30  15.1   15.5
NEW     2018-05-08   nan    9.0""", format='ascii')
    f = io.StringIO()
    identical = report_diff_values(a, b, fileobj=f)
    assert not identical
    out = f.getvalue()
    assert out == ('     name  obs_date  mag_b mag_v\n'
                   '     ---- ---------- ----- -----\n'
                   '  a>  M31 2012-01-02  17.0  16.0\n'
                   '   ?                           ^\n'
                   '  b>  M31 2012-01-02  17.0  16.5\n'
                   '   ?                           ^\n'
                   '      M82 2012-10-29  16.2  15.2\n'
                   '  a> M101 2012-10-31  15.1  15.5\n'
                   '   ?               ^\n'
                   '  b> M101 2012-10-30  15.1  15.5\n'
                   '   ?               ^\n'
                   '  b>  NEW 2018-05-08   nan   9.0\n')

    # Identical
    assert report_diff_values(a, a, fileobj=f)


@pytest.mark.parametrize('kwargs', [{}, {'atol': 0, 'rtol': 0}])
def test_where_not_allclose(kwargs):
    a = np.array([1, np.nan, np.inf, 4.5])
    b = np.array([1, np.inf, np.nan, 4.6])

    assert where_not_allclose(a, b, **kwargs) == ([3], )
    assert len(where_not_allclose(a, a, **kwargs)[0]) == 0
