from __future__ import print_function, division, absolute_import

import subprocess
from time import sleep

import pytest
pytest.importorskip('mpi4py')

import requests

from distributed import Client
from distributed.utils import tmpfile
from distributed.metrics import time
from distributed.utils_test import popen
from distributed.utils_test import loop  # noqa: F401


@pytest.mark.parametrize('nanny', ['--nanny', '--no-nanny'])
def test_basic(loop, nanny):
    with tmpfile() as fn:
        with popen(['mpirun', '--np', '4', 'dask-mpi', '--scheduler-file', fn, nanny],
                   stdin=subprocess.DEVNULL):
            with Client(scheduler_file=fn) as c:

                start = time()
                while len(c.scheduler_info()['workers']) != 3:
                    assert time() < start + 10
                    sleep(0.2)

                assert c.submit(lambda x: x + 1, 10, workers=1).result() == 11


def test_no_scheduler(loop):
    with tmpfile() as fn:
        with popen(['mpirun', '--np', '2', 'dask-mpi', '--scheduler-file', fn],
                   stdin=subprocess.DEVNULL):
            with Client(scheduler_file=fn) as c:

                start = time()
                while len(c.scheduler_info()['workers']) != 1:
                    assert time() < start + 10
                    sleep(0.2)

                assert c.submit(lambda x: x + 1, 10).result() == 11
                with popen(['mpirun', '--np', '1', 'dask-mpi',
                            '--scheduler-file', fn, '--no-scheduler']):

                    start = time()
                    while len(c.scheduler_info()['workers']) != 2:
                        assert time() < start + 10
                        sleep(0.2)


def test_bokeh(loop):
    with tmpfile() as fn:
        with popen(['mpirun', '--np', '2', 'dask-mpi', '--scheduler-file', fn,
                    '--bokeh-port', '59583', '--bokeh-worker-port', '59584'],
                   stdin=subprocess.DEVNULL):

            for port in [59853, 59584]:
                start = time()
                while True:
                    try:
                        response = requests.get('http://localhost:%d/status/' % port)
                        assert response.ok
                        break
                    except Exception:
                        sleep(0.1)
                        assert time() < start + 20

    with pytest.raises(Exception):
        requests.get('http://localhost:59583/status/')
