Skip to content

Commit 045dc64

Browse files
fjettercrusaderky
andauthored
Ensure large payload can be serialized and sent over comms (#8507)
Co-authored-by: crusaderky <[email protected]>
1 parent 72f297a commit 045dc64

File tree

4 files changed

+66
-15
lines changed

4 files changed

+66
-15
lines changed

distributed/comm/tcp.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@
4444
logger = logging.getLogger(__name__)
4545

4646

47-
# Workaround for OpenSSL 1.0.2.
48-
# Can drop with OpenSSL 1.1.1 used by Python 3.10+.
49-
# ref: https://bugs.python.org/issue42853
50-
if sys.version_info < (3, 10):
51-
OPENSSL_MAX_CHUNKSIZE = 256 ** ctypes.sizeof(ctypes.c_int) // 2 - 1
52-
else:
53-
OPENSSL_MAX_CHUNKSIZE = 256 ** ctypes.sizeof(ctypes.c_size_t) - 1
54-
47+
# We must not load more than this into a buffer at a time
48+
# It's currently unclear why that is
49+
# see
50+
# - https://github.com/dask/distributed/pull/5854
51+
# - https://bugs.python.org/issue42853
52+
# - https://github.com/dask/distributed/pull/8507
53+
54+
C_INT_MAX = 256 ** ctypes.sizeof(ctypes.c_int) // 2 - 1
5555
MAX_BUFFER_SIZE = MEMORY_LIMIT / 2
5656

5757

@@ -286,8 +286,8 @@ async def write(self, msg, serializers=None, on_error="message"):
286286
2,
287287
range(
288288
0,
289-
each_frame_nbytes + OPENSSL_MAX_CHUNKSIZE,
290-
OPENSSL_MAX_CHUNKSIZE,
289+
each_frame_nbytes + C_INT_MAX,
290+
C_INT_MAX,
291291
),
292292
):
293293
chunk = each_frame[i:j]
@@ -360,7 +360,7 @@ async def read_bytes_rw(stream: IOStream, n: int) -> memoryview:
360360

361361
for i, j in sliding_window(
362362
2,
363-
range(0, n + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE),
363+
range(0, n + C_INT_MAX, C_INT_MAX),
364364
):
365365
chunk = buf[i:j]
366366
actual = await stream.read_into(chunk) # type: ignore[arg-type]
@@ -432,7 +432,8 @@ class TLS(TCP):
432432
A TLS-specific version of TCP.
433433
"""
434434

435-
max_shard_size = min(OPENSSL_MAX_CHUNKSIZE, TCP.max_shard_size)
435+
# Workaround for OpenSSL 1.0.2 (can drop with OpenSSL 1.1.1)
436+
max_shard_size = min(C_INT_MAX, TCP.max_shard_size)
436437

437438
def _read_extra(self):
438439
TCP._read_extra(self)

distributed/protocol/tests/test_protocol.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,30 @@ def test_fallback_to_pickle():
208208
assert L[0].count(b"__Pickled__") == 1
209209
assert L[0].count(b"__Serialized__") == 1
210210
assert loads(L) == {np.int64(1): {2: "a"}, 3: ("b", "c"), 4: "d"}
211+
212+
213+
@pytest.mark.slow
214+
@pytest.mark.parametrize("typ", [bytes, str, "ext"])
215+
def test_large_payload(typ):
216+
"""See also: test_core.py::test_large_payload"""
217+
critical_size = 2**31 + 1 # >2 GiB
218+
if typ == bytes:
219+
large_payload = critical_size * b"0"
220+
expected = large_payload
221+
elif typ == str:
222+
large_payload = critical_size * "0"
223+
expected = large_payload
224+
# Testing array and map dtypes is practically not possible since we'd have
225+
# to create an actual list or dict object of critical size (i.e. not the
226+
# content but the container itself). These are so large that msgpack is
227+
# running forever
228+
# elif typ == "array":
229+
# large_payload = [b"0"] * critical_size
230+
# expected = tuple(large_payload)
231+
# elif typ == "map":
232+
# large_payload = {x: b"0" for x in range(critical_size)}
233+
# expected = large_payload
234+
elif typ == "ext":
235+
large_payload = msgpack.ExtType(1, b"0" * critical_size)
236+
expected = large_payload
237+
assert loads(dumps(large_payload)) == expected

distributed/protocol/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
BIG_BYTES_SHARD_SIZE = dask.utils.parse_bytes(dask.config.get("distributed.comm.shard"))
1313

1414

15-
msgpack_opts = {
16-
("max_%s_len" % x): 2**31 - 1 for x in ["str", "bin", "array", "map", "ext"]
17-
}
15+
msgpack_opts = {}
1816
msgpack_opts["strict_map_key"] = False
1917
msgpack_opts["raw"] = False
2018

distributed/tests/test_core.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import contextlib
5+
import logging
56
import os
67
import random
78
import socket
@@ -1481,3 +1482,27 @@ def sync_handler(val):
14811482
assert ledger == list(range(n))
14821483
finally:
14831484
await comm.close()
1485+
1486+
1487+
@pytest.mark.slow
1488+
@gen_test(timeout=180)
1489+
async def test_large_payload(caplog):
1490+
"""See also: protocol/tests/test_protocol.py::test_large_payload"""
1491+
critical_size = 2**31 + 1 # >2 GiB
1492+
data = b"0" * critical_size
1493+
1494+
async with Server({"echo": echo_serialize}) as server:
1495+
await server.listen(0)
1496+
comm = await connect(server.address)
1497+
1498+
# FIXME https://github.com/dask/distributed/issues/8465
1499+
# At debug level, messages are dumped into the log. By default, pytest captures
1500+
# all logs, which would make this test extremely expensive to run.
1501+
with caplog.at_level(logging.INFO, logger="distributed.core"):
1502+
# Note: if we wrap data in to_serialize, it will be sent as a buffer, which
1503+
# is not encoded by msgpack.
1504+
await comm.write({"op": "echo", "x": data})
1505+
response = await comm.read()
1506+
1507+
assert response["result"] == data
1508+
await comm.close()

0 commit comments

Comments
 (0)