from __future__ import print_function, absolute_import
import copy
from collections import namedtuple
import re

import numpy as np

from numba.typing.templates import ConcreteTemplate
from numba import types, compiler
from .hlc import hlc
from .hsadrv import devices, driver, enums, drvapi
from .hsadrv.error import HsaKernelLaunchError
from . import gcn_occupancy
from numba.roc.hsadrv.driver import hsa, dgpu_present
from .hsadrv import devicearray
from numba.typing.templates import AbstractTemplate
from numba import ctypes_support as ctypes
from numba import config
from numba.compiler_lock import global_compiler_lock

@global_compiler_lock
def compile_hsa(pyfunc, return_type, args, debug):
    # First compilation will trigger the initialization of the HSA backend.
    from .descriptor import HSATargetDesc

    typingctx = HSATargetDesc.typingctx
    targetctx = HSATargetDesc.targetctx
    # TODO handle debug flag
    flags = compiler.Flags()
    # Do not compile (generate native code), just lower (to LLVM)
    flags.set('no_compile')
    flags.set('no_cpython_wrapper')
    flags.unset('nrt')
    # Run compilation pipeline
    cres = compiler.compile_extra(typingctx=typingctx,
                                  targetctx=targetctx,
                                  func=pyfunc,
                                  args=args,
                                  return_type=return_type,
                                  flags=flags,
                                  locals={})

    # Linking depending libraries
    # targetctx.link_dependencies(cres.llvm_module, cres.target_context.linking)
    library = cres.library
    library.finalize()

    return cres


def compile_kernel(pyfunc, args, debug=False):
    cres = compile_hsa(pyfunc, types.void, args, debug=debug)
    func = cres.library.get_function(cres.fndesc.llvm_func_name)
    kernel = cres.target_context.prepare_hsa_kernel(func, cres.signature.args)
    hsakern = HSAKernel(llvm_module=kernel.module,
                        name=kernel.name,
                        argtypes=cres.signature.args)
    return hsakern


def compile_device(pyfunc, return_type, args, debug=False):
    cres = compile_hsa(pyfunc, return_type, args, debug=debug)
    func = cres.library.get_function(cres.fndesc.llvm_func_name)
    cres.target_context.mark_hsa_device(func)
    devfn = DeviceFunction(cres)

    class device_function_template(ConcreteTemplate):
        key = devfn
        cases = [cres.signature]

    cres.typing_context.insert_user_function(devfn, device_function_template)
    libs = [cres.library]
    cres.target_context.insert_user_function(devfn, cres.fndesc, libs)
    return devfn


def compile_device_template(pyfunc):
    """Compile a DeviceFunctionTemplate
    """
    from .descriptor import HSATargetDesc

    dft = DeviceFunctionTemplate(pyfunc)

    class device_function_template(AbstractTemplate):
        key = dft

        def generic(self, args, kws):
            assert not kws
            return dft.compile(args)

    typingctx = HSATargetDesc.typingctx
    typingctx.insert_user_function(dft, device_function_template)
    return dft


class DeviceFunctionTemplate(object):
    """Unmaterialized device function
    """
    def __init__(self, pyfunc, debug=False):
        self.py_func = pyfunc
        self.debug = debug
        # self.inline = inline
        self._compileinfos = {}

    def compile(self, args):
        """Compile the function for the given argument types.

        Each signature is compiled once by caching the compiled function inside
        this object.
        """
        if args not in self._compileinfos:
            cres = compile_hsa(self.py_func, None, args, debug=self.debug)
            func = cres.library.get_function(cres.fndesc.llvm_func_name)
            cres.target_context.mark_hsa_device(func)
            first_definition = not self._compileinfos
            self._compileinfos[args] = cres
            libs = [cres.library]

            if first_definition:
                # First definition
                cres.target_context.insert_user_function(self, cres.fndesc,
                                                         libs)
            else:
                cres.target_context.add_user_function(self, cres.fndesc, libs)

        else:
            cres = self._compileinfos[args]

        return cres.signature


class DeviceFunction(object):
    def __init__(self, cres):
        self.cres = cres


def _ensure_list(val):
    if not isinstance(val, (tuple, list)):
        return [val]
    else:
        return list(val)


def _ensure_size_or_append(val, size):
    n = len(val)
    for _ in range(n, size):
        val.append(1)


