from __future__ import print_function, division, absolute_import

import sys

import dask
import pytest

from distributed.protocol import loads, dumps, msgpack, maybe_compress, to_serialize
from distributed.protocol.compression import compressions
from distributed.protocol.serialize import Serialize, Serialized, serialize, deserialize
from distributed.utils import nbytes
from distributed.utils_test import slow


def test_protocol():
    for msg in [1, "a", b"a", {"x": 1}, {b"x": 1}, {"x": b""}, {}]:
        assert loads(dumps(msg)) == msg


def test_compression_1():
    pytest.importorskip("lz4")
    np = pytest.importorskip("numpy")
    x = np.ones(1000000)
    frames = dumps({"x": Serialize(x.tobytes())})
    assert sum(map(nbytes, frames)) < x.nbytes
    y = loads(frames)
    assert {"x": x.tobytes()} == y


def test_compression_2():
    pytest.importorskip("lz4")
    np = pytest.importorskip("numpy")
    x = np.random.random(10000)
    header, payload = dumps(x.tobytes())
    assert not header or not msgpack.loads(header, encoding="utf8").get("compression")


def test_compression_without_deserialization():
    pytest.importorskip("lz4")
    np = pytest.importorskip("numpy")
    x = np.ones(1000000)

    frames = dumps({"x": Serialize(x)})
    assert all(len(frame) < 1000000 for frame in frames)

    msg = loads(frames, deserialize=False)
    assert all(len(frame) < 1000000 for frame in msg["x"].frames)


def test_small():
    assert sum(map(nbytes, dumps(b""))) < 10
    assert sum(map(nbytes, dumps(1))) < 10


def test_small_and_big():
    d = {"x": (1, 2, 3), "y": b"0" * 10000000}
    L = dumps(d)
    assert loads(L) == d
    # assert loads([small_header, small]) == {'x': [1, 2, 3]}
    # assert loads([big_header, big]) == {'y': d['y']}


def test_maybe_compress():
    pass

    try_converters = [bytes, memoryview]
    try_compressions = ["zlib", "lz4"]

    payload = b"123"

    with dask.config.set({"distributed.comm.compression": None}):
        for f in try_converters:
            assert maybe_compress(f(payload)) == (None, payload)

    for compression in try_compressions:
        try:
            __import__(compression)
        except ImportError:
            continue

        with dask.config.set({"distributed.comm.compression": compression}):
            for f in try_converters:
                payload = b"123"
                assert maybe_compress(f(payload)) == (None, payload)

                payload = b"0" * 10000
                rc, rd = maybe_compress(f(payload))
                # For some reason compressing memoryviews can force blosc...
                assert rc in (compression, "blosc")
                assert compressions[rc]["decompress"](rd) == payload


def test_maybe_compress_sample():
    np = pytest.importorskip("numpy")
    lz4 = pytest.importorskip("lz4")
    payload = np.random.randint(0, 255, size=10000).astype("u1").tobytes()
    fmt, compressed = maybe_compress(payload)
    assert fmt is None
    assert compressed == payload


def test_large_bytes():
    for tp in (bytes, bytearray):
        msg = {"x": tp(b"0" * 1000000), "y": 1}
        frames = dumps(msg)
        assert loads(frames) == msg
        assert len(frames[0]) < 1000
        assert len(frames[1]) < 1000

        assert loads(frames, deserialize=False) == msg


@slow
def test_large_messages():
    np = pytest.importorskip("numpy")
    psutil = pytest.importorskip("psutil")
    pytest.importorskip("lz4")
    if psutil.virtual_memory().total < 8e9:
        return

    if sys.version_info.major == 2:
        return 2

    x = np.random.randint(0, 255, size=200000000, dtype="u1")

    msg = {
        "x": [Serialize(x), b"small_bytes"],
        "y": {"a": Serialize(x), "b": b"small_bytes"},
    }

    b = dumps(msg)
    msg2 = loads(b)
    assert msg["x"][1] == msg2["x"][1]
    assert msg["y"]["b"] == msg2["y"]["b"]
    assert (msg["x"][0].data == msg2["x"][0]).all()
    assert (msg["y"]["a"].data == msg2["y"]["a"]).all()


def test_large_messages_map():
    import psutil

    if psutil.virtual_memory().total < 8e9:
        pytest.skip("insufficient memory")

    x = {i: "mystring_%d" % i for i in range(100000)}

    b = dumps(x)
    x2 = loads(b)
    assert x == x2


def test_loads_deserialize_False():
    frames = dumps({"data": Serialize(123), "status": "OK"})
    msg = loads(frames)
    assert msg == {"data": 123, "status": "OK"}

    msg = loads(frames, deserialize=False)
    assert msg["status"] == "OK"
    assert isinstance(msg["data"], Serialized)

    result = deserialize(msg["data"].header, msg["data"].frames)
    assert result == 123


def test_loads_without_deserialization_avoids_compression():
    pytest.importorskip("lz4")
    b = b"0" * 100000

    msg = {"x": 1, "data": to_serialize(b)}
    frames = dumps(msg)

    assert sum(map(nbytes, frames)) < 10000

    msg2 = loads(frames, deserialize=False)
    assert sum(map(nbytes, msg2["data"].frames)) < 10000

    msg3 = dumps(msg2)
    msg4 = loads(msg3)

    assert msg4 == {"x": 1, "data": b"0" * 100000}


def eq_frames(a, b):
    if b"headers" in a:
        return msgpack.loads(a, use_list=False) == msgpack.loads(b, use_list=False)
    else:
        return a == b


def test_dumps_loads_Serialize():
    msg = {"x": 1, "data": Serialize(123)}
    frames = dumps(msg)
    assert len(frames) > 2
    result = loads(frames)
    assert result == {"x": 1, "data": 123}

    result2 = loads(frames, deserialize=False)
    assert result2["x"] == 1
    assert isinstance(result2["data"], Serialized)
    assert any(a is b for a in result2["data"].frames for b in frames)

    frames2 = dumps(result2)
    assert all(map(eq_frames, frames, frames2))

    result3 = loads(frames2)
    assert result == result3


def test_dumps_loads_Serialized():
    msg = {"x": 1, "data": Serialized(*serialize(123))}
    frames = dumps(msg)
    assert len(frames) > 2
    result = loads(frames)
    assert result == {"x": 1, "data": 123}

    result2 = loads(frames, deserialize=False)
    assert result2 == msg

    frames2 = dumps(result2)
    assert all(map(eq_frames, frames, frames2))

    result3 = loads(frames2)
    assert result == result3


@pytest.mark.skipif(sys.version_info[0] < 3, reason="NumPy doesnt use memoryviews")
def test_maybe_compress_memoryviews():
    np = pytest.importorskip("numpy")
    pytest.importorskip("lz4")
    x = np.arange(1000000, dtype="int64")
    compression, payload = maybe_compress(x.data)
    try:
        import blosc  # noqa: F401
    except ImportError:
        assert compression == "lz4"
        assert len(payload) < x.nbytes * 0.75
    else:
        assert compression == "blosc"
        assert len(payload) < x.nbytes / 10
