Skip to content

Added missing typing annotations in datasets/video_utils #4172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Nov 22, 2021
109 changes: 57 additions & 52 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import warnings
from fractions import Fraction
from typing import List
from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast

import torch
from torchvision.io import (
Expand All @@ -14,8 +14,10 @@

from .utils import tqdm

T = TypeVar("T")

def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):

def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int:
"""convert pts between different time bases
Args:
pts: presentation timestamp, float
Expand All @@ -27,7 +29,7 @@ def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
return round_func(new_pts)


def unfold(tensor, size, step, dilation=1):
def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor:
"""
similar to tensor.unfold, but with the dilation
and specialized for 1d tensors
Expand Down Expand Up @@ -55,17 +57,17 @@ class _VideoTimestampsDataset:
pickled when forking.
"""

def __init__(self, video_paths: List[str]):
def __init__(self, video_paths: List[str]) -> None:
self.video_paths = video_paths

def __len__(self):
def __len__(self) -> int:
return len(self.video_paths)

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]:
return read_video_timestamps(self.video_paths[idx])


def _collate_fn(x):
def _collate_fn(x: T) -> T:
"""
Dummy collate function to be used with _VideoTimestampsDataset
"""
Expand Down Expand Up @@ -100,19 +102,19 @@ class VideoClips:

def __init__(
self,
video_paths,
clip_length_in_frames=16,
frames_between_clips=1,
frame_rate=None,
_precomputed_metadata=None,
num_workers=0,
_video_width=0,
_video_height=0,
_video_min_dimension=0,
_video_max_dimension=0,
_audio_samples=0,
_audio_channels=0,
):
video_paths: List[str],
clip_length_in_frames: int = 16,
frames_between_clips: int = 1,
frame_rate: Optional[int] = None,
_precomputed_metadata: Optional[Dict[str, Any]] = None,
num_workers: int = 0,
_video_width: int = 0,
_video_height: int = 0,
_video_min_dimension: int = 0,
_video_max_dimension: int = 0,
_audio_samples: int = 0,
_audio_channels: int = 0,
) -> None:

self.video_paths = video_paths
self.num_workers = num_workers
Expand All @@ -131,16 +133,16 @@ def __init__(
self._init_from_metadata(_precomputed_metadata)
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)

def _compute_frame_pts(self):
def _compute_frame_pts(self) -> None:
self.video_pts = []
self.video_fps = []

# strategy: use a DataLoader to parallelize read_video_timestamps
# so need to create a dummy dataset first
import torch.utils.data

dl = torch.utils.data.DataLoader(
_VideoTimestampsDataset(self.video_paths),
dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
_VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type]
batch_size=16,
num_workers=self.num_workers,
collate_fn=_collate_fn,
Expand All @@ -157,23 +159,23 @@ def _compute_frame_pts(self):
self.video_pts.extend(clips)
self.video_fps.extend(fps)

def _init_from_metadata(self, metadata):
def _init_from_metadata(self, metadata: Dict[str, Any]) -> None:
self.video_paths = metadata["video_paths"]
assert len(self.video_paths) == len(metadata["video_pts"])
self.video_pts = metadata["video_pts"]
assert len(self.video_paths) == len(metadata["video_fps"])
self.video_fps = metadata["video_fps"]

@property
def metadata(self):
def metadata(self) -> Dict[str, Any]:
_metadata = {
"video_paths": self.video_paths,
"video_pts": self.video_pts,
"video_fps": self.video_fps,
}
return _metadata

def subset(self, indices):
def subset(self, indices: List[int]) -> "VideoClips":
video_paths = [self.video_paths[i] for i in indices]
video_pts = [self.video_pts[i] for i in indices]
video_fps = [self.video_fps[i] for i in indices]
Expand All @@ -198,29 +200,32 @@ def subset(self, indices):
)

@staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
def compute_clips_for_video(
video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None
) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]:
if fps is None:
# if for some reason the video doesn't have fps (because doesn't have a video stream)
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway
fps = 1
if frame_rate is None:
frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps)
idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
video_pts = video_pts[idxs]
_idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
video_pts = video_pts[_idxs]
clips = unfold(video_pts, num_frames, step)
if not clips.numel():
warnings.warn(
"There aren't enough frames in the current video to get a clip for the given clip length and "
"frames between clips. The video (and potentially others) will be skipped."
)
if isinstance(idxs, slice):
idxs = [idxs] * len(clips)
idxs: Union[List[slice], torch.Tensor]
if isinstance(_idxs, slice):
idxs = [_idxs] * len(clips)
else:
idxs = unfold(idxs, num_frames, step)
idxs = unfold(_idxs, num_frames, step)
return clips, idxs

def compute_clips(self, num_frames, step, frame_rate=None):
def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None:
"""
Compute all consecutive sequences of clips from video_pts.
Always returns clips of size `num_frames`, meaning that the
Expand All @@ -243,19 +248,19 @@ def compute_clips(self, num_frames, step, frame_rate=None):
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()

