"""
Tests for @cfunc and friends.
"""

from __future__ import division, print_function, absolute_import

import ctypes
import os
import subprocess
import sys
from collections import namedtuple

import numpy as np

from numba import unittest_support as unittest
from numba import cfunc, carray, farray, types, typing, utils, njit
from numba import cffi_support, numpy_support
from .support import TestCase, tag, captured_stderr
from .test_dispatcher import BaseCacheTest

skip_cffi_unsupported = unittest.skipUnless(
    cffi_support.SUPPORTED,
    "CFFI not supported -- please install the cffi module",
)


def add_usecase(a, b):
    return a + b

def div_usecase(a, b):
    c = a / b
    return c

def square_usecase(a):
    return a ** 2

add_sig = "float64(float64, float64)"

div_sig = "float64(int64, int64)"

square_sig = "float64(float64)"

def objmode_usecase(a, b):
    object()
    return a + b

# Test functions for carray() and farray()

CARRAY_USECASE_OUT_LEN = 8

def make_cfarray_usecase(func):

    def cfarray_usecase(in_ptr, out_ptr, m, n):
        # Tuple shape
        in_ = func(in_ptr, (m, n))
        # Integer shape
        out = func(out_ptr, CARRAY_USECASE_OUT_LEN)
        out[0] = in_.ndim
        out[1:3] = in_.shape
        out[3:5] = in_.strides
        out[5] = in_.flags.c_contiguous
        out[6] = in_.flags.f_contiguous
        s = 0
        for i, j in np.ndindex(m, n):
            s += in_[i, j] * (i - j)
        out[7] = s

    return cfarray_usecase

carray_usecase = make_cfarray_usecase(carray)
farray_usecase = make_cfarray_usecase(farray)


def make_cfarray_dtype_usecase(func):
    # Same as make_cfarray_usecase(), but with explicit dtype.

    def cfarray_usecase(in_ptr, out_ptr, m, n):
        # Tuple shape
        in_ = func(in_ptr, (m, n), dtype=np.float32)
        # Integer shape
        out = func(out_ptr, CARRAY_USECASE_OUT_LEN, np.float32)
        out[0] = in_.ndim
        out[1:3] = in_.shape
        out[3:5] = in_.strides
        out[5] = in_.flags.c_contiguous
        out[6] = in_.flags.f_contiguous
        s = 0
        for i, j in np.ndindex(m, n):
            s += in_[i, j] * (i - j)
        out[7] = s

    return cfarray_usecase

carray_dtype_usecase = make_cfarray_dtype_usecase(carray)
farray_dtype_usecase = make_cfarray_dtype_usecase(farray)

carray_float32_usecase_sig = types.void(types.CPointer(types.float32),
                                        types.CPointer(types.float32),
                                        types.intp, types.intp)

carray_float64_usecase_sig = types.void(types.CPointer(types.float64),
                                        types.CPointer(types.float64),
                                        types.intp, types.intp)

carray_voidptr_usecase_sig = types.void(types.voidptr, types.voidptr,
                                        types.intp, types.intp)


