from __future__ import print_function, absolute_import

import numpy as np
from numba import roc
from numba.roc.hsadrv.error import HsaKernelLaunchError
import numba.unittest_support as unittest


class TestSimple(unittest.TestCase):

    def test_array_access(self):
        magic_token = 123

        @roc.jit
        def udt(output):
            output[0] = magic_token

        out = np.zeros(1, dtype=np.intp)
        udt[1, 1](out)

        self.assertEqual(out[0], magic_token)

    def test_array_access_2d(self):
        magic_token = 123

        @roc.jit
        def udt(output):
            for i in range(output.shape[0]):
                for j in range(output.shape[1]):
                    output[i, j] = magic_token

        out = np.zeros((10, 10), dtype=np.intp)
        udt[1, 1](out)
        np.testing.assert_equal(out, magic_token)

    def test_array_access_3d(self):
        magic_token = 123

        @roc.jit
        def udt(output):
            for i in range(output.shape[0]):
                for j in range(output.shape[1]):
                    for k in range(output.shape[2]):
                        output[i, j, k] = magic_token

        out = np.zeros((10, 10, 10), dtype=np.intp)
        udt[1, 1](out)
        np.testing.assert_equal(out, magic_token)

    def test_global_id(self):
        @roc.jit
        def udt(output):
            global_id = roc.get_global_id(0)
            output[global_id] = global_id

        # Allocate extra space to track bad indexing
        out = np.zeros(100 + 2, dtype=np.intp)
        udt[10, 10](out[1:-1])

        np.testing.assert_equal(out[1:-1], np.arange(100))

        self.assertEqual(out[0], 0)
        self.assertEqual(out[-1], 0)

    def test_local_id(self):
        @roc.jit
        def udt(output):
            global_id = roc.get_global_id(0)
            local_id = roc.get_local_id(0)
            output[global_id] = local_id

        # Allocate extra space to track bad indexing
        out = np.zeros(100 + 2, dtype=np.intp)
        udt[10, 10](out[1:-1])

        subarr = out[1:-1]

        for parted in np.split(subarr, 10):
            np.testing.assert_equal(parted, np.arange(10))

        self.assertEqual(out[0], 0)
        self.assertEqual(out[-1], 0)

    def test_group_id(self):
        @roc.jit
        def udt(output):
            global_id = roc.get_global_id(0)
            group_id = roc.get_group_id(0)
            output[global_id] = group_id + 1

        # Allocate extra space to track bad indexing
        out = np.zeros(100 + 2, dtype=np.intp)
        udt[10, 10](out[1:-1])

        subarr = out[1:-1]

        for i, parted in enumerate(np.split(subarr, 10), start=1):
            np.testing.assert_equal(parted, i)

        self.assertEqual(out[0], 0)
        self.assertEqual(out[-1], 0)


    def test_workdim(self):
        @roc.jit
        def udt(output):
            global_id = roc.get_global_id(0)
            workdim = roc.get_work_dim()
            output[global_id] = workdim

        out = np.zeros(10, dtype=np.intp)
        udt[1, 10](out)
        np.testing.assert_equal(out, 1)

        @roc.jit
        def udt2(output):
            g0 = roc.get_global_id(0)
            g1 = roc.get_global_id(1)
            output[g0, g1] = roc.get_work_dim()

        out = np.zeros((2, 5), dtype=np.intp)
        udt2[(1, 1), (2, 5)](out)
        np.testing.assert_equal(out, 2)

    def test_empty_kernel(self):
        @roc.jit
        def udt():
            pass

        udt[1, 1]()

    def test_workgroup_oversize(self):
        @roc.jit
        def udt():
            pass

        with self.assertRaises(HsaKernelLaunchError) as raises:
            udt[1, 2**30]()
        self.assertIn("Try reducing group-size", str(raises.exception))


if __name__ == '__main__':
    unittest.main()