class HSAKernelBase(object):
    """Define interface for configurable kernels
    """

    def __init__(self):
        self.global_size = (1,)
        self.local_size = (1,)
        self.stream = None

    def copy(self):
        return copy.copy(self)

    def configure(self, global_size, local_size=None, stream=None):
        """Configure the OpenCL kernel
        local_size can be None
        """
        global_size = _ensure_list(global_size)

        if local_size is not None:
            local_size = _ensure_list(local_size)
            size = max(len(global_size), len(local_size))
            _ensure_size_or_append(global_size, size)
            _ensure_size_or_append(local_size, size)

        clone = self.copy()
        clone.global_size = tuple(global_size)
        clone.local_size = tuple(local_size) if local_size else None
        clone.stream = stream

        return clone

    def forall(self, nelem, local_size=64, stream=None):
        """Simplified configuration for 1D kernel launch
        """
        return self.configure(nelem, min(nelem, local_size), stream=stream)

    def __getitem__(self, args):
        """Mimick CUDA python's square-bracket notation for configuration.
        This assumes a the argument to be:
            `griddim, blockdim, stream`
        The blockdim maps directly to local_size.
        The actual global_size is computed by multiplying the local_size to
        griddim.
        """
        griddim = _ensure_list(args[0])
        blockdim = _ensure_list(args[1])
        size = max(len(griddim), len(blockdim))
        _ensure_size_or_append(griddim, size)
        _ensure_size_or_append(blockdim, size)
        # Compute global_size
        gs = [g * l for g, l in zip(griddim, blockdim)]
        return self.configure(gs, blockdim, *args[2:])


_CacheEntry = namedtuple("_CachedEntry", ['symbol', 'executable',
                                          'kernarg_region'])


class _CachedProgram(object):
    def __init__(self, entry_name, binary):
        self._entry_name = entry_name
        self._binary = binary
        # key: hsa context
        self._cache = {}

    def get(self):
        ctx = devices.get_context()
        result = self._cache.get(ctx)
        # The program does not exist as GCN yet.
        if result is None:

            # generate GCN
            symbol = '{0}'.format(self._entry_name)
            agent = ctx.agent

            ba = bytearray(self._binary)
            bblob = ctypes.c_byte * len(self._binary)
            bas = bblob.from_buffer(ba)

            code_ptr = drvapi.hsa_code_object_t()
            driver.hsa.hsa_code_object_deserialize(
                    ctypes.addressof(bas),
                    len(self._binary),
                    None,
                    ctypes.byref(code_ptr)
                    )

            code = driver.CodeObject(code_ptr)

            ex = driver.Executable()
            ex.load(agent, code)
            ex.freeze()
            symobj = ex.get_symbol(agent, symbol)
            regions = agent.regions.globals
            for reg in regions:
                if reg.host_accessible:
                    if reg.supports(enums.HSA_REGION_GLOBAL_FLAG_KERNARG):
                        kernarg_region = reg
                        break
            assert kernarg_region is not None

            # Cache the GCN program
            result = _CacheEntry(symbol=symobj, executable=ex,
                                 kernarg_region=kernarg_region)
            self._cache[ctx] = result

        return ctx, result


class HSAKernel(HSAKernelBase):
    """
    A HSA kernel object
    """
    def __init__(self, llvm_module, name, argtypes):
        super(HSAKernel, self).__init__()
        self._llvm_module = llvm_module
        self.assembly, self.binary = self._generateGCN()
        self.entry_name = name
        self.argument_types = tuple(argtypes)
        self._argloc = []
        # cached program
        self._cacheprog = _CachedProgram(entry_name=self.entry_name,
                                         binary=self.binary)
        self._parse_kernel_resource()

    def _parse_kernel_resource(self):
        """
        Temporary workaround for register limit
        """
        m = re.search(r"\bwavefront_sgpr_count\s*=\s*(\d+)", self.assembly)
        self._wavefront_sgpr_count = int(m.group(1))
        m = re.search(r"\bworkitem_vgpr_count\s*=\s*(\d+)", self.assembly)
        self._workitem_vgpr_count = int(m.group(1))

    def _sentry_resource_limit(self):
        # only check resource factprs if either sgpr or vgpr is non-zero
        #if (self._wavefront_sgpr_count > 0 or self._workitem_vgpr_count > 0):
        group_size = np.prod(self.local_size)
        limits = gcn_occupancy.get_limiting_factors(
            group_size=group_size,
            vgpr_per_workitem=self._workitem_vgpr_count,
            sgpr_per_wave=self._wavefront_sgpr_count)
        if limits.reasons:
            fmt = 'insufficient resources to launch kernel due to:\n{}'
            msg = fmt.format('\n'.join(limits.suggestions))
            raise HsaKernelLaunchError(msg)

    def _generateGCN(self):
        hlcmod = hlc.Module()
        hlcmod.load_llvm(str(self._llvm_module))
        return hlcmod.generateGCN()

    def bind(self):
        """
        Bind kernel to device
        """
        ctx, entry = self._cacheprog.get()
        if entry.symbol.kernarg_segment_size > 0:
            sz = ctypes.sizeof(ctypes.c_byte) *\
                entry.symbol.kernarg_segment_size
            kernargs = entry.kernarg_region.allocate(sz)
        else:
            kernargs = None

        return ctx, entry.symbol, kernargs, entry.kernarg_region

    def __call__(self, *args):
        self._sentry_resource_limit()

        ctx, symbol, kernargs, kernarg_region = self.bind()

        # Unpack pyobject values into ctypes scalar values
        expanded_values = []

        # contains lambdas to execute on return
        retr = []
        for ty, val in zip(self.argument_types, args):
            _unpack_argument(ty, val, expanded_values, retr)

        # Insert kernel arguments
        base = 0
        for av in expanded_values:
            # Adjust for alignemnt
            align = ctypes.sizeof(av)
            pad = _calc_padding_for_alignment(align, base)
            base += pad
            # Move to offset
            offseted = kernargs.value + base
            asptr = ctypes.cast(offseted, ctypes.POINTER(type(av)))
            # Assign value
            asptr[0] = av
            # Increment offset
            base += align

        # Actual Kernel launch
        qq = ctx.default_queue

        if self.stream is None:
            hsa.implicit_sync()

        # Dispatch
        signal = None
        if self.stream is not None:
            signal = hsa.create_signal(1)
            qq.insert_barrier(self.stream._get_last_signal())

        qq.dispatch(symbol, kernargs, workgroup_size=self.local_size,
                    grid_size=self.global_size, signal=signal)

        if self.stream is not None:
            self.stream._add_signal(signal)

        # retrieve auto converted arrays
        for wb in retr:
            wb()

        # Free kernel region
        if kernargs is not None:
            if self.stream is None:
                kernarg_region.free(kernargs)
            else:
                self.stream._add_callback(lambda: kernarg_region.free(kernargs))


