diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index 0202ca25f7ad8e..bb688ce6bf22c7 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -73,6 +73,10 @@ def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx