from __future__ import absolute_import, print_function, division

from collections import OrderedDict
import ctypes
import sys

import numpy as np

from numba import (float32, float64, int16, int32, boolean, deferred_type,
                   optional)
from numba import njit, typeof, errors
from numba import unittest_support as unittest
from numba import jitclass
from .support import TestCase, MemoryLeakMixin, tag
from numba.jitclass import _box
from numba.runtime.nrt import MemInfo
from numba.errors import LoweringError


def _get_meminfo(box):
    ptr = _box.box_get_meminfoptr(box)
    mi = MemInfo(ptr)
    mi.acquire()
    return mi


class TestJitClass(TestCase, MemoryLeakMixin):

    def _check_spec(self, spec):
        @jitclass(spec)
        class Test(object):

            def __init__(self):
                pass

        clsty = Test.class_type.instance_type
        names = list(clsty.struct.keys())
        values = list(clsty.struct.values())
        self.assertEqual(names[0], 'x')
        self.assertEqual(names[1], 'y')
        self.assertEqual(values[0], int32)
        self.assertEqual(values[1], float32)

    def test_ordereddict_spec(self):
        spec = OrderedDict()
        spec['x'] = int32
        spec['y'] = float32
        self._check_spec(spec)

    def test_list_spec(self):
        spec = [('x', int32),
                ('y', float32)]
        self._check_spec(spec)

    def test_spec_errors(self):
        spec1 = [('x', int), ('y', float32[:])]
        spec2 = [(1, int32), ('y', float32[:])]

        class Test(object):

            def __init__(self):
                pass

        with self.assertRaises(TypeError) as raises:
            jitclass(spec1)(Test)
        self.assertIn("spec values should be Numba type instances",
                      str(raises.exception))
        with self.assertRaises(TypeError) as raises:
            jitclass(spec2)(Test)
        self.assertEqual(str(raises.exception),
                         "spec keys should be strings, got 1")

    def _make_Float2AndArray(self):
        spec = OrderedDict()
        spec['x'] = float32
        spec['y'] = float32
        spec['arr'] = float32[:]

        @jitclass(spec)
        class Float2AndArray(object):

            def __init__(self, x, y, arr):
                self.x = x
                self.y = y
                self.arr = arr

            def add(self, val):
                self.x += val
                self.y += val
                return val

        return Float2AndArray

    def _make_Vector2(self):
        spec = OrderedDict()
        spec['x'] = int32
        spec['y'] = int32

        @jitclass(spec)
        class Vector2(object):

            def __init__(self, x, y):
                self.x = x
                self.y = y

        return Vector2

    @tag('important')
    def test_jit_class_1(self):
        Float2AndArray = self._make_Float2AndArray()
        Vector2 = self._make_Vector2()

        @njit
        def bar(obj):
            return obj.x + obj.y

        @njit
        def foo(a):
            obj = Float2AndArray(1, 2, a)
            obj.add(123)

            vec = Vector2(3, 4)
            return bar(obj), bar(vec), obj.arr

        inp = np.ones(10, dtype=np.float32)
        a, b, c = foo(inp)
        self.assertEqual(a, 123 + 1 + 123 + 2)
        self.assertEqual(b, 3 + 4)
        self.assertPreciseEqual(c, inp)

    @tag('important')
    def test_jitclass_usage_from_python(self):
        Float2AndArray = self._make_Float2AndArray()

        @njit
        def identity(obj):
            return obj

        @njit
        def retrieve_attributes(obj):
            return obj.x, obj.y, obj.arr

        arr = np.arange(10, dtype=np.float32)
        obj = Float2AndArray(1, 2, arr)
        obj_meminfo = _get_meminfo(obj)
        self.assertEqual(obj_meminfo.refcount, 2)
        self.assertEqual(obj_meminfo.data, _box.box_get_dataptr(obj))
        self.assertEqual(obj._numba_type_.class_type,
                         Float2AndArray.class_type)
        # Use jit class instance in numba
        other = identity(obj)
        other_meminfo = _get_meminfo(other)  # duplicates MemInfo object to obj
        self.assertEqual(obj_meminfo.refcount, 4)
        self.assertEqual(other_meminfo.refcount, 4)
        self.assertEqual(other_meminfo.data, _box.box_get_dataptr(other))
        self.assertEqual(other_meminfo.data, obj_meminfo.data)

        # Check dtor
        del other, other_meminfo
        self.assertEqual(obj_meminfo.refcount, 2)

        # Check attributes
        out_x, out_y, out_arr = retrieve_attributes(obj)
        self.assertEqual(out_x, 1)
        self.assertEqual(out_y, 2)
        self.assertIs(out_arr, arr)

        # Access attributes from python
        self.assertEqual(obj.x, 1)
        self.assertEqual(obj.y, 2)
        self.assertIs(obj.arr, arr)

        # Access methods from python
        self.assertEqual(obj.add(123), 123)
        self.assertEqual(obj.x, 1 + 123)
        self.assertEqual(obj.y, 2 + 123)

        # Setter from python
        obj.x = 333
        obj.y = 444
        obj.arr = newarr = np.arange(5, dtype=np.float32)
        self.assertEqual(obj.x, 333)
        self.assertEqual(obj.y, 444)
        self.assertIs(obj.arr, newarr)

    def test_jitclass_datalayout(self):
        spec = OrderedDict()
        # Boolean has different layout as value vs data
        spec['val'] = boolean

        @jitclass(spec)
        class Foo(object):

            def __init__(self, val):
                self.val = val

        self.assertTrue(Foo(True).val)
        self.assertFalse(Foo(False).val)

    @tag('important')
    def test_deferred_type(self):
        node_type = deferred_type()

        spec = OrderedDict()
        spec['data'] = float32
        spec['next'] = optional(node_type)

        @njit
        def get_data(node):
            return node.data

        @jitclass(spec)
        class LinkedNode(object):

            def __init__(self, data, next):
                self.data = data
                self.next = next

            def get_next_data(self):
                # use deferred type as argument
                return get_data(self.next)

            def append_to_tail(self, other):
                cur = self
                while cur.next is not None:
                    cur = cur.next
                cur.next = other


        node_type.define(LinkedNode.class_type.instance_type)

        first = LinkedNode(123, None)
        self.assertEqual(first.data, 123)
        self.assertIsNone(first.next)

        second = LinkedNode(321, first)

        first_meminfo = _get_meminfo(first)
        second_meminfo = _get_meminfo(second)
        self.assertEqual(first_meminfo.refcount, 3)
        self.assertEqual(second.next.data, first.data)
        self.assertEqual(first_meminfo.refcount, 3)
        self.assertEqual(second_meminfo.refcount, 2)

        # Test using deferred type as argument
        first_val = second.get_next_data()
        self.assertEqual(first_val, first.data)

        # Check setattr (issue #2606)
        self.assertIsNone(first.next)
        second.append_to_tail(LinkedNode(567, None))
        self.assertIsNotNone(first.next)
        self.assertEqual(first.next.data, 567)
        self.assertIsNone(first.next.next)
        second.append_to_tail(LinkedNode(678, None))
        self.assertIsNotNone(first.next.next)
        self.assertEqual(first.next.next.data, 678)

        # Check ownership
        self.assertEqual(first_meminfo.refcount, 3)
        del second, second_meminfo
        self.assertEqual(first_meminfo.refcount, 2)

    def test_c_structure(self):
        spec = OrderedDict()
        spec['a'] = int32
        spec['b'] = int16
        spec['c'] = float64

        @jitclass(spec)
        class Struct(object):

            def __init__(self, a, b, c):
                self.a = a
                self.b = b
                self.c = c

        st = Struct(0xabcd, 0xef, 3.1415)

        class CStruct(ctypes.Structure):
            _fields_ = [
                ('a', ctypes.c_int32),
                ('b', ctypes.c_int16),
                ('c', ctypes.c_double),
            ]

        ptr = ctypes.c_void_p(_box.box_get_dataptr(st))
        cstruct = ctypes.cast(ptr, ctypes.POINTER(CStruct))[0]
        self.assertEqual(cstruct.a, st.a)
        self.assertEqual(cstruct.b, st.b)
        self.assertEqual(cstruct.c, st.c)

    def test_is(self):
        Vector = self._make_Vector2()
        vec_a = Vector(1, 2)
        vec_b = Vector(1, 2)

        @njit
        def do_is(a, b):
            return a is b

        with self.assertRaises(LoweringError) as raises:
            # trigger compilation
            do_is(vec_a, vec_a)
        self.assertIn('no default `is` implementation', str(raises.exception))

    def test_isinstance(self):
        Vector2 = self._make_Vector2()
        vec = Vector2(1, 2)
        self.assertIsInstance(vec, Vector2)

    def test_subclassing(self):
        Vector2 = self._make_Vector2()
        with self.assertRaises(TypeError) as raises:
            class SubV(Vector2):
                pass
        self.assertEqual(str(raises.exception),
                         "cannot subclass from a jitclass")

    def test_base_class(self):
        class Base(object):

            def what(self):
                return self.attr

        @jitclass([('attr', int32)])
        class Test(Base):

            def __init__(self, attr):
                self.attr = attr

        obj = Test(123)
        self.assertEqual(obj.what(), 123)

    def test_globals(self):

        class Mine(object):
            constant = 123

            def __init__(self):
                pass

        with self.assertRaises(TypeError) as raises:
            jitclass(())(Mine)

        self.assertEqual(str(raises.exception),
                         "class members are not yet supported: constant")

    @tag('important')
    def test_user_getter_setter(self):
        @jitclass([('attr', int32)])
        class Foo(object):

            def __init__(self, attr):
                self.attr = attr

            @property
            def value(self):
                return self.attr + 1

            @value.setter
            def value(self, val):
                self.attr = val - 1

        foo = Foo(123)
        self.assertEqual(foo.attr, 123)
        # Getter
        self.assertEqual(foo.value, 123 + 1)
        # Setter
        foo.value = 789
        self.assertEqual(foo.attr, 789 - 1)
        self.assertEqual(foo.value, 789)

        # Test nopython mode usage of getter and setter
        @njit
        def bar(foo, val):
            a = foo.value
            foo.value = val
            b = foo.value
            c = foo.attr
            return a, b, c

        a, b, c = bar(foo, 567)
        self.assertEqual(a, 789)
        self.assertEqual(b, 567)
        self.assertEqual(c, 567 - 1)

    def test_user_deleter_error(self):
        class Foo(object):

            def __init__(self):
                pass

            @property
            def value(self):
                return 1

            @value.deleter
            def value(self):
                pass

        with self.assertRaises(TypeError) as raises:
            jitclass([])(Foo)
        self.assertEqual(str(raises.exception),
                         "deleter is not supported: value")

    def test_name_shadowing_error(self):
        class Foo(object):

            def __init__(self):
                pass

            @property
            def my_property(self):
                pass

            def my_method(self):
                pass

        with self.assertRaises(NameError) as raises:
            jitclass([('my_property', int32)])(Foo)
        self.assertEqual(str(raises.exception), 'name shadowing: my_property')

        with self.assertRaises(NameError) as raises:
            jitclass([('my_method', int32)])(Foo)
        self.assertEqual(str(raises.exception), 'name shadowing: my_method')

    def test_distinct_classes(self):
        # Different classes with the same names shouldn't confuse the compiler
        @jitclass([('x', int32)])
        class Foo(object):

            def __init__(self, x):
                self.x = x + 2

            def run(self):
                return self.x + 1

        FirstFoo = Foo

        @jitclass([('x', int32)])
        class Foo(object):

            def __init__(self, x):
                self.x = x - 2

            def run(self):
                return self.x - 1

        SecondFoo = Foo
        foo = FirstFoo(5)
        self.assertEqual(foo.x, 7)
        self.assertEqual(foo.run(), 8)
        foo = SecondFoo(5)
        self.assertEqual(foo.x, 3)
        self.assertEqual(foo.run(), 2)

    def test_parameterized(self):
        class MyClass(object):

            def __init__(self, value):
                self.value = value

        def create_my_class(value):
            cls = jitclass([('value', typeof(value))])(MyClass)
            return cls(value)

        a = create_my_class(123)
        self.assertEqual(a.value, 123)

        b = create_my_class(12.3)
        self.assertEqual(b.value, 12.3)

        c = create_my_class(np.array([123]))
        np.testing.assert_equal(c.value, [123])

        d = create_my_class(np.array([12.3]))
        np.testing.assert_equal(d.value, [12.3])

    @tag('important')
    def test_protected_attrs(self):
        spec = {
            'value': int32,
            '_value': float32,
            '__value': int32,
            '__value__': int32,
        }

        @jitclass(spec)
        class MyClass(object):

            def __init__(self, value):
                self.value = value
                self._value = value / 2
                self.__value = value * 2
                self.__value__ = value - 1

            @property
            def private_value(self):
                return self.__value

            @property
            def _inner_value(self):
                return self._value

            @_inner_value.setter
            def _inner_value(self, v):
                self._value = v

            @property
            def __private_value(self):
                return self.__value

            @__private_value.setter
            def __private_value(self, v):
                self.__value = v

            def swap_private_value(self, new):
                old = self.__private_value
                self.__private_value = new
                return old

            def _protected_method(self, factor):
                return self._value * factor

            def __private_method(self, factor):
                return self.__value * factor

            def check_private_method(self, factor):
                return self.__private_method(factor)


        value = 123
        inst = MyClass(value)
        # test attributes
        self.assertEqual(inst.value, value)
        self.assertEqual(inst._value, value / 2)
        self.assertEqual(inst.private_value, value * 2)
        # test properties
        self.assertEqual(inst._inner_value, inst._value)
        freeze_inst_value = inst._value
        inst._inner_value -= 1
        self.assertEqual(inst._inner_value, freeze_inst_value - 1)

        self.assertEqual(inst.swap_private_value(321), value * 2)
        self.assertEqual(inst.swap_private_value(value * 2), 321)
        # test methods
        self.assertEqual(inst._protected_method(3), inst._value * 3)
        self.assertEqual(inst.check_private_method(3), inst.private_value * 3)
        # test special
        self.assertEqual(inst.__value__, value - 1)
        inst.__value__ -= 100
        self.assertEqual(inst.__value__, value - 101)

        # test errors
        @njit
        def access_dunder(inst):
            return inst.__value

        with self.assertRaises(errors.TypingError) as raises:
            access_dunder(inst)
        # It will appear as "_TestJitClass__value" because the `access_dunder`
        # is under the scope of 'TestJitClass'.
        self.assertIn('_TestJitClass__value', str(raises.exception))

        with self.assertRaises(AttributeError) as raises:
            access_dunder.py_func(inst)
        self.assertIn('_TestJitClass__value', str(raises.exception))

    @unittest.skipIf(sys.version_info < (3,), "Python 3-specific test")
    def test_annotations(self):
        """
        Methods with annotations should compile fine (issue #1911).
        """
        from .annotation_usecases import AnnotatedClass

        spec = {'x': int32}
        cls = jitclass(spec)(AnnotatedClass)

        obj = cls(5)
        self.assertEqual(obj.x, 5)
        self.assertEqual(obj.add(2), 7)

    def test_docstring(self):

        @jitclass([])
        class Apple(object):
            "Class docstring"
            def __init__(self):
                "init docstring"

            def foo(self):
                "foo method docstring"

            @property
            def aval(self):
                "aval property docstring"

        self.assertEqual(Apple.__doc__, 'Class docstring')
        self.assertEqual(Apple.__init__.__doc__, 'init docstring')
        self.assertEqual(Apple.foo.__doc__, 'foo method docstring')
        self.assertEqual(Apple.aval.__doc__, 'aval property docstring')

    def test_kwargs(self):
        spec = [('a', int32),
                ('b', float64)]

        @jitclass(spec)
        class TestClass(object):
            def __init__(self, x, y, z):
                self.a = x * y
                self.b = z

        x = 2
        y = 2
        z = 1.1
        kwargs = {'y': y, 'z': z}
        tc = TestClass(x=2, **kwargs)
        self.assertEqual(tc.a, x * y)
        self.assertEqual(tc.b, z)

    def test_generator_method(self):
        spec = []

        @jitclass(spec)
        class TestClass(object):
            def __init__(self):
                pass

            def gen(self, niter):
                for i in range(niter):
                    yield np.arange(i)

        def expected_gen(niter):
            for i in range(niter):
                yield np.arange(i)

        for niter in range(10):
            for expect, got in zip(expected_gen(niter), TestClass().gen(niter)):
                self.assertPreciseEqual(expect, got)


if __name__ == '__main__':
    unittest.main()
