diff --git a/mypy.ini b/mypy.ini index bac2124f878..d00bbe17156 100644 --- a/mypy.ini +++ b/mypy.ini @@ -8,7 +8,7 @@ pretty = True ;ignore_errors = True -[mypy-torchvision.io.*] +[mypy-torchvision.io._video_opt.*] ignore_errors = True @@ -51,3 +51,7 @@ ignore_missing_imports = True [mypy-accimage.*] ignore_missing_imports = True + +[mypy-av.*] + +ignore_missing_imports = True diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 8d5da4899ca..35f9971f37c 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -1,9 +1,8 @@ import torch -from torch import nn, Tensor import os import os.path as osp -import importlib +import importlib.machinery _HAS_IMAGE_OPT = False @@ -15,7 +14,7 @@ importlib.machinery.EXTENSION_SUFFIXES ) - extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) # type: ignore[arg-type] ext_specs = extfinder.find_spec("image") if ext_specs is not None: torch.ops.load_library(ext_specs.origin) @@ -24,8 +23,7 @@ pass -def decode_png(input): - # type: (Tensor) -> Tensor +def decode_png(input: torch.Tensor) -> torch.Tensor: """ Decodes a PNG image into a 3 dimensional RGB Tensor. The values of the output tensor are uint8 between 0 and 255. @@ -37,7 +35,7 @@ def decode_png(input): Returns: output (Tensor[image_width, image_height, 3]) """ - if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1: + if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1: # type: ignore[attr-defined] raise ValueError("Expected a non empty 1-dimensional tensor.") if not input.dtype == torch.uint8: @@ -46,8 +44,7 @@ def decode_png(input): return output -def read_png(path): - # type: (str) -> Tensor +def read_png(path: str) -> torch.Tensor: """ Reads a PNG image into a 3 dimensional RGB Tensor. The values of the output tensor are uint8 between 0 and 255. @@ -68,8 +65,7 @@ def read_png(path): return decode_png(data) -def decode_jpeg(input): - # type: (Tensor) -> Tensor +def decode_jpeg(input: torch.Tensor) -> torch.Tensor: """ Decodes a JPEG image into a 3 dimensional RGB Tensor. The values of the output tensor are uint8 between 0 and 255. @@ -79,7 +75,7 @@ def decode_jpeg(input): Returns: output (Tensor[image_width, image_height, 3]) """ - if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1: + if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1: # type: ignore[attr-defined] raise ValueError("Expected a non empty 1-dimensional tensor.") if not input.dtype == torch.uint8: @@ -89,8 +85,7 @@ def decode_jpeg(input): return output -def read_jpeg(path): - # type: (str) -> Tensor +def read_jpeg(path: str) -> torch.Tensor: """ Reads a JPEG image into a 3 dimensional RGB Tensor. The values of the output tensor are uint8 between 0 and 255. diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 5c8529a7b5d..eb6d0f98d62 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -2,7 +2,7 @@ import math import re import warnings -from typing import List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -35,12 +35,12 @@ ) -def _check_av_available(): +def _check_av_available() -> None: if isinstance(av, Exception): raise av -def _av_available(): +def _av_available() -> bool: return not isinstance(av, Exception) @@ -49,7 +49,13 @@ def _av_available(): _GC_COLLECTION_INTERVAL = 10 -def write_video(filename, video_array, fps: Union[int, float], video_codec="libx264", options=None): +def write_video( + filename: str, + video_array: torch.Tensor, + fps: float, + video_codec: str = "libx264", + options: Optional[Dict[str, Any]] = None, +) -> None: """ Writes a 4d tensor in [T, H, W, C] format in a video file @@ -89,8 +95,13 @@ def write_video(filename, video_array, fps: Union[int, float], video_codec="libx def _read_from_stream( - container, start_offset, end_offset, pts_unit, stream, stream_name -): + container: "av.container.Container", + start_offset: float, + end_offset: float, + pts_unit: str, + stream: "av.stream.Stream", + stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]], +) -> List["av.frame.Frame"]: global _CALLED_TIMES, _GC_COLLECTION_INTERVAL _CALLED_TIMES += 1 if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: @@ -166,7 +177,9 @@ def _read_from_stream( return result -def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): +def _align_audio_frames( + aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float +) -> torch.Tensor: start, end = audio_frames[0].pts, audio_frames[-1].pts total_aframes = aframes.shape[1] step_per_aframe = (end - start + 1) / total_aframes @@ -179,7 +192,9 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): return aframes[:, s_idx:e_idx] -def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): +def read_video( + filename: str, start_pts: int = 0, end_pts: Optional[float] = None, pts_unit: str = "pts" +) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Reads a video from a file, returning both the video frames as well as the audio frames @@ -260,16 +275,16 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): # TODO raise a warning? pass - vframes = [frame.to_rgb().to_ndarray() for frame in video_frames] - aframes = [frame.to_ndarray() for frame in audio_frames] + vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] + aframes_list = [frame.to_ndarray() for frame in audio_frames] - if vframes: - vframes = torch.as_tensor(np.stack(vframes)) + if vframes_list: + vframes = torch.as_tensor(np.stack(vframes_list)) else: vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) - if aframes: - aframes = np.concatenate(aframes, 1) + if aframes_list: + aframes = np.concatenate(aframes_list, 1) aframes = torch.as_tensor(aframes) aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) else: @@ -278,7 +293,7 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): return vframes, aframes, info -def _can_read_timestamps_from_packets(container): +def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool: extradata = container.streams[0].codec_context.extradata if extradata is None: return False @@ -287,7 +302,7 @@ def _can_read_timestamps_from_packets(container): return False -def _decode_video_timestamps(container): +def _decode_video_timestamps(container: "av.container.Container") -> List[int]: if _can_read_timestamps_from_packets(container): # fast path return [x.pts for x in container.demux(video=0) if x.pts is not None] @@ -295,7 +310,7 @@ def _decode_video_timestamps(container): return [x.pts for x in container.decode(video=0) if x.pts is not None] -def read_video_timestamps(filename, pts_unit="pts"): +def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]: """ List the video frames timestamps. @@ -313,7 +328,7 @@ def read_video_timestamps(filename, pts_unit="pts"): pts : List[int] if pts_unit = 'pts' List[Fraction] if pts_unit = 'sec' presentation timestamps for each one of the frames in the video. - video_fps : int + video_fps : float, optional the frame rate for the video """