Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Support for IterableDatasets #422

Closed
Closed
@miguelvr

Description

@miguelvr

🚀 Feature

Pytorch introduced the IterableDataset in v1.2 to allow users to process streams of information. ClassyVision currently only supports map style datasets, it would be nice to extend the support to IterableDatasets, given that they are especially useful to process video streams.

Motivation / Pitch

Map style datasets assume each sample can be read completely independently from each other. In some situations, such as processing video streams, it is extremely expensive to open and close a video stream N times to read N frames. IterableDatasets allow streams to be open once and then results are yielded as requested by the data loader, which is substantially more efficient.

Additional context

There are a few important differences between Map style datasets and iterable datasets, that break the current classy vision dataset paradigm:

  1. In an iterable dataset, there is no __getitem__ method, this is replaced by the __iter__ method
  2. The __len__ method is optional in iterable dataset
  3. Samplers do not work with IterableDatasets, sampling and shuffling has to be handled in the dataset

I've come up with a template dataset called ChunkDataset hides some of this complexity away, which might be nice to help beginner users to get started. Nevertheless, in order to get this working with my code, I had to subclass ClassificationTask to modify it and had to create and entirely new base class for this style of dataset (ClassyDataset is not compatible).

chunk.py

class ChunkDataset(IterableDataset):
    def __init__(self, indices: List, process_fn: Callable):
        """
        Subset of IterableDataset to serve as a base class to process streams of data,
        such as audio, video or text.

        Args:
            indices (list): list of arguments to provide to process_fn
            process_fn (callable): function that processes the indices and returns an iterator
        """
        self.idxs = indices
        self._process_fn = process_fn
        self.epoch = 0
        self.shuffle = False

        # replacement for distributed sampler
        distributed = dist.is_available() and dist.is_initialized()
        if distributed:
            num_replicas = dist.get_world_size()
            rank = dist.get_rank()
            self.idxs = self.idxs[rank::num_replicas]

    def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        if self.shuffle:
            indices = torch.randperm(len(self.idxs), generator=g).tolist()
            idxs = [self.idxs[i] for i in indices]
        else:
            idxs = self.idxs

        return self._process_fn(idxs)

    @staticmethod
    def worker_init_fn(worker_id):
        worker_info = torch.utils.data.get_worker_info()
        dataset = worker_info.dataset  # the dataset copy in this worker process
        n_workers = worker_info.num_workers
        dataset.idxs = dataset.idxs[worker_id::n_workers]

    def set_epoch(self, epoch):
        self.epoch = epoch
        return self

    def set_shuffle(self, shuffle: bool = True):
        self.shuffle = shuffle
        return self

    def __len__(self):
        raise NotImplementedError


class ClassyChunkDataset(IterableDataset):
    """
    Class representing a dataset abstraction to wrap a ChunkDataset.

    This class wraps a :class:`ChunkDataset` via the `dataset` attribute
    and configures the dataloaders needed to access the datasets.
    Transforms which need to be applied to the data should be specified in this class.
    ClassyChunkDataset can be instantiated from a configuration file as well.
    """

    def __init__(
        self,
        dataset: ChunkDataset,
        batchsize_per_replica: int,
        shuffle: bool,
        transform: Optional[Union[ClassyTransform, Callable]],
    ) -> None:
        """
        Constructor for a ClassyDataset.

        Args:
            batchsize_per_replica: Positive integer indicating batch size for each
                replica
            shuffle: Whether to shuffle between epochs
            transform: When set, transform to be applied to each sample
            num_samples: When set, this restricts the number of samples provided by
                the dataset
        """
        # Asserts:
        assert is_pos_int(
            batchsize_per_replica
        ), "batchsize_per_replica must be a positive int"
        assert isinstance(shuffle, bool), "shuffle must be a boolean"

        # Assignments:
        self.batchsize_per_replica = batchsize_per_replica
        self.shuffle = shuffle
        self.transform = transform
        self.dataset = dataset

        if self.shuffle:
            self.dataset = self.dataset.set_shuffle()

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "ClassyDataset":
        """Instantiates a ClassyDataset from a configuration.

        Args:
            config: A configuration for the ClassyDataset.

        Returns:
            A ClassyDataset instance.
        """
        raise NotImplementedError

    @classmethod
    def parse_config(cls, config: Dict[str, Any]):
        """
        This function parses out common config options.

        Args:
            config: A dict with the following string keys -

                | *batchsize_per_replica* (int): Must be a positive int, batch size
                |    for each replica
                | *use_shuffle* (bool): Whether to enable shuffling for the dataset
                | *num_samples* (int, optional): When set, restricts the number of
                     samples in a dataset
                | *transforms*: list of tranform configurations to be applied in order

        Returns:
            A tuple containing the following variables -
                | *transform_config*: Config for the dataset transform. Can be passed to
                |    :func:`transforms.build_transform`
                | *batchsize_per_replica*: Batch size per replica
                | *shuffle*: Whether we should shuffle between epochs
                | *num_samples*: When set, restricts the number of samples in a dataset
        """
        batchsize_per_replica = config.get("batchsize_per_replica")
        shuffle = config.get("use_shuffle")
        num_samples = config.get("num_samples")
        transform_config = config.get("transforms")
        return transform_config, batchsize_per_replica, shuffle, num_samples

    def __iter__(self):
        for sample in self.dataset:
            if self.transform is not None:
                sample = self.transform(sample)
            yield sample

    def __len__(self):
        return len(self.dataset)

    def iterator(self, *args, **kwargs):
        """
        Returns an iterable which can be used to iterate over the data.

        Args:
            shuffle_seed (int, optional): Seed for the shuffle
            current_phase_id (int, optional): The epoch being fetched. Needed so that
                each epoch has a different shuffle order
        Returns:
            An iterable over the data
        """
        # TODO: Fix naming to be consistent (i.e. everyone uses epoch)
        epoch = kwargs.get("current_phase_id", 0)
        assert isinstance(epoch, int), "Epoch must be an int"

        self.dataset = self.dataset.set_epoch(epoch)

        return DataLoader(
            self,
            batch_size=self.batchsize_per_replica,
            num_workers=kwargs.get("num_workers", 0),
            pin_memory=kwargs.get("pin_memory", False),
            multiprocessing_context=kwargs.get("multiprocessing_context", None),
            worker_init_fn=self.worker_init_fn
        )

    def get_batchsize_per_replica(self):
        """
        Get the batch size per replica.

        Returns:
            The batch size for each replica.
        """
        return self.batchsize_per_replica

    def get_global_batchsize(self):
        """
        Get the global batch size, combined over all the replicas.

        Returns:
            The overall batch size of the dataset.
        """
        return self.get_batchsize_per_replica() * get_world_size()

    @staticmethod
    def worker_init_fn(worker_id):
        worker_info = torch.utils.data.get_worker_info()
        dataset = worker_info.dataset.dataset  # the dataset copy in this worker process
        n_workers = worker_info.num_workers
        dataset.idxs = dataset.idxs[worker_id::n_workers]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions