from __future__ import print_function, division, absolute_import

import copy
import warnings
import numpy as np

import numba
from numba import unittest_support as unittest
from numba.transforms import find_setupwiths, with_lifting
from numba.withcontexts import bypass_context, call_context, objmode_context
from numba.bytecode import FunctionIdentity, ByteCode
from numba.interpreter import Interpreter
from numba import typing, errors
from numba.targets.registry import cpu_target
from numba.targets import cpu
from numba.compiler import compile_ir, DEFAULT_FLAGS
from numba import njit, typeof, objmode
from .support import MemoryLeak, TestCase, captured_stdout


try:
    import scipy
except ImportError:
    scipy = None

_msg = "SciPy needed for test"
skip_unless_scipy = unittest.skipIf(scipy is None, _msg)


def get_func_ir(func):
    func_id = FunctionIdentity.from_function(func)
    bc = ByteCode(func_id=func_id)
    interp = Interpreter(func_id)
    func_ir = interp.interpret(bc)
    return func_ir


def lift1():
    print("A")
    with bypass_context:
        print("B")
        b()
    print("C")


def lift2():
    x = 1
    print("A", x)
    x = 1
    with bypass_context:
        print("B", x)
        x += 100
        b()
    x += 1
    with bypass_context:
        print("C", x)
        b()
        x += 10
    x += 1
    print("D", x)


def lift3():
    x = 1
    y = 100
    print("A", x, y)
    with bypass_context:
        print("B")
        b()
        x += 100
        with bypass_context:
            print("C")
            y += 100000
            b()
    x += 1
    y += 1
    print("D", x, y)


def lift4():
    x = 0
    print("A", x)
    x += 10
    with bypass_context:
        print("B")
        b()
        x += 1
        for i in range(10):
            with bypass_context:
                print("C")
                b()
                x += i
    with bypass_context:
        print("D")
        b()
        if x:
            x *= 10
    x += 1
    print("E", x)


def lift5():
    print("A")


def liftcall1():
    x = 1
    print("A", x)
    with call_context:
        x += 1
    print("B", x)
    return x


def liftcall2():
    x = 1
    print("A", x)
    with call_context:
        x += 1
    print("B", x)
    with call_context:
        x += 10
    print("C", x)
    return x


def liftcall3():
    x = 1
    print("A", x)
    with call_context:
        if x > 0:
            x += 1
    print("B", x)
    with call_context:
        for i in range(10):
            x += i
    print("C", x)
    return x


def liftcall4():
    with call_context:
        with call_context:
            pass


def lift_undefiend():
    with undefined_global_var:
        pass


bogus_contextmanager = object()


def lift_invalid():
    with bogus_contextmanager:
        pass


class TestWithFinding(TestCase):
    def check_num_of_with(self, func, expect_count):
        the_ir = get_func_ir(func)
        ct = len(find_setupwiths(the_ir.blocks))
        self.assertEqual(ct, expect_count)

    def test_lift1(self):
        self.check_num_of_with(lift1, expect_count=1)

    def test_lift2(self):
        self.check_num_of_with(lift2, expect_count=2)

    def test_lift3(self):
        self.check_num_of_with(lift3, expect_count=1)

    def test_lift4(self):
        self.check_num_of_with(lift4, expect_count=2)

    def test_lift5(self):
        self.check_num_of_with(lift5, expect_count=0)


class BaseTestWithLifting(TestCase):
    def setUp(self):
        super(BaseTestWithLifting, self).setUp()
        self.typingctx = typing.Context()
        self.targetctx = cpu.CPUContext(self.typingctx)
        self.flags = DEFAULT_FLAGS

    def check_extracted_with(self, func, expect_count, expected_stdout):
        the_ir = get_func_ir(func)
        new_ir, extracted = with_lifting(
            the_ir, self.typingctx, self.targetctx, self.flags,
            locals={},
        )
        self.assertEqual(len(extracted), expect_count)
        cres = self.compile_ir(new_ir)

        with captured_stdout() as out:
            cres.entry_point()

        self.assertEqual(out.getvalue(), expected_stdout)

    def compile_ir(self, the_ir, args=(), return_type=None):
        typingctx = self.typingctx
        targetctx = self.targetctx
        flags = self.flags
        # Register the contexts in case for nested @jit or @overload calls
        with cpu_target.nested_context(typingctx, targetctx):
            return compile_ir(typingctx, targetctx, the_ir, args,
                              return_type, flags, locals={})