def _unpack_argument(ty, val, kernelargs, retr):
    """
    Convert arguments to ctypes and append to kernelargs
    """
    if isinstance(ty, types.Array):
        c_intp = ctypes.c_ssize_t
        # if a dgpu is present, move the data to the device.
        if dgpu_present:
            devary, conv = devicearray.auto_device(val, devices.get_context())
            if conv:
                retr.append(lambda: devary.copy_to_host(val))
            data = devary.device_ctypes_pointer
        else:
            data = ctypes.c_void_p(val.ctypes.data)


        meminfo = parent = ctypes.c_void_p(0)
        nitems = c_intp(val.size)
        itemsize = c_intp(val.dtype.itemsize)
        kernelargs.append(meminfo)
        kernelargs.append(parent)
        kernelargs.append(nitems)
        kernelargs.append(itemsize)
        kernelargs.append(data)
        for ax in range(val.ndim):
            kernelargs.append(c_intp(val.shape[ax]))
        for ax in range(val.ndim):
            kernelargs.append(c_intp(val.strides[ax]))

    elif isinstance(ty, types.Integer):
        cval = getattr(ctypes, "c_%s" % ty)(val)
        kernelargs.append(cval)

    elif ty == types.float64:
        cval = ctypes.c_double(val)
        kernelargs.append(cval)

    elif ty == types.float32:
        cval = ctypes.c_float(val)
        kernelargs.append(cval)

    elif ty == types.boolean:
        cval = ctypes.c_uint8(int(val))
        kernelargs.append(cval)

    elif ty == types.complex64:
        kernelargs.append(ctypes.c_float(val.real))
        kernelargs.append(ctypes.c_float(val.imag))

    elif ty == types.complex128:
        kernelargs.append(ctypes.c_double(val.real))
        kernelargs.append(ctypes.c_double(val.imag))

    else:
        raise NotImplementedError(ty, val)


def _calc_padding_for_alignment(align, base):
    """
    Returns byte padding required to move the base pointer into proper alignment
    """
    rmdr = int(base) % align
    if rmdr == 0:
        return 0
    else:
        return align - rmdr


class AutoJitHSAKernel(HSAKernelBase):
    def __init__(self, func):
        super(AutoJitHSAKernel, self).__init__()
        self.py_func = func
        self.definitions = {}

        from .descriptor import HSATargetDesc

        self.typingctx = HSATargetDesc.typingctx

    def __call__(self, *args):
        kernel = self.specialize(*args)
        cfg = kernel.configure(self.global_size, self.local_size, self.stream)
        cfg(*args)

    def specialize(self, *args):
        argtypes = tuple([self.typingctx.resolve_argument_type(a)
                          for a in args])
        kernel = self.definitions.get(argtypes)
        if kernel is None:
            kernel = compile_kernel(self.py_func, argtypes)
            self.definitions[argtypes] = kernel
        return kernel

