from __future__ import print_function

from numba.compiler import Pipeline
from numba import jit, generated_jit, types, objmode
from numba.ir import FunctionIR
from .support import TestCase


class TestCustomPipeline(TestCase):
    def setUp(self):
        super(TestCustomPipeline, self).setUp()

        # Define custom pipeline class
        class CustomPipeline(Pipeline):
            custom_pipeline_cache = []

            def compile_extra(self, func):
                # Store the compiled function
                self.custom_pipeline_cache.append(func)
                return super(CustomPipeline, self).compile_extra(func)

            def compile_ir(self, func_ir, *args, **kwargs):
                # Store the compiled function
                self.custom_pipeline_cache.append(func_ir)
                return super(CustomPipeline, self).compile_ir(
                    func_ir, *args, **kwargs)

        self.pipeline_class = CustomPipeline

    def test_jit_custom_pipeline(self):
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache, [])

        @jit(pipeline_class=self.pipeline_class)
        def foo(x):
            return x

        self.assertEqual(foo(4), 4)
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache,
                             [foo.py_func])

    def test_generated_jit_custom_pipeline(self):
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache, [])

        def inner(x):
            return x

        @generated_jit(pipeline_class=self.pipeline_class)
        def foo(x):
            if isinstance(x, types.Integer):
                return inner

        self.assertEqual(foo(5), 5)
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache,
                             [inner])

    def test_objmode_custom_pipeline(self):
        self.assertListEqual(self.pipeline_class.custom_pipeline_cache, [])

        @jit(pipeline_class=self.pipeline_class)
        def foo(x):
            with objmode(x="intp"):
                x += int(0x1)
            return x

        arg = 123
        self.assertEqual(foo(arg), arg + 1)
        # Two items in the list.
        self.assertEqual(len(self.pipeline_class.custom_pipeline_cache), 2)
        # First item is the `foo` function
        first = self.pipeline_class.custom_pipeline_cache[0]
        self.assertIs(first, foo.py_func)
        # Second item is a FunctionIR of the obj-lifted function
        second = self.pipeline_class.custom_pipeline_cache[1]
        self.assertIsInstance(second, FunctionIR)

