Skip to content

Commit afd8bc1

Browse files
committed
appease mypy
1 parent 1c025f1 commit afd8bc1

File tree

4 files changed

+13
-11
lines changed

4 files changed

+13
-11
lines changed

mypy.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,7 @@ ignore_missing_imports = True
125125
[mypy-h5py.*]
126126

127127
ignore_missing_imports = True
128+
129+
[mypy-rarfile.*]
130+
131+
ignore_missing_imports = True

torchvision/prototype/datasets/_builtin/hmdb51.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
5858

5959
def _is_split_number(self, data: Tuple[str, Any], *, split_number: str) -> bool:
6060
path = pathlib.Path(data[0])
61-
return self._SPLIT_FILE_PATTERN.match(path.name)["split_number"] == split_number # type: ignore[union-attr]
61+
return self._SPLIT_FILE_PATTERN.match(path.name)["split_number"] == split_number # type: ignore[index]
6262

6363
_SPLIT_ID_TO_NAME = {
6464
"1": "train",
@@ -111,7 +111,6 @@ def _generate_categories(self, root: pathlib.Path) -> List[str]:
111111

112112
dp = resources[0].load(root)
113113
categories = {
114-
self._SPLIT_FILE_PATTERN.match(pathlib.Path(path).name)["category"] # type: ignore[union-attr]
115-
for path, _ in dp
114+
self._SPLIT_FILE_PATTERN.match(pathlib.Path(path).name)["category"] for path, _ in dp # type: ignore[index]
116115
}
117116
return sorted(categories)

torchvision/prototype/datasets/_builtin/ucf101.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
8888
resources = self.resources(config)
8989

9090
dp = resources[0].load(root)
91-
dp = Filter(dp, path_comparator("name", "classInd.txt"))
91+
dp: IterDataPipe[Tuple[str, BinaryIO]] = Filter(dp, path_comparator("name", "classInd.txt"))
9292
dp = CSVParser(dp, dialect="ucf101")
9393
_, categories = zip(*dp)
9494
return cast(Tuple[str, ...], categories)

torchvision/prototype/datasets/utils/_video.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import random
2-
from typing import Any, Dict, Iterator, BinaryIO, Optional, Tuple
2+
from typing import Any, Dict, Iterator, Optional, Tuple
33

44
import av
55
import numpy as np
@@ -15,7 +15,7 @@ def __init__(self, datapipe: IterDataPipe, *, inline: bool = True) -> None:
1515
self.datapipe = datapipe
1616
self._inline = inline
1717

18-
def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
18+
def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
1919
raise NotImplementedError
2020

2121
def _find_encoded_video(self, id: Tuple[Any, ...], obj: Any) -> Optional[Tuple[Any, ...]]:
@@ -65,13 +65,12 @@ def __iter__(self) -> Iterator[Any]:
6565
raise ValueError("more than one encoded video")
6666
id, video = ids_and_videos[0]
6767

68-
buffer = ReadOnlyTensorBuffer(video)
69-
for data in self._decode(buffer, video.meta.copy()):
68+
for data in self._decode(ReadOnlyTensorBuffer(video), video.meta.copy()):
7069
yield self._integrate_data(sample, id, data)
7170

7271

7372
class KeyframeDecoder(_VideoDecoder):
74-
def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
73+
def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
7574
with av.open(buffer, metadata_errors="ignore") as container:
7675
stream = container.streams.video[0]
7776
stream.codec_context.skip_frame = "NONKEY"
@@ -92,7 +91,7 @@ def __init__(self, datapipe: IterDataPipe, *, num_samples: int = 1, inline: bool
9291
super().__init__(datapipe, inline=inline)
9392
self.num_sampler = num_samples
9493

95-
def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
94+
def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
9695
with av.open(buffer, metadata_errors="ignore") as container:
9796
stream = container.streams.video[0]
9897
# duration is given in time_base units as int
@@ -147,7 +146,7 @@ def _unfold(self, tensor: torch.Tensor, dilation: int = 1) -> torch.Tensor:
147146
new_size = (0, self.num_frames_per_clip)
148147
return torch.as_strided(tensor, new_size, new_stride)
149148

150-
def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
149+
def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
151150
with av.open(buffer, metadata_errors="ignore") as container:
152151
stream = container.streams.video[0]
153152
time_base = stream.time_base

0 commit comments

Comments
 (0)