"""
Expose each GPU device directly
"""
from __future__ import print_function, absolute_import, division
import functools
from numba import servicelib
from .driver import hsa as driver, Context as _Context


class _culist(object):
    """A thread local list of GPU instances
    """

    def __init__(self):
        self._lst = None

    @property
    def _gpus(self):
        if not self._lst:
            self._lst = self._init_gpus()
        return self._lst

    def _init_gpus(self):
        gpus = []
        for com in driver.components:
            gpus.append(CU(com))
        return gpus

    def __getitem__(self, item):
        return self._gpus[item]

    def append(self, item):
        return self._gpus.append(item)

    def __len__(self):
        return len(self._gpus)

    def __nonzero__(self):
        return bool(self._gpus)

    def __iter__(self):
        return iter(self._gpus)

    __bool__ = __nonzero__

    def reset(self):
        for gpu in self:
            gpu.reset()

    @property
    def current(self):
        """Get the current GPU object associated with the thread
        """
        return _custack.top


cus = _culist()
del _culist


class CU(object):
    def __init__(self, cu):
        self._cu = cu
        self._context = None

    def __getattr__(self, key):
        """Redirect to self._gpu
        """
        if key.startswith('_'):
            raise AttributeError(key)
        return getattr(self._cu, key)

    def __repr__(self):
        return repr(self._cu)

    def associate_context(self):
        """Associate the context of this GPU to the running thread
        """
        # No context was created for this GPU
        if self._context is None:
            self._context = self._cu.create_context()

        return self._context

    def __enter__(self):
        self.associate_context()
        _custack.push(self)

    def __exit__(self, exc_type, exc_val, exc_tb):
        assert _get_device() is self
        self._context.pop()
        _custack.pop()

    def reset(self):
        if self._context:
            self._context.reset()
            self._context = None


_cpu_context = None


def get_cpu_context():
    global _cpu_context
    if _cpu_context is None:
        cpu_agent = [a for a in driver.agents if not a.is_component][0]
        _cpu_context = _Context(cpu_agent)
    return _cpu_context


def get_gpu(i):
    return cus[i]

def get_num_gpus():
    return len(cus)


_custack = servicelib.TLStack()


def _get_device(devnum=0):
    """Get the current device or use a device by device number.
    """
    if not _custack:
        _custack.push(get_gpu(devnum))
    return _custack.top


def get_context(devnum=0):
    """Get the current device or use a device by device number, and
    return the HSA context.
    """
    return _get_device(devnum=devnum).associate_context()


def get_all_contexts():
    return [get_context(i) for i in range(get_num_gpus())]


def require_context(fn):
    """
    A decorator to ensure a context for the HSA subsystem
    """

    @functools.wraps(fn)
    def _require_cu_context(*args, **kws):
        get_context()
        return fn(*args, **kws)

    return _require_cu_context


def reset():
    cus.reset()
    _custack.clear()