class TestCFunc(TestCase):

    @tag('important')
    def test_basic(self):
        """
        Basic usage and properties of a cfunc.
        """
        f = cfunc(add_sig)(add_usecase)

        self.assertEqual(f.__name__, "add_usecase")
        self.assertEqual(f.__qualname__, "add_usecase")
        self.assertIs(f.__wrapped__, add_usecase)

        symbol = f.native_name
        self.assertIsInstance(symbol, str)
        self.assertIn("add_usecase", symbol)

        addr = f.address
        self.assertIsInstance(addr, utils.INT_TYPES)

        ct = f.ctypes
        self.assertEqual(ctypes.cast(ct, ctypes.c_void_p).value, addr)

        self.assertPreciseEqual(ct(2.0, 3.5), 5.5)

    @tag('important')
    @skip_cffi_unsupported
    def test_cffi(self):
        from . import cffi_usecases
        ffi, lib = cffi_usecases.load_inline_module()

        f = cfunc(square_sig)(square_usecase)

        res = lib._numba_test_funcptr(f.cffi)
        self.assertPreciseEqual(res, 2.25)  # 1.5 ** 2

    def test_locals(self):
        # By forcing the intermediate result into an integer, we
        # truncate the ultimate function result
        f = cfunc(div_sig, locals={'c': types.int64})(div_usecase)
        self.assertPreciseEqual(f.ctypes(8, 3), 2.0)

    @tag('important')
    def test_errors(self):
        f = cfunc(div_sig)(div_usecase)

        with captured_stderr() as err:
            self.assertPreciseEqual(f.ctypes(5, 2), 2.5)
        self.assertEqual(err.getvalue(), "")

        with captured_stderr() as err:
            res = f.ctypes(5, 0)
            # This is just a side effect of Numba zero-initializing
            # stack variables, and could change in the future.
            self.assertPreciseEqual(res, 0.0)
        err = err.getvalue()
        self.assertIn("ZeroDivisionError: (\'division by zero\',)", err)
        if sys.version_info >= (3,):
            self.assertIn("Exception ignored", err)
        else:
            self.assertIn(" ignored", err)

    def test_llvm_ir(self):
        f = cfunc(add_sig)(add_usecase)
        ir = f.inspect_llvm()
        self.assertIn(f.native_name, ir)
        self.assertIn("fadd double", ir)

    def test_object_mode(self):
        """
        Object mode is currently unsupported.
        """
        with self.assertRaises(NotImplementedError):
            cfunc(add_sig, forceobj=True)(add_usecase)
        with self.assertTypingError() as raises:
            cfunc(add_sig)(objmode_usecase)
        self.assertIn("Untyped global name 'object'", str(raises.exception))


