From 2042349c3adf04c847d2a297d323964b15144ba7 Mon Sep 17 00:00:00 2001 From: Robert Date: Sun, 28 Aug 2022 11:23:04 +0800 Subject: [PATCH] Use validator from torch main --- torchdata/datapipes/iter/transform/callable.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 34c2a43c2..9bd00cbf6 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -8,7 +8,7 @@ from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union from torch.utils.data import functional_datapipe, IterDataPipe -from torch.utils.data.datapipes.utils.common import _check_unpickable_fn +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, validate_input_col T_co = TypeVar("T_co", covariant=True) @@ -113,6 +113,10 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]): Args: datapipe: Source IterDataPipe fn: the function to be applied to each element in the DataPipe, the output must be a Sequence + input_col: Index or indices of data which ``fn`` is applied, such as: + - ``None`` as default to apply ``fn`` to the data directly. + - Integer(s) is/are used for list/tuple. + - Key(s) is/are used for dict. Example: >>> from torchdata.datapipes.iter import IterableWrapper @@ -139,6 +143,7 @@ def __init__(self, datapipe: IterDataPipe, fn: Callable = None, input_col=None) _check_unpickable_fn(fn) self.fn = fn # type: ignore[assignment] self.input_col = input_col + validate_input_col(fn, input_col) def _apply_fn(self, data): if self.input_col is None: