from __future__ import print_function, absolute_import

import numpy as np
import math

import numba.unittest_support as unittest
from numba import roc
from numba import utils


class TestMath(unittest.TestCase):
    def _get_tol(self, math_fn, ty):
        """gets the tolerance for functions when the input is of type 'ty'"""

        low_res = {
            (math.gamma, np.float64): 1e-14,
            (math.lgamma, np.float64): 1e-13,
            (math.asin, np.float64): 1e-9,
            (math.acos, np.float64): 4e-9,
            (math.sqrt, np.float64): 2e-8,
        }
        default = 1e-15 if ty == np.float64 else 1e-6
        return low_res.get((math_fn, ty), default)

    def _generic_test_unary(self, math_fn, npy_fn,
                            cases=None,
                            span=(-1., 1.), count=128,
                            types=(np.float32, np.float64)):

        @roc.jit
        def fn(dst, src):
            i = roc.get_global_id(0)
            if i < dst.size:
                dst[i] = math_fn(src[i])

        for dtype in types:
            if cases is None:
                src = np.linspace(span[0], span[1], count).astype(dtype)
            else:
                src = np.array(cases, dtype=dtype)

            dst = np.zeros_like(src)
            fn[src.size, 1](dst, src)
            np.testing.assert_allclose(dst, npy_fn(src),
                                       rtol=self._get_tol(math_fn, dtype),
                                       err_msg='{0} ({1})'.format(
                                           math_fn.__name__,
                                           dtype.__name__))

    def _generic_test_binary(self, math_fn, npy_fn,
                             cases=None,
                             span=(-1., 1., 1., -1.), count=128,
                             types=(np.float32, np.float64)):

        @roc.jit
        def fn(dst, src1, src2):
            i = roc.get_global_id(0)
            if i < dst.size:
                dst[i] = math_fn(src1[i], src2[i])

        for dtype in types:
            if cases is None:
                src1 = np.linspace(span[0], span[1], count).astype(dtype)
                src2 = np.linspace(span[2], span[3], count).astype(dtype)
            else:
                src1 = np.array(cases[0], dtype=dtype)
                src2 = np.array(cases[1], dtype=dtype)

            dst = np.zeros_like(src1)
            fn[dst.size, 1](dst, src1, src2)
            np.testing.assert_allclose(dst, npy_fn(src1, src2),
                                       rtol=self._get_tol(math_fn, dtype),
                                       err_msg='{0} ({1})'.format(
                                           math_fn.__name__,
                                           dtype.__name__))

    def test_trig(self):
        funcs = [math.sin, math.cos, math.tan]

        for fn in funcs:
            self._generic_test_unary(fn, getattr(np, fn.__name__),
                                     span=(-np.pi, np.pi))

    def test_trig_inv(self):
        funcs = [(math.asin, np.arcsin),
                 (math.acos, np.arccos),
                 (math.atan, np.arctan)]

        for fn, np_fn in funcs:
            self._generic_test_unary(fn, np_fn)

    def test_trigh(self):
        funcs = [math.sinh, math.cosh, math.tanh]
        for fn in funcs:
            self._generic_test_unary(fn, getattr(np, fn.__name__),
                                     span=(-4.0, 4.0))

    def test_trigh_inv(self):
        funcs = [(math.asinh, np.arcsinh, (-4, 4)),
                 (math.acosh, np.arccosh, (1, 9)),
                 (math.atanh, np.arctanh, (-0.9, 0.9))]

        for fn, np_fn, span in funcs:
            self._generic_test_unary(fn, np_fn, span=span)

    def test_classify(self):
        funcs = [math.isnan, math.isinf]
        cases = (float('nan'), float('inf'), float('-inf'), float('-nan'),
                 0, 3, -2)
        for fn in funcs:
            self._generic_test_unary(fn, getattr(np, fn.__name__),
                                     cases=cases)

    def test_floor_ceil(self):
        funcs = [math.ceil, math.floor]

        for fn in funcs:
            # cases with varied decimals
            self._generic_test_unary(fn, getattr(np, fn.__name__),
                                     span=(-1013.14, 843.21))
            # cases that include "exact" integers
            self._generic_test_unary(fn, getattr(np, fn.__name__),
                                     span=(-16, 16), count=129)

    def test_fabs(self):
        funcs = [math.fabs]
        for fn in funcs:
            self._generic_test_unary(fn, getattr(np, fn.__name__),
                                     span=(-63.3, 63.3))

    def test_unary_exp(self):
        funcs = [math.exp]
        for fn in funcs:
            self._generic_test_unary(fn, getattr(np, fn.__name__),
                                     span=(-30, 30))

    def test_unary_expm1(self):
        funcs = [math.expm1]
        for fn in funcs:
            self._generic_test_unary(fn, getattr(np, fn.__name__),
                                     span=(-30, 30))

    def test_sqrt(self):
        funcs = [math.sqrt]
        for fn in funcs:
            self._generic_test_unary(fn, getattr(np, fn.__name__),
                                     span=(0, 1000))

    def test_log(self):
        funcs = [math.log, math.log10, math.log1p]
        for fn in funcs:
            self._generic_test_unary(fn, getattr(np, fn.__name__),
                                     span=(0.1, 2500))

    def test_binaries(self):
        funcs = [math.copysign, math.fmod]
        for fn in funcs:
            self._generic_test_binary(fn, getattr(np, fn.__name__))

    def test_pow(self):
        funcs = [(math.pow, np.power)]
        for fn, npy_fn in funcs:
            self._generic_test_binary(fn, npy_fn)

    def test_atan2(self):
        funcs = [(math.atan2, np.arctan2)]
        for fn, npy_fn in funcs:
            self._generic_test_binary(fn, npy_fn)

    def test_erf(self):
        funcs = [math.erf, math.erfc]
        for fn in funcs:
            self._generic_test_unary(fn, np.vectorize(fn))

    def test_gamma(self):
        funcs = [math.gamma, math.lgamma]
        for fn in funcs:
            self._generic_test_unary(fn, np.vectorize(fn), span=(1e-4, 4.0))


if __name__ == '__main__':
    unittest.main()
