Skip to content

Commit f4337de

Browse files
Fix VideoDecoder device argument to accept torch.device (#607)
1 parent 0eb7eb0 commit f4337de

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pathlib import Path
99
from typing import Literal, Optional, Tuple, Union
1010

11-
from torch import device, Tensor
11+
from torch import device as torch_device, Tensor
1212

1313
from torchcodec import _core as core, Frame, FrameBatch
1414
from torchcodec.decoders._decoder_utils import (
@@ -71,7 +71,7 @@ def __init__(
7171
stream_index: Optional[int] = None,
7272
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
7373
num_ffmpeg_threads: int = 1,
74-
device: Optional[Union[str, device]] = "cpu",
74+
device: Optional[Union[str, torch_device]] = "cpu",
7575
seek_mode: Literal["exact", "approximate"] = "exact",
7676
):
7777
allowed_seek_modes = ("exact", "approximate")
@@ -93,6 +93,9 @@ def __init__(
9393
if num_ffmpeg_threads is None:
9494
raise ValueError(f"{num_ffmpeg_threads = } should be an int.")
9595

96+
if isinstance(device, torch_device):
97+
device = str(device)
98+
9699
core.add_video_stream(
97100
self._decoder,
98101
stream_index=stream_index,

test/decoders/test_decoders.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,11 @@ def test_getitem_slice(self, device, seek_mode):
288288
# See https://github.com/pytorch/torchcodec/issues/428
289289
assert_frames_equal(sliced, ref)
290290

291+
def test_device_instance(self):
292+
# Non-regression test for https://github.com/pytorch/torchcodec/issues/602
293+
decoder = VideoDecoder(NASA_VIDEO.path, device=torch.device("cpu"))
294+
assert isinstance(decoder.metadata, VideoStreamMetadata)
295+
291296
@pytest.mark.parametrize("device", cpu_and_cuda())
292297
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
293298
def test_getitem_fails(self, device, seek_mode):

0 commit comments

Comments
 (0)