diff --git a/test/test_video_reader.py b/test/test_video_reader.py index c3b0487f153..867923d10d0 100644 --- a/test/test_video_reader.py +++ b/test/test_video_reader.py @@ -1225,7 +1225,7 @@ def test_invalid_file(self): @pytest.mark.parametrize("test_video", test_videos.keys()) @pytest.mark.parametrize("backend", ["video_reader", "pyav"]) - @pytest.mark.parametrize("start_offset", [0, 1000]) + @pytest.mark.parametrize("start_offset", [0, 500]) @pytest.mark.parametrize("end_offset", [3000, None]) def test_audio_present_pts(self, test_video, backend, start_offset, end_offset): """Test if audio frames are returned with pts unit.""" diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index 5357d25ea62..055b195a8f4 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -423,16 +423,6 @@ def _probe_video_from_memory( return info -def _convert_to_sec( - start_pts: Union[float, Fraction], end_pts: Union[float, Fraction], pts_unit: str, time_base: Fraction -) -> Tuple[Union[float, Fraction], Union[float, Fraction], str]: - if pts_unit == "pts": - start_pts = float(start_pts * time_base) - end_pts = float(end_pts * time_base) - pts_unit = "sec" - return start_pts, end_pts, pts_unit - - def _read_video( filename: str, start_pts: Union[float, Fraction] = 0, @@ -452,38 +442,28 @@ def _read_video( has_video = info.has_video has_audio = info.has_audio - video_pts_range = (0, -1) - video_timebase = default_timebase - audio_pts_range = (0, -1) - audio_timebase = default_timebase - time_base = default_timebase - - if has_video: - video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) - time_base = video_timebase - - if has_audio: - audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator) - time_base = time_base if time_base else audio_timebase - - # video_timebase is the default time_base - start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(start_pts, end_pts, pts_unit, time_base) def get_pts(time_base): - start_offset = start_pts_sec - end_offset = end_pts_sec + start_offset = start_pts + end_offset = end_pts if pts_unit == "sec": - start_offset = int(math.floor(start_pts_sec * (1 / time_base))) + start_offset = int(math.floor(start_pts * (1 / time_base))) if end_offset != float("inf"): - end_offset = int(math.ceil(end_pts_sec * (1 / time_base))) + end_offset = int(math.ceil(end_pts * (1 / time_base))) if end_offset == float("inf"): end_offset = -1 return start_offset, end_offset + video_pts_range = (0, -1) + video_timebase = default_timebase if has_video: + video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) video_pts_range = get_pts(video_timebase) + audio_pts_range = (0, -1) + audio_timebase = default_timebase if has_audio: + audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator) audio_pts_range = get_pts(audio_timebase) vframes, aframes, info = _read_video_from_file( diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 479fdfc1ddf..d026e754546 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -287,13 +287,6 @@ def read_video( with av.open(filename, metadata_errors="ignore") as container: if container.streams.audio: audio_timebase = container.streams.audio[0].time_base - time_base = _video_opt.default_timebase - if container.streams.video: - time_base = container.streams.video[0].time_base - elif container.streams.audio: - time_base = container.streams.audio[0].time_base - # video_timebase is the default time_base - start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, pts_unit, time_base) if container.streams.video: video_frames = _read_from_stream( container,