Skip to content

Commit 86df1a0

Browse files
bushshrubfacebook-github-bot
authored andcommitted
Validation of fn on input_col for Datapipes (#755)
Summary: Fixes #362 ### Changes - Uses validator from torch main to apply checks - Apply check on `FlatMapperIterDataPipe` Pull Request resolved: #755 Reviewed By: ejguan Differential Revision: D39239201 Pulled By: NivekT fbshipit-source-id: 4a258da510c605abbda4116f0e6794db5432d1f9
1 parent 61bab05 commit 86df1a0

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchdata/datapipes/iter/transform/callable.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union
99

1010
from torch.utils.data import functional_datapipe, IterDataPipe
11-
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
11+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, validate_input_col
1212

1313
T_co = TypeVar("T_co", covariant=True)
1414

@@ -113,6 +113,10 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]):
113113
Args:
114114
datapipe: Source IterDataPipe
115115
fn: the function to be applied to each element in the DataPipe, the output must be a Sequence
116+
input_col: Index or indices of data which ``fn`` is applied, such as:
117+
- ``None`` as default to apply ``fn`` to the data directly.
118+
- Integer(s) is/are used for list/tuple.
119+
- Key(s) is/are used for dict.
116120
117121
Example:
118122
>>> from torchdata.datapipes.iter import IterableWrapper
@@ -139,6 +143,7 @@ def __init__(self, datapipe: IterDataPipe, fn: Callable = None, input_col=None)
139143
_check_unpickable_fn(fn)
140144
self.fn = fn # type: ignore[assignment]
141145
self.input_col = input_col
146+
validate_input_col(fn, input_col)
142147

143148
def _apply_fn(self, data):
144149
if self.input_col is None:

0 commit comments

Comments
 (0)