class TestCFuncCache(BaseCacheTest):

    here = os.path.dirname(__file__)
    usecases_file = os.path.join(here, "cfunc_cache_usecases.py")
    modname = "cfunc_caching_test_fodder"

    def run_in_separate_process(self):
        # Cached functions can be run from a distinct process.
        code = """if 1:
            import sys

            sys.path.insert(0, %(tempdir)r)
            mod = __import__(%(modname)r)
            mod.self_test()

            f = mod.add_usecase
            assert f.cache_hits == 1
            f = mod.outer
            assert f.cache_hits == 1
            f = mod.div_usecase
            assert f.cache_hits == 1
            """ % dict(tempdir=self.tempdir, modname=self.modname)

        popen = subprocess.Popen([sys.executable, "-c", code],
                                 stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out, err = popen.communicate()
        if popen.returncode != 0:
            raise AssertionError("process failed with code %s: stderr follows\n%s\n"
                                 % (popen.returncode, err.decode()))

    def check_module(self, mod):
        mod.self_test()

    @tag('important')
    def test_caching(self):
        self.check_pycache(0)
        mod = self.import_module()
        self.check_pycache(6)  # 3 index, 3 data

        self.assertEqual(mod.add_usecase.cache_hits, 0)
        self.assertEqual(mod.outer.cache_hits, 0)
        self.assertEqual(mod.add_nocache_usecase.cache_hits, 0)
        self.assertEqual(mod.div_usecase.cache_hits, 0)
        self.check_module(mod)

        # Reload module to hit the cache
        mod = self.import_module()
        self.check_pycache(6)  # 3 index, 3 data

        self.assertEqual(mod.add_usecase.cache_hits, 1)
        self.assertEqual(mod.outer.cache_hits, 1)
        self.assertEqual(mod.add_nocache_usecase.cache_hits, 0)
        self.assertEqual(mod.div_usecase.cache_hits, 1)
        self.check_module(mod)

        self.run_in_separate_process()


class TestCArray(TestCase):
    """
    Tests for carray() and farray().
    """

    def run_carray_usecase(self, pointer_factory, func):
        a = np.arange(10, 16).reshape((2, 3)).astype(np.float32)
        out = np.empty(CARRAY_USECASE_OUT_LEN, dtype=np.float32)
        func(pointer_factory(a), pointer_factory(out), *a.shape)
        return out

    def check_carray_usecase(self, pointer_factory, pyfunc, cfunc):
        expected = self.run_carray_usecase(pointer_factory, pyfunc)
        got = self.run_carray_usecase(pointer_factory, cfunc)
        self.assertPreciseEqual(expected, got)

    def make_voidptr(self, arr):
        return arr.ctypes.data_as(ctypes.c_void_p)

    def make_float32_pointer(self, arr):
        return arr.ctypes.data_as(ctypes.POINTER(ctypes.c_float))

    def make_float64_pointer(self, arr):
        return arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))

    def check_carray_farray(self, func, order):
        def eq(got, expected):
            # Same layout, dtype, shape, etc.
            self.assertPreciseEqual(got, expected)
            # Same underlying data
            self.assertEqual(got.ctypes.data, expected.ctypes.data)

        base = np.arange(6).reshape((2, 3)).astype(np.float32).copy(order=order)

        # With typed pointer and implied dtype
        a = func(self.make_float32_pointer(base), base.shape)
        eq(a, base)
        # Integer shape
        a = func(self.make_float32_pointer(base), base.size)
        eq(a, base.ravel('K'))

        # With typed pointer and explicit dtype
        a = func(self.make_float32_pointer(base), base.shape, base.dtype)
        eq(a, base)
        a = func(self.make_float32_pointer(base), base.shape, np.float32)
        eq(a, base)

        # With voidptr and explicit dtype
        a = func(self.make_voidptr(base), base.shape, base.dtype)
        eq(a, base)
        a = func(self.make_voidptr(base), base.shape, np.int32)
        eq(a, base.view(np.int32))

        # voidptr without dtype
        with self.assertRaises(TypeError):
            func(self.make_voidptr(base), base.shape)
        # Invalid pointer type
        with self.assertRaises(TypeError):
            func(base.ctypes.data, base.shape)
        # Mismatching dtype
        with self.assertRaises(TypeError) as raises:
            func(self.make_float32_pointer(base), base.shape, np.int32)
        self.assertIn("mismatching dtype 'int32' for pointer",
                      str(raises.exception))

    @tag('important')
    def test_carray(self):
        """
        Test pure Python carray().
        """
        self.check_carray_farray(carray, 'C')

    def test_farray(self):
        """
        Test pure Python farray().
        """
        self.check_carray_farray(farray, 'F')

    def make_carray_sigs(self, formal_sig):
        """
        Generate a bunch of concrete signatures by varying the width
        and signedness of size arguments (see issue #1923).
        """
        for actual_size in (types.intp, types.int32, types.intc,
                            types.uintp, types.uint32, types.uintc):
            args = tuple(actual_size if a == types.intp else a
                         for a in formal_sig.args)
            yield formal_sig.return_type(*args)

    def check_numba_carray_farray(self, usecase, dtype_usecase):
        # With typed pointers and implicit dtype
        pyfunc = usecase
        for sig in self.make_carray_sigs(carray_float32_usecase_sig):
            f = cfunc(sig)(pyfunc)
            self.check_carray_usecase(self.make_float32_pointer, pyfunc, f.ctypes)

        # With typed pointers and explicit (matching) dtype
        pyfunc = dtype_usecase
        for sig in self.make_carray_sigs(carray_float32_usecase_sig):
            f = cfunc(sig)(pyfunc)
            self.check_carray_usecase(self.make_float32_pointer, pyfunc, f.ctypes)
        # With typed pointers and mismatching dtype
        with self.assertTypingError() as raises:
            f = cfunc(carray_float64_usecase_sig)(pyfunc)
        self.assertIn("mismatching dtype 'float32' for pointer type 'float64*'",
                      str(raises.exception))

        # With voidptr
        pyfunc = dtype_usecase
        for sig in self.make_carray_sigs(carray_voidptr_usecase_sig):
            f = cfunc(sig)(pyfunc)
            self.check_carray_usecase(self.make_float32_pointer, pyfunc, f.ctypes)

    @tag('important')
    def test_numba_carray(self):
        """
        Test Numba-compiled carray() against pure Python carray()
        """
        self.check_numba_carray_farray(carray_usecase, carray_dtype_usecase)

    def test_numba_farray(self):
        """
        Test Numba-compiled farray() against pure Python farray()
        """
        self.check_numba_carray_farray(farray_usecase, farray_dtype_usecase)