class TestLiftByPass(BaseTestWithLifting):

    def test_lift1(self):
        self.check_extracted_with(lift1, expect_count=1,
                                  expected_stdout="A\nC\n")

    def test_lift2(self):
        self.check_extracted_with(lift2, expect_count=2,
                                  expected_stdout="A 1\nD 3\n")

    def test_lift3(self):
        self.check_extracted_with(lift3, expect_count=1,
                                  expected_stdout="A 1 100\nD 2 101\n")

    def test_lift4(self):
        self.check_extracted_with(lift4, expect_count=2,
                                  expected_stdout="A 0\nE 11\n")

    def test_lift5(self):
        self.check_extracted_with(lift5, expect_count=0,
                                  expected_stdout="A\n")


class TestLiftCall(BaseTestWithLifting):

    def check_same_semantic(self, func):
        """Ensure same semantic with non-jitted code
        """
        jitted = njit(func)
        with captured_stdout() as got:
            jitted()

        with captured_stdout() as expect:
            func()

        self.assertEqual(got.getvalue(), expect.getvalue())

    def test_liftcall1(self):
        self.check_extracted_with(liftcall1, expect_count=1,
                                  expected_stdout="A 1\nB 2\n")
        self.check_same_semantic(liftcall1)

    def test_liftcall2(self):
        self.check_extracted_with(liftcall2, expect_count=2,
                                  expected_stdout="A 1\nB 2\nC 12\n")
        self.check_same_semantic(liftcall2)

    def test_liftcall3(self):
        self.check_extracted_with(liftcall3, expect_count=2,
                                  expected_stdout="A 1\nB 2\nC 47\n")
        self.check_same_semantic(liftcall3)

    def test_liftcall4(self):
        with self.assertRaises(errors.TypingError) as raises:
            njit(liftcall4)()
        # Known error.  We only support one context manager per function
        # for body that are lifted.
        self.assertIn("re-entrant", str(raises.exception))


def expected_failure_for_list_arg(fn):
    def core(self, *args, **kwargs):
        with self.assertRaises(errors.TypingError) as raises:
            fn(self, *args, **kwargs)
        self.assertIn('Does not support list type',
                      str(raises.exception))
    return core


def expected_failure_for_function_arg(fn):
    def core(self, *args, **kwargs):
        with self.assertRaises(errors.TypingError) as raises:
            fn(self, *args, **kwargs)
        self.assertIn('Does not support function type',
                      str(raises.exception))
    return core