def __len__(self):
def __len__(self) -> int:
return self.num_clips()

def num_videos(self):
def num_videos(self) -> int:
return len(self.video_paths)

def num_clips(self):
def num_clips(self) -> int:
"""
Number of subclips that are available in the video list.
"""
return self.cumulative_sizes[-1]

def get_clip_location(self, idx):
def get_clip_location(self, idx: int) -> Tuple[int, int]:
"""
Converts a flattened representation of the indices into a video_idx, clip_idx
representation.
Expand All @@ -268,7 +273,7 @@ def get_clip_location(self, idx):
return video_idx, clip_idx

@staticmethod
def _resample_video_idx(num_frames, original_fps, new_fps):
def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]:
step = float(original_fps) / new_fps
if step.is_integer():
# optimization: if step is integer, don't need to perform
Expand All @@ -279,7 +284,7 @@ def _resample_video_idx(num_frames, original_fps, new_fps):
idxs = idxs.floor().to(torch.int64)
return idxs

def get_clip(self, idx):
def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]:
"""
Gets a subclip from a list of videos.

Expand Down Expand Up @@ -320,22 +325,22 @@ def get_clip(self, idx):
end_pts = clip_pts[-1].item()
video, audio, info = read_video(video_path, start_pts, end_pts)
else:
info = _probe_video_from_file(video_path)
video_fps = info.video_fps
_info = _probe_video_from_file(video_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we now have

allow_redefinition = True

this renaming is probably not needed anymore. Could you check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did and unfortunately it's still an issue

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems even with allow_redefinition, mypy is stricter than I thought:

Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition. [emphasis mine]

Since that is not the case here, we need to keep your fix.

video_fps = _info.video_fps
audio_fps = None

video_start_pts = clip_pts[0].item()
video_end_pts = clip_pts[-1].item()
video_start_pts = cast(int, clip_pts[0].item())
video_end_pts = cast(int, clip_pts[-1].item())

audio_start_pts, audio_end_pts = 0, -1
audio_timebase = Fraction(0, 1)
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
if info.has_audio:
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
if _info.has_audio:
audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
audio_fps = info.audio_sample_rate
video, audio, info = _read_video_from_file(
audio_fps = _info.audio_sample_rate
video, audio, _ = _read_video_from_file(
video_path,
video_width=self._video_width,
video_height=self._video_height,
Expand All @@ -362,7 +367,7 @@ def get_clip(self, idx):
assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
return video, audio, info, video_idx

def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
video_pts_sizes = [len(v) for v in self.video_pts]
# To be back-compatible, we convert data to dtype torch.long as needed
# because for empty list, in legacy implementation, torch.as_tensor will
Expand All @@ -371,10 +376,10 @@ def __getstate__(self):
video_pts = [x.to(torch.int64) for x in self.video_pts]
# video_pts can be an empty list if no frames have been decoded
if video_pts:
video_pts = torch.cat(video_pts)
video_pts = torch.cat(video_pts) # type: ignore[assignment]
# avoid bug in https://github.com/pytorch/pytorch/issues/32351
# TODO: Revert it once the bug is fixed.
video_pts = video_pts.numpy()
video_pts = video_pts.numpy() # type: ignore[attr-defined]

# make a copy of the fields of self
d = self.__dict__.copy()
Expand All @@ -390,7 +395,7 @@ def __getstate__(self):
d["_version"] = 2
return d

def __setstate__(self, d):
def __setstate__(self, d: Dict[str, Any]) -> None:
# for backwards-compatibility
if "_version" not in d:
self.__dict__ = d
Expand Down