We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0eb7eb0 commit f4337deCopy full SHA for f4337de
src/torchcodec/decoders/_video_decoder.py
@@ -8,7 +8,7 @@
8
from pathlib import Path
9
from typing import Literal, Optional, Tuple, Union
10
11
-from torch import device, Tensor
+from torch import device as torch_device, Tensor
12
13
from torchcodec import _core as core, Frame, FrameBatch
14
from torchcodec.decoders._decoder_utils import (
@@ -71,7 +71,7 @@ def __init__(
71
stream_index: Optional[int] = None,
72
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
73
num_ffmpeg_threads: int = 1,
74
- device: Optional[Union[str, device]] = "cpu",
+ device: Optional[Union[str, torch_device]] = "cpu",
75
seek_mode: Literal["exact", "approximate"] = "exact",
76
):
77
allowed_seek_modes = ("exact", "approximate")
@@ -93,6 +93,9 @@ def __init__(
93
if num_ffmpeg_threads is None:
94
raise ValueError(f"{num_ffmpeg_threads = } should be an int.")
95
96
+ if isinstance(device, torch_device):
97
+ device = str(device)
98
+
99
core.add_video_stream(
100
self._decoder,
101
stream_index=stream_index,
test/decoders/test_decoders.py
@@ -288,6 +288,11 @@ def test_getitem_slice(self, device, seek_mode):
288
# See https://github.com/pytorch/torchcodec/issues/428
289
assert_frames_equal(sliced, ref)
290
291
+ def test_device_instance(self):
292
+ # Non-regression test for https://github.com/pytorch/torchcodec/issues/602
293
+ decoder = VideoDecoder(NASA_VIDEO.path, device=torch.device("cpu"))
294
+ assert isinstance(decoder.metadata, VideoStreamMetadata)
295
296
@pytest.mark.parametrize("device", cpu_and_cuda())
297
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
298
def test_getitem_fails(self, device, seek_mode):
0 commit comments