Skip to content

Validation of fn on input_col for Datapipes #755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchdata/datapipes/iter/transform/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union

from torch.utils.data import functional_datapipe, IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, validate_input_col

T_co = TypeVar("T_co", covariant=True)

Expand Down Expand Up @@ -113,6 +113,10 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]):
Args:
datapipe: Source IterDataPipe
fn: the function to be applied to each element in the DataPipe, the output must be a Sequence
input_col: Index or indices of data which ``fn`` is applied, such as:
- ``None`` as default to apply ``fn`` to the data directly.
- Integer(s) is/are used for list/tuple.
- Key(s) is/are used for dict.

Example:
>>> from torchdata.datapipes.iter import IterableWrapper
Expand All @@ -139,6 +143,7 @@ def __init__(self, datapipe: IterDataPipe, fn: Callable = None, input_col=None)
_check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment]
self.input_col = input_col
validate_input_col(fn, input_col)

def _apply_fn(self, data):
if self.input_col is None:
Expand Down