Skip to content

Commit 78d2acb

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
Revamp notebook example (#110)
Summary: This PR re-writes our notebook-style example to cover all methods, attributes and related classes of `SimpleVideoDecoder`. And also to show a cute puppy, ~~which will hopefully account for 95% of the success of this project~~. The example is a bit verbose, by design. I think it's fine, we'll add a much more compact and to-the-point example in the README. The example currently renders as follows: ![image](https://github.com/user-attachments/assets/8f399f39-f456-4b26-87ed-6b0a4b86ba12) Pull Request resolved: #110 Reviewed By: ahmadsharif1 Differential Revision: D60373589 Pulled By: NicolasHug fbshipit-source-id: 5b8ce234d7a687b98bdb69c1272f12c4983f797f
1 parent fb12e58 commit 78d2acb

File tree

4 files changed

+148
-31
lines changed

4 files changed

+148
-31
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
22
max-line-length = 120
3-
ignore = E203, E402, W503, W504, F821, E501, B, C4, EXE
3+
ignore = E203, E402, W503, W504, F821, E501, B, C4, EXE, E251, E202
44
per-file-ignores =
55
__init__.py: F401, F403, F405

examples/basic_example.py

Lines changed: 137 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,48 +5,159 @@
55
# LICENSE file in the root directory of this source tree.
66

77
"""
8-
==================================================
9-
Basic Example to use TorchCodec to decode a video.
10-
==================================================
8+
========================================
9+
Decoding a video with SimpleVideoDecoder
10+
========================================
1111
12-
A simple example showing how to decode the first few frames of a video using
13-
the :class:`~torchcodec.decoders.SimpleVideoDecoder` class.
12+
In this example, we'll learn how to decode a video using the
13+
:class:`~torchcodec.decoders.SimpleVideoDecoder` class.
1414
"""
1515

1616
# %%
17-
import inspect
18-
import os
17+
# First, a bit of boilerplate: we'll download a video from the web, and define a
18+
# plotting utility. You can ignore that part and jump right below to
19+
# :ref:`creating_decoder`.
1920

21+
from typing import Optional
22+
import torch
23+
import requests
24+
25+
26+
# Video source: https://www.pexels.com/video/dog-eating-854132/
27+
# License: CC0. Author: Coverr.
28+
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
29+
response = requests.get(url)
30+
if response.status_code != 200:
31+
raise RuntimeError(f"Failed to download video. {response.status_code = }.")
32+
33+
raw_video_bytes = response.content
34+
35+
36+
def plot(frames: torch.Tensor, title : Optional[str] = None):
37+
try:
38+
from torchvision.utils import make_grid
39+
from torchvision.transforms.v2.functional import to_pil_image
40+
import matplotlib.pyplot as plt
41+
except ImportError:
42+
print("Cannot plot, please run `pip install torchvision matplotlib`")
43+
return
44+
45+
plt.rcParams["savefig.bbox"] = 'tight'
46+
fig, ax = plt.subplots()
47+
ax.imshow(to_pil_image(make_grid(frames)))
48+
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
49+
if title is not None:
50+
ax.set_title(title)
51+
plt.tight_layout()
52+
53+
54+
# %%
55+
# .. _creating_decoder:
56+
#
57+
# Creating a decoder
58+
# ------------------
59+
#
60+
# We can now create a decoder from the raw (encoded) video bytes. You can of
61+
# course use a local video file and pass the path as input, rather than download
62+
# a video.
2063
from torchcodec.decoders import SimpleVideoDecoder
2164

65+
# You can also pass a path to a local file!
66+
decoder = SimpleVideoDecoder(raw_video_bytes)
67+
2268
# %%
23-
my_path = os.path.abspath(inspect.getfile(inspect.currentframe()))
24-
video_file_path = os.path.dirname(my_path) + "/../test/resources/nasa_13013.mp4"
25-
simple_decoder = SimpleVideoDecoder(video_file_path)
69+
# The has not yet been decoded by the decoder, but we already have access to
70+
# some metadata via the ``metadata`` attribute which is a
71+
# :class:`~torchcodec.decoders.VideoStreamMetadata` object.
72+
print(decoder.metadata)
2673

2774
# %%
28-
# You can get the total frame count for the best video stream by calling len().
29-
num_frames = len(simple_decoder)
30-
print(f"{video_file_path=} has {num_frames} frames")
75+
# Decoding frames by indexing the decoder
76+
# ---------------------------------------
77+
78+
first_frame = decoder[0] # using a single int index
79+
every_twenty_frame = decoder[0 : -1 : 20] # using slices
80+
81+
print(f"{first_frame.shape = }")
82+
print(f"{first_frame.dtype = }")
83+
print(f"{every_twenty_frame.shape = }")
84+
print(f"{every_twenty_frame.dtype = }")
3185

3286
# %%
33-
# You can get the decoded frame by using the subscript operator.
34-
first_frame = simple_decoder[0]
35-
print(f"decoded frame has type {type(first_frame)}")
87+
# Indexing the decoder returns the frames as :class:`torch.Tensor` objects.
88+
# By default, the shape of the frames is ``(N, C, H, W)`` where N is the batch
89+
# size C the number of channels, H is the height, and W is the width of the
90+
# frames. The batch dimension N is only present when we're decoding more than
91+
# one frame. The dimension order can be changed to ``N, H, W, C`` using the
92+
# ``dimension_order`` parameter of
93+
# :class:`~torchcodec.decoders.SimpleVideoDecoder`. Frames are always of
94+
# ``torch.uint8`` dtype.
95+
#
96+
97+
plot(first_frame, "First frame")
98+
99+
# %%
100+
plot(every_twenty_frame, "Every 20 frame")
101+
102+
# %%
103+
# Iterating over frames
104+
# ---------------------
105+
#
106+
# The decoder is a normal iterable object and can be iterated over like so:
107+
108+
for frame in decoder:
109+
assert (
110+
isinstance(frame, torch.Tensor)
111+
and frame.shape == (3, decoder.metadata.height, decoder.metadata.width)
112+
)
36113

37114
# %%
38-
# The shape of the decoded frame is (H, W, C) where H and W are the height
39-
# and width of the video frame. C is 3 because we have 3 channels red, green,
40-
# and blue.
41-
print(f"{first_frame.shape=}")
115+
# Retrieving pts and duration of frames
116+
# -------------------------------------
117+
#
118+
# Indexing the decoder returns pure :class:`torch.Tensor` objects. Sometimes, it
119+
# can be useful to retrieve additional information about the frames, such as
120+
# their :term:`pts` (Presentation Time Stamp), and their duration.
121+
# This can be achieved using the
122+
# :meth:`~torchcodec.decoders.SimpleVideoDecoder.get_frame_at` and
123+
# :meth:`~torchcodec.decoders.SimpleVideoDecoder.get_frames_at` methods, which
124+
# will return a :class:`~torchcodec.decoders.Frame` and
125+
# :class:`~torchcodec.decoders.FrameBatch` objects respectively.
126+
127+
last_frame = decoder.get_frame_at(len(decoder) - 1)
128+
print(f"{type(last_frame) = }")
129+
print(last_frame)
130+
131+
# %%
132+
middle_frames = decoder.get_frames_at(start=10, stop=20, step=2)
133+
print(f"{type(middle_frames) = }")
134+
print(middle_frames)
42135

43136
# %%
44-
# The dtype of the decoded frame is ``torch.uint8``.
45-
print(f"{first_frame.dtype=}")
137+
plot(last_frame.data, "Last frame")
138+
plot(middle_frames.data, "Middle frames")
46139

47140
# %%
48-
# Negative indexes are supported.
49-
last_frame = simple_decoder[-1]
50-
print(f"{last_frame.shape=}")
141+
# Both :class:`~torchcodec.decoders.Frame` and
142+
# :class:`~torchcodec.decoders.FrameBatch` have a ``data`` field, which contains
143+
# the decoded tensor data. They also have the ``pts_seconds`` and
144+
# ``duration_seconds`` fields which are single ints for
145+
# :class:`~torchcodec.decoders.Frame`, and 1-D :class:`torch.Tensor` for
146+
# :class:`~torchcodec.decoders.FrameBatch` (one value per frame in the batch).
147+
148+
# %%
149+
# Using time-based indexing
150+
# -------------------------
151+
#
152+
# So far, we have retrieved frames based on their index. We can also retrieve
153+
# frames based on *when* they are displayed. The available method are
154+
# :meth:`~torchcodec.decoders.SimpleVideoDecoder.get_frame_displayed_at` and
155+
# :meth:`~torchcodec.decoders.SimpleVideoDecoder.get_frames_displayed_at`, which
156+
# also return :class:`~torchcodec.decoders.Frame` and
157+
# :class:`~torchcodec.decoders.FrameBatch` objects respectively.
51158

52-
# TODO_BEFORE_RELEASE: add documentation for slices and metadata.
159+
frame_at_2_seconds = decoder.get_frame_displayed_at(seconds=2)
160+
print(f"{type(frame_at_2_seconds) = }")
161+
print(frame_at_2_seconds)
162+
plot(frame_at_2_seconds.data, "Frame displayed at 2 seconds")
163+
# TODO_BEFORE_RELEASE: illustrate get_frames_displayed_at

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,9 @@ first_party_detection = false
2929

3030
[tool.black]
3131
target-version = ["py38"]
32+
33+
[tool.ufmt]
34+
35+
excludes = [
36+
"examples",
37+
]

src/torchcodec/decoders/_simple_video_decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from torchcodec.decoders import _core as core
1515

1616

17-
def _frame_str(self):
18-
# Utility to replace Frame and FrameBatch __str__ method. This prints the
17+
def _frame_repr(self):
18+
# Utility to replace Frame and FrameBatch __repr__ method. This prints the
1919
# shape of the .data tensor rather than printing the (potentially very long)
2020
# data tensor itself.
2121
s = self.__class__.__name__ + ":\n"
@@ -46,7 +46,7 @@ def __iter__(self) -> Iterator[Union[Tensor, float]]:
4646
yield getattr(self, field.name)
4747

4848
def __repr__(self):
49-
return _frame_str(self)
49+
return _frame_repr(self)
5050

5151

5252
@dataclass
@@ -65,7 +65,7 @@ def __iter__(self) -> Iterator[Union[Tensor, float]]:
6565
yield getattr(self, field.name)
6666

6767
def __repr__(self):
68-
return _frame_str(self)
68+
return _frame_repr(self)
6969

7070

7171
_ERROR_REPORTING_INSTRUCTIONS = """

0 commit comments

Comments
 (0)