Skip to content

Commit de3f755

Browse files
Avoid schema enforcement from meta on Arrow data in P2P shuffling (#8235)
1 parent 310d2f0 commit de3f755

File tree

3 files changed

+41
-16
lines changed

3 files changed

+41
-16
lines changed

distributed/shuffle/_arrow.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from pathlib import Path
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING
55

66
from packaging.version import parse
77

@@ -80,15 +80,12 @@ def deserialize_table(buffer: bytes) -> pa.Table:
8080
return reader.read_all()
8181

8282

83-
def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]:
83+
def read_from_disk(path: Path) -> tuple[list[pa.Table], int]:
8484
import pyarrow as pa
8585

86-
from dask.dataframe.dispatch import pyarrow_schema_dispatch
87-
8886
batch_size = parse_bytes("1 MiB")
8987
batch = []
9088
shards = []
91-
schema = pyarrow_schema_dispatch(meta, preserve_index=True)
9289

9390
with pa.OSFile(str(path), mode="rb") as f:
9491
size = f.seek(0, whence=2)
@@ -103,17 +100,17 @@ def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]:
103100

104101
if offset - prev >= batch_size:
105102
table = pa.concat_tables(batch)
106-
shards.append(_copy_table(table, schema))
103+
shards.append(_copy_table(table))
107104
batch = []
108105
prev = offset
109106
if batch:
110107
table = pa.concat_tables(batch)
111-
shards.append(_copy_table(table, schema))
108+
shards.append(_copy_table(table))
112109
return shards, size
113110

114111

115-
def _copy_table(table: pa.Table, schema: pa.Schema) -> pa.Table:
112+
def _copy_table(table: pa.Table) -> pa.Table:
116113
import pyarrow as pa
117114

118115
arrs = [pa.concat_arrays(column.chunks) for column in table.columns]
119-
return pa.table(data=arrs, schema=schema)
116+
return pa.table(data=arrs, schema=table.schema)

distributed/shuffle/_shuffle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,8 @@ def _(partition_id: int, meta: pd.DataFrame) -> pd.DataFrame:
499499
def _get_assigned_worker(self, id: int) -> str:
500500
return self.worker_for[id]
501501

502-
def read(self, path: Path) -> tuple[Any, int]:
503-
return read_from_disk(path, self.meta)
502+
def read(self, path: Path) -> tuple[pa.Table, int]:
503+
return read_from_disk(path)
504504

505505

506506
@dataclass(frozen=True)

distributed/shuffle/tests/test_shuffle.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@ def __init__(self, value: int) -> None:
11291129

11301130
out = {}
11311131
for k in range(npartitions):
1132-
shards, _ = read_from_disk(tmp_path / str(k), meta)
1132+
shards, _ = read_from_disk(tmp_path / str(k))
11331133
out[k] = convert_shards(shards, meta)
11341134

11351135
shuffled_df = pd.concat(df for df in out.values())
@@ -2100,7 +2100,7 @@ async def test_replace_stale_shuffle(c, s, a, b):
21002100

21012101

21022102
@gen_cluster(client=True)
2103-
async def test_handle_null_partitions_p2p_shuffling(c, s, *workers):
2103+
async def test_handle_null_partitions_p2p_shuffling(c, s, a, b):
21042104
data = [
21052105
{"companies": [], "id": "a", "x": None},
21062106
{"companies": [{"id": 3}, {"id": 5}], "id": "b", "x": None},
@@ -2113,8 +2113,8 @@ async def test_handle_null_partitions_p2p_shuffling(c, s, *workers):
21132113
result = await c.compute(ddf)
21142114
dd.assert_eq(result, df)
21152115

2116-
await c.close()
2117-
await asyncio.gather(*[check_worker_cleanup(w) for w in workers])
2116+
await check_worker_cleanup(a)
2117+
await check_worker_cleanup(b)
21182118
await check_scheduler_cleanup(s)
21192119

21202120

@@ -2133,7 +2133,35 @@ def make_partition(i):
21332133
result = await result
21342134
expected = await expected
21352135
dd.assert_eq(result, expected)
2136-
del result
2136+
2137+
await check_worker_cleanup(a)
2138+
await check_worker_cleanup(b)
2139+
await check_scheduler_cleanup(s)
2140+
2141+
2142+
@gen_cluster(client=True)
2143+
async def test_handle_object_columns_p2p(c, s, a, b):
2144+
with dask.config.set({"dataframe.convert-string": False}):
2145+
df = pd.DataFrame(
2146+
{
2147+
"a": [1, 2, 3],
2148+
"b": [
2149+
np.asarray([1, 2, 3]),
2150+
np.asarray([4, 5, 6]),
2151+
np.asarray([7, 8, 9]),
2152+
],
2153+
"c": ["foo", "bar", "baz"],
2154+
}
2155+
)
2156+
2157+
ddf = dd.from_pandas(
2158+
df,
2159+
npartitions=2,
2160+
)
2161+
shuffled = ddf.shuffle(on="a")
2162+
2163+
result = await c.compute(shuffled)
2164+
dd.assert_eq(result, df)
21372165

21382166
await check_worker_cleanup(a)
21392167
await check_worker_cleanup(b)

0 commit comments

Comments
 (0)