@skip_cffi_unsupported
class TestCffiStruct(TestCase):
    c_source = """
typedef struct _big_struct {
    int    i1;
    float  f2;
    double d3;
    float  af4[9];
} big_struct;

typedef struct _error {
    int bits:4;
} error;

typedef double (*myfunc)(big_struct*, size_t);
"""

    def get_ffi(self, src=c_source):
        from cffi import FFI

        ffi = FFI()
        ffi.cdef(src)
        return ffi

    def test_type_parsing(self):
        ffi = self.get_ffi()
        # Check struct typedef
        big_struct = ffi.typeof('big_struct')
        nbtype = cffi_support.map_type(big_struct, use_record_dtype=True)
        self.assertIsInstance(nbtype, types.Record)
        self.assertEqual(len(nbtype), 4)
        self.assertEqual(nbtype.typeof('i1'), types.int32)
        self.assertEqual(nbtype.typeof('f2'), types.float32)
        self.assertEqual(nbtype.typeof('d3'), types.float64)
        self.assertEqual(
            nbtype.typeof('af4'),
            types.NestedArray(dtype=types.float32, shape=(9,)),
        )

        # Check function typedef
        myfunc = ffi.typeof('myfunc')
        sig = cffi_support.map_type(myfunc, use_record_dtype=True)
        self.assertIsInstance(sig, typing.Signature)
        self.assertEqual(sig.args[0], types.CPointer(nbtype))
        self.assertEqual(sig.args[1], types.uintp)
        self.assertEqual(sig.return_type, types.float64)

    def test_cfunc_callback(self):
        ffi = self.get_ffi()
        big_struct = ffi.typeof('big_struct')
        nb_big_struct = cffi_support.map_type(big_struct, use_record_dtype=True)
        sig = cffi_support.map_type(ffi.typeof('myfunc'), use_record_dtype=True)

        @njit
        def calc(base):
            tmp = 0
            for i in range(base.size):
                elem = base[i]
                tmp += elem.i1 * elem.f2 / elem.d3
                tmp += base[i].af4.sum()
            return tmp

        @cfunc(sig)
        def foo(ptr, n):
            base = carray(ptr, n)
            return calc(base)

        # Make data
        mydata = ffi.new('big_struct[3]')
        ptr = ffi.cast('big_struct*', mydata)
        for i in range(3):
            ptr[i].i1 = i * 123
            ptr[i].f2 = i * 213
            ptr[i].d3 = (1 + i) * 213
            for j in range(9):
                ptr[i].af4[j] = i * 10 + j

        # Address of my data
        addr = int(ffi.cast('size_t', ptr))
        got = foo.ctypes(addr, 3)

        # Make numpy array from the cffi buffer
        array = np.ndarray(
            buffer=ffi.buffer(mydata),
            dtype=numpy_support.as_dtype(nb_big_struct),
            shape=3,
            )
        expect = calc(array)
        self.assertEqual(got, expect)

    def test_unsupport_bitsize(self):
        ffi = self.get_ffi()
        with self.assertRaises(ValueError) as raises:
            cffi_support.map_type(
                ffi.typeof('error'),
                use_record_dtype=True,
            )
        # When bitsize is provided, bitshift defaults to 0.
        self.assertEqual(
            "field 'bits' has bitshift, this is not supported",
            str(raises.exception)
        )


if __name__ == "__main__":
    unittest.main()
