Skip to content

Commit 16f963e

Browse files
ejguanfacebook-github-bot
authored andcommitted
Raise warning for unpickable local function (#80232)
Summary: X-link: pytorch/pytorch#80232 Pull Request resolved: #547 Fixes #538 - Improve the validation function to raise warning about unpickable function when either lambda or local function is provided to DataPipe. - The inner function from functools.partial object is extracted as well for validation - Mimic the behavior of pickle module for local lambda function: It would only raise Error for the local function rather than lambda function. So, we will raise warning about local function not lambda function. ```py >>> import pickle >>> def fn(): ... lf = lambda x: x ... pickle.dumps(lf) >>> pickle.dumps(fn) AttributeError: Can't pickle local object 'fn.<locals>.<lambda>' ``` This Diff also fixes the Error introduced by pytorch/pytorch#79344 Differential Revision: D37417556 fbshipit-source-id: c17e475bde703f2f55af140f2155d41efb45f049
1 parent 75f31dc commit 16f963e

File tree

6 files changed

+22
-22
lines changed

6 files changed

+22
-22
lines changed

test/test_serialization.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -374,15 +374,13 @@ def test_serializable_with_dill(self):
374374
else:
375375
dp_no_attribute_error = (iterdp.OnDiskCacheHolder,)
376376
try:
377-
with warnings.catch_warnings(record=True) as wa:
377+
with self.assertWarnsRegex(UserWarning, r"^Local function is not supported by pickle"):
378378
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
379-
self.assertEqual(len(wa), 1)
380-
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
381-
if isinstance(datapipe, dp_no_attribute_error):
379+
if isinstance(datapipe, dp_no_attribute_error):
380+
_ = pickle.dumps(datapipe)
381+
else:
382+
with self.assertRaises(AttributeError):
382383
_ = pickle.dumps(datapipe)
383-
else:
384-
with self.assertRaises(AttributeError):
385-
_ = pickle.dumps(datapipe)
386384
except Exception as e:
387385
print(f"{dpipe} is failing.")
388386
raise e

torchdata/datapipes/iter/transform/callable.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Callable, Iterator, List, TypeVar
88

99
from torch.utils.data import functional_datapipe, IterDataPipe
10-
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
10+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
1111

1212
T_co = TypeVar("T_co", covariant=True)
1313

@@ -59,7 +59,7 @@ def __init__(
5959
) -> None:
6060
self.datapipe = datapipe
6161

62-
_check_lambda_fn(fn)
62+
_check_unpickable_fn(fn)
6363
self.fn = fn # type: ignore[assignment]
6464

6565
assert batch_size > 0, "Batch size is required to be larger than 0!"
@@ -118,7 +118,7 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]):
118118
def __init__(self, datapipe: IterDataPipe, fn: Callable, input_col=None) -> None:
119119
self.datapipe = datapipe
120120

121-
_check_lambda_fn(fn)
121+
_check_unpickable_fn(fn)
122122
self.fn = fn # type: ignore[assignment]
123123
self.input_col = input_col
124124

torchdata/datapipes/iter/util/cacheholder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
raise
2828

2929

30-
from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE
30+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE
3131

3232
from torch.utils.data.graph import traverse
3333
from torchdata.datapipes import functional_datapipe
@@ -184,7 +184,8 @@ def __init__(
184184
):
185185
self.source_datapipe = source_datapipe
186186

187-
_check_lambda_fn(filepath_fn)
187+
if filepath_fn is not None:
188+
_check_unpickable_fn(filepath_fn)
188189
filepath_fn = _generator_to_list(filepath_fn) if inspect.isgeneratorfunction(filepath_fn) else filepath_fn
189190

190191
if hash_dict is not None and hash_type not in ("sha256", "md5"):

torchdata/datapipes/iter/util/combining.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Callable, Iterator, Optional, TypeVar
1010

1111
from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
12-
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
12+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
1313

1414
T_co = TypeVar("T_co", covariant=True)
1515

@@ -64,14 +64,14 @@ def __init__(
6464
raise TypeError(f"ref_datapipe must be a IterDataPipe, but its type is {type(ref_datapipe)} instead.")
6565
self.source_datapipe = source_datapipe
6666
self.ref_datapipe = ref_datapipe
67-
_check_lambda_fn(key_fn)
67+
_check_unpickable_fn(key_fn)
6868
self.key_fn = key_fn
6969
if ref_key_fn is not None:
70-
_check_lambda_fn(ref_key_fn)
70+
_check_unpickable_fn(ref_key_fn)
7171
self.ref_key_fn = key_fn if ref_key_fn is None else ref_key_fn
7272
self.keep_key = keep_key
7373
if merge_fn is not None:
74-
_check_lambda_fn(merge_fn)
74+
_check_unpickable_fn(merge_fn)
7575
self.merge_fn = merge_fn
7676
if buffer_size is not None and buffer_size <= 0:
7777
raise ValueError("'buffer_size' is required to be either None or a positive integer.")
@@ -185,10 +185,10 @@ def __init__(
185185
raise TypeError(f"map_datapipe must be a MapDataPipe, but its type is {type(map_datapipe)} instead.")
186186
self.source_iterdatapipe: IterDataPipe = source_iterdatapipe
187187
self.map_datapipe: MapDataPipe = map_datapipe
188-
_check_lambda_fn(key_fn)
188+
_check_unpickable_fn(key_fn)
189189
self.key_fn: Callable = key_fn
190190
if merge_fn is not None:
191-
_check_lambda_fn(merge_fn)
191+
_check_unpickable_fn(merge_fn)
192192
self.merge_fn: Optional[Callable] = merge_fn
193193
self.length: int = -1
194194

torchdata/datapipes/iter/util/converter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Callable, Dict, Optional
1010

1111
from torch.utils.data import IterDataPipe, MapDataPipe
12-
from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE
12+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE
1313

1414
if DILL_AVAILABLE:
1515
import dill
@@ -52,7 +52,8 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No
5252
if not isinstance(datapipe, IterDataPipe):
5353
raise TypeError(f"IterToMapConverter can only apply on IterDataPipe, but found {type(datapipe)}")
5454
self.datapipe = datapipe
55-
_check_lambda_fn(key_value_fn)
55+
if key_value_fn is not None:
56+
_check_unpickable_fn(key_value_fn)
5657
self.key_value_fn = key_value_fn # type: ignore[assignment]
5758
self._map = None
5859
self._length = -1

torchdata/datapipes/iter/util/paragraphaggregator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import Callable, Iterator, List, Tuple, TypeVar
88

9-
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
9+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
1010

1111
from torchdata.datapipes import functional_datapipe
1212
from torchdata.datapipes.iter import IterDataPipe
@@ -44,7 +44,7 @@ class ParagraphAggregatorIterDataPipe(IterDataPipe[Tuple[str, str]]):
4444

4545
def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Callable = _default_line_join) -> None:
4646
self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe
47-
_check_lambda_fn(joiner)
47+
_check_unpickable_fn(joiner)
4848
self.joiner: Callable = joiner
4949
self.buffer: List = []
5050

0 commit comments

Comments
 (0)