Skip to content

Commit 9750de7

Browse files
committed
Use validator from torch main
1 parent 719616a commit 9750de7

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchdata/datapipes/iter/transform/callable.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union
99

1010
from torch.utils.data import functional_datapipe, IterDataPipe
11-
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
11+
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, validate_input_col
1212

1313
T_co = TypeVar("T_co", covariant=True)
1414

@@ -103,6 +103,10 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]):
103103
Args:
104104
datapipe: Source IterDataPipe
105105
fn: the function to be applied to each element in the DataPipe, the output must be a Sequence
106+
input_col: Index or indices of data which ``fn`` is applied, such as:
107+
- ``None`` as default to apply ``fn`` to the data directly.
108+
- Integer(s) is/are used for list/tuple.
109+
- Key(s) is/are used for dict.
106110
107111
Example:
108112
>>> from torchdata.datapipes.iter import IterableWrapper
@@ -122,6 +126,7 @@ def __init__(self, datapipe: IterDataPipe, fn: Callable, input_col=None) -> None
122126
_check_unpickable_fn(fn)
123127
self.fn = fn # type: ignore[assignment]
124128
self.input_col = input_col
129+
validate_input_col(fn, input_col)
125130

126131
def _apply_fn(self, data):
127132
if self.input_col is None:

0 commit comments

Comments
 (0)