diff --git a/torchvision/datasets/hmdb51.py b/torchvision/datasets/hmdb51.py index bd06ea6e5c9..3cb4547b82e 100644 --- a/torchvision/datasets/hmdb51.py +++ b/torchvision/datasets/hmdb51.py @@ -1,8 +1,11 @@ import glob import os +from pathlib import Path + +import torch -from .utils import list_dir from .folder import make_dataset +from .utils import list_dir from .video_utils import VideoClips from .vision import VisionDataset @@ -99,7 +102,11 @@ def _select_fold(self, video_list, annotation_path, fold, train): data = [x[0] for x in data if int(x[1]) == target_tag] selected_files.extend(data) selected_files = set(selected_files) - indices = [i for i in range(len(video_list)) if os.path.basename(video_list[i]) in selected_files] + indices = [] + for i in range(len(video_list)): + path = Path(video_list[i]) + if str(path.relative_to(path.parent.parent)) in selected_files: + indices.append(i) return indices def __len__(self): @@ -110,6 +117,10 @@ def __getitem__(self, idx): label = self.samples[self.indices[video_idx]][1] if self.transform is not None: - video = self.transform(video) + transformed_video = [] + for counter, image in enumerate(video): + image = self.transform(image) + transformed_video.append(image) + video = torch.stack(transformed_video) return video, audio, label diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index 07db91cc195..8066edeca8c 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -1,5 +1,5 @@ -from .utils import list_dir from .folder import make_dataset +from .utils import list_dir from .video_utils import VideoClips from .vision import VisionDataset @@ -74,6 +74,10 @@ def __getitem__(self, idx): label = self.samples[video_idx][1] if self.transform is not None: - video = self.transform(video) + transformed_video = [] + for counter, image in enumerate(video): + image = self.transform(image) + transformed_video.append(image) + video = torch.stack(transformed_video) return video, audio, label diff --git a/torchvision/datasets/ucf101.py b/torchvision/datasets/ucf101.py index 43d8124bd4b..e8b4ecd6e8e 100644 --- a/torchvision/datasets/ucf101.py +++ b/torchvision/datasets/ucf101.py @@ -1,8 +1,11 @@ import glob import os +from pathlib import Path + +import torch -from .utils import list_dir from .folder import make_dataset +from .utils import list_dir from .video_utils import VideoClips from .vision import VisionDataset @@ -91,9 +94,12 @@ def _select_fold(self, video_list, annotation_path, fold, train): data = [x[0] for x in data] selected_files.extend(data) selected_files = set(selected_files) - indices = [i for i in range(len(video_list)) if video_list[i][len(self.root) + 1:] in selected_files] + indices = [] + for i in range(len(video_list)): + path = Path(video_list[i]) + if str(path.relative_to(path.parent.parent)) in selected_files: + indices.append(i) return indices - def __len__(self): return self.video_clips.num_clips() @@ -102,6 +108,10 @@ def __getitem__(self, idx): label = self.samples[self.indices[video_idx]][1] if self.transform is not None: - video = self.transform(video) + transformed_video = [] + for counter, image in enumerate(video): + image = self.transform(image) + transformed_video.append(image) + video = torch.stack(transformed_video) return video, audio, label