diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index f0f19e332ed..efa3836c8d1 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -2,7 +2,7 @@ import math import warnings from fractions import Fraction -from typing import List +from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast import torch from torchvision.io import ( @@ -14,8 +14,10 @@ from .utils import tqdm +T = TypeVar("T") -def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor): + +def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int: """convert pts between different time bases Args: pts: presentation timestamp, float @@ -27,7 +29,7 @@ def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor): return round_func(new_pts) -def unfold(tensor, size, step, dilation=1): +def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor: """ similar to tensor.unfold, but with the dilation and specialized for 1d tensors @@ -55,17 +57,17 @@ class _VideoTimestampsDataset: pickled when forking. """ - def __init__(self, video_paths: List[str]): + def __init__(self, video_paths: List[str]) -> None: self.video_paths = video_paths - def __len__(self): + def __len__(self) -> int: return len(self.video_paths) - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]: return read_video_timestamps(self.video_paths[idx]) -def _collate_fn(x): +def _collate_fn(x: T) -> T: """ Dummy collate function to be used with _VideoTimestampsDataset """ @@ -100,19 +102,19 @@ class VideoClips: def __init__( self, - video_paths, - clip_length_in_frames=16, - frames_between_clips=1, - frame_rate=None, - _precomputed_metadata=None, - num_workers=0, - _video_width=0, - _video_height=0, - _video_min_dimension=0, - _video_max_dimension=0, - _audio_samples=0, - _audio_channels=0, - ): + video_paths: List[str], + clip_length_in_frames: int = 16, + frames_between_clips: int = 1, + frame_rate: Optional[int] = None, + _precomputed_metadata: Optional[Dict[str, Any]] = None, + num_workers: int = 0, + _video_width: int = 0, + _video_height: int = 0, + _video_min_dimension: int = 0, + _video_max_dimension: int = 0, + _audio_samples: int = 0, + _audio_channels: int = 0, + ) -> None: self.video_paths = video_paths self.num_workers = num_workers @@ -131,7 +133,7 @@ def __init__( self._init_from_metadata(_precomputed_metadata) self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate) - def _compute_frame_pts(self): + def _compute_frame_pts(self) -> None: self.video_pts = [] self.video_fps = [] @@ -139,8 +141,8 @@ def _compute_frame_pts(self): # so need to create a dummy dataset first import torch.utils.data - dl = torch.utils.data.DataLoader( - _VideoTimestampsDataset(self.video_paths), + dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader( + _VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type] batch_size=16, num_workers=self.num_workers, collate_fn=_collate_fn, @@ -157,7 +159,7 @@ def _compute_frame_pts(self): self.video_pts.extend(clips) self.video_fps.extend(fps) - def _init_from_metadata(self, metadata): + def _init_from_metadata(self, metadata: Dict[str, Any]) -> None: self.video_paths = metadata["video_paths"] assert len(self.video_paths) == len(metadata["video_pts"]) self.video_pts = metadata["video_pts"] @@ -165,7 +167,7 @@ def _init_from_metadata(self, metadata): self.video_fps = metadata["video_fps"] @property - def metadata(self): + def metadata(self) -> Dict[str, Any]: _metadata = { "video_paths": self.video_paths, "video_pts": self.video_pts, @@ -173,7 +175,7 @@ def metadata(self): } return _metadata - def subset(self, indices): + def subset(self, indices: List[int]) -> "VideoClips": video_paths = [self.video_paths[i] for i in indices] video_pts = [self.video_pts[i] for i in indices] video_fps = [self.video_fps[i] for i in indices] @@ -198,7 +200,9 @@ def subset(self, indices): ) @staticmethod - def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): + def compute_clips_for_video( + video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None + ) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]: if fps is None: # if for some reason the video doesn't have fps (because doesn't have a video stream) # set the fps to 1. The value doesn't matter, because video_pts is empty anyway @@ -206,21 +210,22 @@ def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): if frame_rate is None: frame_rate = fps total_frames = len(video_pts) * (float(frame_rate) / fps) - idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate) - video_pts = video_pts[idxs] + _idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate) + video_pts = video_pts[_idxs] clips = unfold(video_pts, num_frames, step) if not clips.numel(): warnings.warn( "There aren't enough frames in the current video to get a clip for the given clip length and " "frames between clips. The video (and potentially others) will be skipped." ) - if isinstance(idxs, slice): - idxs = [idxs] * len(clips) + idxs: Union[List[slice], torch.Tensor] + if isinstance(_idxs, slice): + idxs = [_idxs] * len(clips) else: - idxs = unfold(idxs, num_frames, step) + idxs = unfold(_idxs, num_frames, step) return clips, idxs - def compute_clips(self, num_frames, step, frame_rate=None): + def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None: """ Compute all consecutive sequences of clips from video_pts. Always returns clips of size `num_frames`, meaning that the @@ -243,19 +248,19 @@ def compute_clips(self, num_frames, step, frame_rate=None): clip_lengths = torch.as_tensor([len(v) for v in self.clips]) self.cumulative_sizes = clip_lengths.cumsum(0).tolist() - def __len__(self): + def __len__(self) -> int: return self.num_clips() - def num_videos(self): + def num_videos(self) -> int: return len(self.video_paths) - def num_clips(self): + def num_clips(self) -> int: """ Number of subclips that are available in the video list. """ return self.cumulative_sizes[-1] - def get_clip_location(self, idx): + def get_clip_location(self, idx: int) -> Tuple[int, int]: """ Converts a flattened representation of the indices into a video_idx, clip_idx representation. @@ -268,7 +273,7 @@ def get_clip_location(self, idx): return video_idx, clip_idx @staticmethod - def _resample_video_idx(num_frames, original_fps, new_fps): + def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]: step = float(original_fps) / new_fps if step.is_integer(): # optimization: if step is integer, don't need to perform @@ -279,7 +284,7 @@ def _resample_video_idx(num_frames, original_fps, new_fps): idxs = idxs.floor().to(torch.int64) return idxs - def get_clip(self, idx): + def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]: """ Gets a subclip from a list of videos. @@ -320,22 +325,22 @@ def get_clip(self, idx): end_pts = clip_pts[-1].item() video, audio, info = read_video(video_path, start_pts, end_pts) else: - info = _probe_video_from_file(video_path) - video_fps = info.video_fps + _info = _probe_video_from_file(video_path) + video_fps = _info.video_fps audio_fps = None - video_start_pts = clip_pts[0].item() - video_end_pts = clip_pts[-1].item() + video_start_pts = cast(int, clip_pts[0].item()) + video_end_pts = cast(int, clip_pts[-1].item()) audio_start_pts, audio_end_pts = 0, -1 audio_timebase = Fraction(0, 1) - video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) - if info.has_audio: - audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator) + video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator) + if _info.has_audio: + audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator) audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor) audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil) - audio_fps = info.audio_sample_rate - video, audio, info = _read_video_from_file( + audio_fps = _info.audio_sample_rate + video, audio, _ = _read_video_from_file( video_path, video_width=self._video_width, video_height=self._video_height, @@ -362,7 +367,7 @@ def get_clip(self, idx): assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}" return video, audio, info, video_idx - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: video_pts_sizes = [len(v) for v in self.video_pts] # To be back-compatible, we convert data to dtype torch.long as needed # because for empty list, in legacy implementation, torch.as_tensor will @@ -371,10 +376,10 @@ def __getstate__(self): video_pts = [x.to(torch.int64) for x in self.video_pts] # video_pts can be an empty list if no frames have been decoded if video_pts: - video_pts = torch.cat(video_pts) + video_pts = torch.cat(video_pts) # type: ignore[assignment] # avoid bug in https://github.com/pytorch/pytorch/issues/32351 # TODO: Revert it once the bug is fixed. - video_pts = video_pts.numpy() + video_pts = video_pts.numpy() # type: ignore[attr-defined] # make a copy of the fields of self d = self.__dict__.copy() @@ -390,7 +395,7 @@ def __getstate__(self): d["_version"] = 2 return d - def __setstate__(self, d): + def __setstate__(self, d: Dict[str, Any]) -> None: # for backwards-compatibility if "_version" not in d: self.__dict__ = d