diff --git a/mypy.ini b/mypy.ini index 916abc1b43d..a6000f8a9d5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -22,11 +22,11 @@ warn_unreachable = True ; miscellaneous strictness flags allow_redefinition = True -[mypy-torchvision.io._video_opt.*] +[mypy-torchvision.io.image.*] ignore_errors = True -[mypy-torchvision.io.*] +[mypy-torchvision.io.video.*] ignore_errors = True diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index da2fed01f29..382e06fb4f2 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -132,7 +132,7 @@ def __next__(self) -> Dict[str, Any]: raise StopIteration return {"data": frame, "pts": pts} - def __iter__(self) -> Iterator["VideoReader"]: + def __iter__(self) -> Iterator[Dict[str, Any]]: return self def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader": diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index 630cbe07781..45bec44ec61 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -1,7 +1,7 @@ import math import warnings from fractions import Fraction -from typing import List, Tuple +from typing import List, Tuple, Dict, Optional, Union import torch @@ -26,10 +26,9 @@ class Timebase: def __init__( self, - numerator, # type: int - denominator, # type: int - ): - # type: (...) -> None + numerator: int, + denominator: int, + ) -> None: self.numerator = numerator self.denominator = denominator @@ -56,7 +55,7 @@ class VideoMetaData: "audio_sample_rate", ] - def __init__(self): + def __init__(self) -> None: self.has_video = False self.video_timebase = Timebase(0, 1) self.video_duration = 0.0 @@ -67,8 +66,7 @@ def __init__(self): self.audio_sample_rate = 0.0 -def _validate_pts(pts_range): - # type: (List[int]) -> None +def _validate_pts(pts_range: Tuple[int, int]) -> None: if pts_range[1] > 0: assert ( @@ -80,8 +78,14 @@ def _validate_pts(pts_range): ) -def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): - # type: (torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor) -> VideoMetaData +def _fill_info( + vtimebase: torch.Tensor, + vfps: torch.Tensor, + vduration: torch.Tensor, + atimebase: torch.Tensor, + asample_rate: torch.Tensor, + aduration: torch.Tensor, +) -> VideoMetaData: """ Build update VideoMetaData struct with info about the video """ @@ -106,8 +110,9 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): return meta -def _align_audio_frames(aframes, aframe_pts, audio_pts_range): - # type: (torch.Tensor, torch.Tensor, List[int]) -> torch.Tensor +def _align_audio_frames( + aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int] +) -> torch.Tensor: start, end = aframe_pts[0], aframe_pts[-1] num_samples = aframes.size(0) step_per_aframe = float(end - start + 1) / float(num_samples) @@ -121,21 +126,21 @@ def _align_audio_frames(aframes, aframe_pts, audio_pts_range): def _read_video_from_file( - filename, - seek_frame_margin=0.25, - read_video_stream=True, - video_width=0, - video_height=0, - video_min_dimension=0, - video_max_dimension=0, - video_pts_range=(0, -1), - video_timebase=default_timebase, - read_audio_stream=True, - audio_samples=0, - audio_channels=0, - audio_pts_range=(0, -1), - audio_timebase=default_timebase, -): + filename: str, + seek_frame_margin: float = 0.25, + read_video_stream: bool = True, + video_width: int = 0, + video_height: int = 0, + video_min_dimension: int = 0, + video_max_dimension: int = 0, + video_pts_range: Tuple[int, int] = (0, -1), + video_timebase: Fraction = default_timebase, + read_audio_stream: bool = True, + audio_samples: int = 0, + audio_channels: int = 0, + audio_pts_range: Tuple[int, int] = (0, -1), + audio_timebase: Fraction = default_timebase, +) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]: """ Reads a video from a file, returning both the video frames as well as the audio frames @@ -217,7 +222,7 @@ def _read_video_from_file( return vframes, aframes, info -def _read_video_timestamps_from_file(filename): +def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]: """ Decode all video- and audio frames in the video. Only pts (presentation timestamp) is returned. The actual frame pixel data is not @@ -252,7 +257,7 @@ def _read_video_timestamps_from_file(filename): return vframe_pts, aframe_pts, info -def _probe_video_from_file(filename): +def _probe_video_from_file(filename: str) -> VideoMetaData: """ Probe a video file and return VideoMetaData with info about the video """ @@ -263,24 +268,23 @@ def _probe_video_from_file(filename): def _read_video_from_memory( - video_data, # type: torch.Tensor - seek_frame_margin=0.25, # type: float - read_video_stream=1, # type: int - video_width=0, # type: int - video_height=0, # type: int - video_min_dimension=0, # type: int - video_max_dimension=0, # type: int - video_pts_range=(0, -1), # type: List[int] - video_timebase_numerator=0, # type: int - video_timebase_denominator=1, # type: int - read_audio_stream=1, # type: int - audio_samples=0, # type: int - audio_channels=0, # type: int - audio_pts_range=(0, -1), # type: List[int] - audio_timebase_numerator=0, # type: int - audio_timebase_denominator=1, # type: int -): - # type: (...) -> Tuple[torch.Tensor, torch.Tensor] + video_data: torch.Tensor, + seek_frame_margin: float = 0.25, + read_video_stream: int = 1, + video_width: int = 0, + video_height: int = 0, + video_min_dimension: int = 0, + video_max_dimension: int = 0, + video_pts_range: Tuple[int, int] = (0, -1), + video_timebase_numerator: int = 0, + video_timebase_denominator: int = 1, + read_audio_stream: int = 1, + audio_samples: int = 0, + audio_channels: int = 0, + audio_pts_range: Tuple[int, int] = (0, -1), + audio_timebase_numerator: int = 0, + audio_timebase_denominator: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: """ Reads a video from memory, returning both the video frames as well as the audio frames @@ -370,7 +374,9 @@ def _read_video_from_memory( return vframes, aframes -def _read_video_timestamps_from_memory(video_data): +def _read_video_timestamps_from_memory( + video_data: torch.Tensor, +) -> Tuple[List[int], List[int], VideoMetaData]: """ Decode all frames in the video. Only pts (presentation timestamp) is returned. The actual frame pixel data is not copied. Thus, read_video_timestamps(...) @@ -407,8 +413,9 @@ def _read_video_timestamps_from_memory(video_data): return vframe_pts, aframe_pts, info -def _probe_video_from_memory(video_data): - # type: (torch.Tensor) -> VideoMetaData +def _probe_video_from_memory( + video_data: torch.Tensor, +) -> VideoMetaData: """ Probe a video in memory and return VideoMetaData with info about the video This function is torchscriptable @@ -421,7 +428,9 @@ def _probe_video_from_memory(video_data): return info -def _convert_to_sec(start_pts, end_pts, pts_unit, time_base): +def _convert_to_sec( + start_pts: Union[float, Fraction], end_pts: Union[float, Fraction], pts_unit: str, time_base: Fraction +) -> Tuple[Union[float, Fraction], Union[float, Fraction], str]: if pts_unit == "pts": start_pts = float(start_pts * time_base) end_pts = float(end_pts * time_base) @@ -429,7 +438,12 @@ def _convert_to_sec(start_pts, end_pts, pts_unit, time_base): return start_pts, end_pts, pts_unit -def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): +def _read_video( + filename: str, + start_pts: Union[float, Fraction] = 0, + end_pts: Optional[Union[float, Fraction]] = None, + pts_unit: str = "pts", +) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]: if end_pts is None: end_pts = float("inf") @@ -495,13 +509,16 @@ def get_pts(time_base): return vframes, aframes, _info -def _read_video_timestamps(filename, pts_unit="pts"): +def _read_video_timestamps( + filename: str, pts_unit: str = "pts" +) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]: if pts_unit == "pts": warnings.warn( "The pts_unit 'pts' gives wrong results and will be removed in a " + "follow-up version. Please use pts_unit 'sec'." ) + pts: Union[List[int], List[Fraction]] pts, _, info = _read_video_timestamps_from_file(filename) if pts_unit == "sec":