From d7412db4b68e06f3de1781e08bb577692ea90f33 Mon Sep 17 00:00:00 2001 From: SoheilStar <75124326+soheil-star01@users.noreply.github.com> Date: Sat, 29 Mar 2025 02:19:19 +0200 Subject: [PATCH 1/2] Fix: Convert torch.device to string in VideoDecoder init --- src/torchcodec/decoders/_video_decoder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 081f332b4..4ca9b3a69 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Literal, Optional, Tuple, Union -from torch import device, Tensor +from torch import device as torch_device, Tensor from torchcodec import Frame, FrameBatch from torchcodec.decoders import _core as core @@ -72,7 +72,7 @@ def __init__( stream_index: Optional[int] = None, dimension_order: Literal["NCHW", "NHWC"] = "NCHW", num_ffmpeg_threads: int = 1, - device: Optional[Union[str, device]] = "cpu", + device: Optional[Union[str, torch_device]] = "cpu", seek_mode: Literal["exact", "approximate"] = "exact", ): allowed_seek_modes = ("exact", "approximate") @@ -94,6 +94,9 @@ def __init__( if num_ffmpeg_threads is None: raise ValueError(f"{num_ffmpeg_threads = } should be an int.") + if isinstance(device, torch_device): + device = str(device) + core.add_video_stream( self._decoder, stream_index=stream_index, From 0a0b482843d299d9afa37fbe5d6f125f945e6847 Mon Sep 17 00:00:00 2001 From: SoheilStar <75124326+soheil-star01@users.noreply.github.com> Date: Tue, 1 Apr 2025 21:19:03 +0300 Subject: [PATCH 2/2] Test: Add non-regression test for torch.device() in VideoDecoder --- test/decoders/test_decoders.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index cc47e116d..4115553f3 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -285,6 +285,11 @@ def test_getitem_slice(self, device, seek_mode): # See https://github.com/pytorch/torchcodec/issues/428 assert_frames_equal(sliced, ref) + def test_device_instance(self): + # Non-regression test for https://github.com/pytorch/torchcodec/issues/602 + decoder = VideoDecoder(NASA_VIDEO.path, device=torch.device("cpu")) + assert isinstance(decoder.metadata, VideoStreamMetadata) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_fails(self, device, seek_mode):