class TestLiftObj(MemoryLeak, TestCase):

    def setUp(self):
        warnings.simplefilter("error", errors.NumbaWarning)

    def tearDown(self):
        warnings.resetwarnings()

    def assert_equal_return_and_stdout(self, pyfunc, *args):
        py_args = copy.deepcopy(args)
        c_args = copy.deepcopy(args)
        cfunc = njit(pyfunc)

        with captured_stdout() as stream:
            expect_res = pyfunc(*py_args)
            expect_out = stream.getvalue()

        # avoid compiling during stdout-capturing for easier print-debugging
        cfunc.compile(tuple(map(typeof, c_args)))
        with captured_stdout() as stream:
            got_res = cfunc(*c_args)
            got_out = stream.getvalue()

        self.assertEqual(expect_out, got_out)
        self.assertPreciseEqual(expect_res, got_res)

    def test_lift_objmode_basic(self):
        def bar(ival):
            print("ival =", {'ival': ival // 2})

        def foo(ival):
            ival += 1
            with objmode_context:
                bar(ival)
            return ival + 1

        def foo_nonglobal(ival):
            ival += 1
            with numba.objmode:
                bar(ival)
            return ival + 1

        self.assert_equal_return_and_stdout(foo, 123)
        self.assert_equal_return_and_stdout(foo_nonglobal, 123)

    def test_lift_objmode_array_in(self):
        def bar(arr):
            print({'arr': arr // 2})
            # arr is modified. the effect is visible outside.
            arr *= 2

        def foo(nelem):
            arr = np.arange(nelem).astype(np.int64)
            with objmode_context:
                # arr is modified inplace inside bar()
                bar(arr)
            return arr + 1

        nelem = 10
        self.assert_equal_return_and_stdout(foo, nelem)

    def test_lift_objmode_define_new_unused(self):
        def bar(y):
            print(y)

        def foo(x):
            with objmode_context():
                y = 2 + x           # defined but unused outside
                a = np.arange(y)    # defined but unused outside
                bar(a)
            return x

        arg = 123
        self.assert_equal_return_and_stdout(foo, arg)

    def test_lift_objmode_return_simple(self):
        def inverse(x):
            print(x)
            return 1 / x

        def foo(x):
            with objmode_context(y="float64"):
                y = inverse(x)
            return x, y

        def foo_nonglobal(x):
            with numba.objmode(y="float64"):
                y = inverse(x)
            return x, y

        arg = 123
        self.assert_equal_return_and_stdout(foo, arg)
        self.assert_equal_return_and_stdout(foo_nonglobal, arg)

    def test_lift_objmode_return_array(self):
        def inverse(x):
            print(x)
            return 1 / x

        def foo(x):
            with objmode_context(y="float64[:]", z="int64"):
                y = inverse(x)
                z = int(y[0])
            return x, y, z

        arg = np.arange(1, 10, dtype=np.float64)
        self.assert_equal_return_and_stdout(foo, arg)

    @expected_failure_for_list_arg
    def test_lift_objmode_using_list(self):
        def foo(x):
            with objmode_context(y="float64[:]"):
                print(x)
                x[0] = 4
                print(x)
                y = [1, 2, 3] + x
                y = np.asarray([1 / i for i in y])
            return x, y

        arg = [1, 2, 3]
        self.assert_equal_return_and_stdout(foo, arg)

    def test_lift_objmode_var_redef(self):
        def foo(x):
            for x in range(x):
                pass
            if x:
                x += 1
            with objmode_context(x="intp"):
                print(x)
                x -= 1
                print(x)
                for i in range(x):
                    x += i
                    print(x)
            return x

        arg = 123
        self.assert_equal_return_and_stdout(foo, arg)

    @expected_failure_for_list_arg
    def test_case01_mutate_list_ahead_of_ctx(self):
        def foo(x, z):
            x[2] = z

            with objmode_context():
                # should print [1, 2, 15] but prints [1, 2, 3]
                print(x)

            with objmode_context():
                x[2] = 2 * z
                # should print [1, 2, 30] but prints [1, 2, 15]
                print(x)

            return x

        self.assert_equal_return_and_stdout(foo, [1, 2, 3], 15)

    def test_case02_mutate_array_ahead_of_ctx(self):
        def foo(x, z):
            x[2] = z

            with objmode_context():
                # should print [1, 2, 15]
                print(x)

            with objmode_context():
                x[2] = 2 * z
                # should print [1, 2, 30]
                print(x)

            return x

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x, 15)

    @expected_failure_for_list_arg
    def test_case03_create_and_mutate(self):
        def foo(x):
            with objmode_context(y='List(int64)'):
                y = [1, 2, 3]
            with objmode_context():
                y[2] = 10
            return y
        self.assert_equal_return_and_stdout(foo, 1)

    def test_case04_bogus_variable_type_info(self):

        def foo(x):
            # should specifying nonsense type info be considered valid?
            with objmode_context(k="float64[:]"):
                print(x)
            return x

        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(errors.TypingError) as raises:
            cfoo(x)
        self.assertIn(
            "Invalid type annotation on non-outgoing variables",
            str(raises.exception),
            )

    def test_case05_bogus_type_info(self):
        def foo(x):
            # should specifying the wrong type info be considered valid?
            # z is complex.
            # Note: for now, we will coerce for scalar and raise for array
            with objmode_context(z="float64[:]"):
                z = x + 1.j
            return z

        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(TypeError) as raises:
            got = cfoo(x)
        self.assertIn(
            ("can't unbox array from PyObject into native value."
             "  The object maybe of a different type"),
            str(raises.exception),
        )

    def test_case06_double_objmode(self):
        def foo(x):
            # would nested ctx in the same scope ever make sense? Is this
            # pattern useful?
            with objmode_context():
                #with npmmode_context(): not implemented yet
                    with objmode_context():
                        print(x)
            return x

        with self.assertRaises(errors.TypingError) as raises:
            njit(foo)(123)
        # Check that an error occurred in with-lifting in objmode
        pat = ("During: resolving callee type: "
               "type\(ObjModeLiftedWith\(<.*>\)\)")
        self.assertRegexpMatches(str(raises.exception), pat)

    def test_case07_mystery_key_error(self):
        # this raises a key error
        def foo(x):
            with objmode_context():
                t = {'a': x}
            return x, t
        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(errors.TypingError) as raises:
            cfoo(x)
        self.assertIn(
            "missing type annotation on outgoing variables",
            str(raises.exception),
            )

    def test_case08_raise_from_external(self):
        # this segfaults, expect its because the dict needs to raise as '2' is
        # not in the keys until a later loop (looking for `d['0']` works fine).
        d = dict()

        def foo(x):
            for i in range(len(x)):
                with objmode_context():
                    k = str(i)
                    v = x[i]
                    d[k] = v
                    print(d['2'])
            return x

        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(KeyError) as raises:
            cfoo(x)
        self.assertEqual(str(raises.exception), "'2'")

    def test_case09_explicit_raise(self):
        def foo(x):
            with objmode_context():
                raise ValueError()
            return x

        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(errors.CompilerError) as raises:
            cfoo(x)
        self.assertIn(
            ('unsupported controlflow due to return/raise statements inside '
             'with block'),
            str(raises.exception),
        )

    @expected_failure_for_list_arg
    def test_case10_mutate_across_contexts(self):
        # This shouldn't work due to using List as input.
        def foo(x):
            with objmode_context(y='List(int64)'):
                y = [1, 2, 3]
            with objmode_context():
                y[2] = 10
            return y

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_case10_mutate_array_across_contexts(self):
        # Sub-case of case-10.
        def foo(x):
            with objmode_context(y='int64[:]'):
                y = np.asarray([1, 2, 3], dtype='int64')
            with objmode_context():
                # Note: `y` is not an output.
                y[2] = 10
            return y

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_case11_define_function_in_context(self):
        # should this work? no, `make_function` opcode not supported
        def foo(x):
            with objmode_context():
                def bar(y):
                    return y + 1
            return x

        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(errors.TypingError) as raises:
            cfoo(x)
        self.assertIn(
            'op code: make_function',
            str(raises.exception),
        )

    def test_case12_njit_inside_a_objmode_ctx(self):
        # TODO: is this still the cases?
        # this works locally but not inside this test, probably due to the way
        # compilation is being done
        def bar(y):
            return y + 1

        def foo(x):
            with objmode_context(y='int64[:]'):
                y = njit(bar)(x).astype('int64')
            return x + y

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_case13_branch_to_objmode_ctx(self):
        # Checks for warning in dataflow.py due to mishandled stack offset
        # dataflow.py:57: RuntimeWarning: inconsistent stack offset ...
        def foo(x, wobj):
            if wobj:
                with objmode_context(y='int64[:]'):
                    y = (x + 1).astype('int64')
            else:
                y = x + 2

            return x + y

        x = np.array([1, 2, 3], dtype='int64')

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always", RuntimeWarning)
            self.assert_equal_return_and_stdout(foo, x, True)
        # Assert no warnings from dataflow.py
        for each in w:
            self.assertFalse(each.filename.endswith('dataflow.py'),
                             msg='there were warnings in dataflow.py')

    def test_case14_return_direct_from_objmode_ctx(self):
        # fails with:
        # AssertionError: Failed in nopython mode pipeline (step: Handle with contexts)
        # ending offset is not a label
        def foo(x):
            with objmode_context(x='int64[:]'):
                return x
        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(errors.CompilerError) as raises:
            cfoo(x)
        self.assertIn(
            ('unsupported controlflow due to return/raise statements inside '
             'with block'),
            str(raises.exception),
        )

    # No easy way to handle this yet.
    @unittest.expectedFailure
    def test_case15_close_over_objmode_ctx(self):
        # Fails with Unsupported constraint encountered: enter_with $phi8.1
        def foo(x):
            j = 10

            def bar(x):
                with objmode_context(x='int64[:]'):
                    print(x)
                    return x + j
            return bar(x) + 2
        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    @skip_unless_scipy
    def test_case16_scipy_call_in_objmode_ctx(self):
        from scipy import sparse as sp

        def foo(x):
            with objmode_context(k='int64'):
                print(x)
                spx = sp.csr_matrix(x)
                k = spx[0, 0]
            return k
        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_case17_print_own_bytecode(self):
        import dis

        def foo(x):
            with objmode_context():
                dis.dis(foo)
        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    @expected_failure_for_function_arg
    def test_case18_njitfunc_passed_to_objmode_ctx(self):
        def foo(func, x):
            with objmode_context():
                func(x[0])

        x = np.array([1, 2, 3])
        fn = njit(lambda z: z + 5)
        self.assert_equal_return_and_stdout(foo, fn, x)

    def test_case19_recursion(self):
        def foo(x):
            with objmode_context():
                if x == 0:
                    return 7
            ret = foo(x - 1)
            return ret
        x = np.array([1, 2, 3])
        cfoo = njit(foo)
        with self.assertRaises(errors.CompilerError) as raises:
            cfoo(x)
        msg = "Does not support with-context that contain branches"
        self.assertIn(msg, str(raises.exception))

    @unittest.expectedFailure
    def test_case20_rng_works_ok(self):
        def foo(x):
            np.random.seed(0)
            y = np.random.rand()
            with objmode_context(z="float64"):
                # It's known that the random state does not sync
                z = np.random.rand()
            return x + z + y

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_case21_rng_seed_works_ok(self):
        def foo(x):
            np.random.seed(0)
            y = np.random.rand()
            with objmode_context(z="float64"):
                # Similar to test_case20_rng_works_ok but call seed
                np.random.seed(0)
                z = np.random.rand()
            return x + z + y

        x = np.array([1, 2, 3])
        self.assert_equal_return_and_stdout(foo, x)

    def test_example01(self):
        # Example from _ObjModeContextType.__doc__
        def bar(x):
            return np.asarray(list(reversed(x.tolist())))

        @njit
        def foo():
            x = np.arange(5)
            with objmode(y='intp[:]'):  # annotate return type
                # this region is executed by object-mode.
                y = x + bar(x)
            return y

        self.assertPreciseEqual(foo(), foo.py_func())
        self.assertIs(objmode, objmode_context)

class TestBogusContext(BaseTestWithLifting):
    def test_undefined_global(self):
        the_ir = get_func_ir(lift_undefiend)

        with self.assertRaises(errors.CompilerError) as raises:
            with_lifting(
                the_ir, self.typingctx, self.targetctx, self.flags, locals={},
            )
        self.assertIn(
            "Undefined variable used as context manager",
            str(raises.exception),
            )

    def test_invalid(self):
        the_ir = get_func_ir(lift_invalid)

        with self.assertRaises(errors.CompilerError) as raises:
            with_lifting(
                the_ir, self.typingctx, self.targetctx, self.flags, locals={},
            )
        self.assertIn(
            "Unsupported context manager in use",
            str(raises.exception),
            )


if __name__ == '__main__':
    unittest.main()
