|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | """
|
8 |
| -================================================== |
9 |
| -Basic Example to use TorchCodec to decode a video. |
10 |
| -================================================== |
| 8 | +======================================== |
| 9 | +Decoding a video with SimpleVideoDecoder |
| 10 | +======================================== |
11 | 11 |
|
12 |
| -A simple example showing how to decode the first few frames of a video using |
13 |
| -the :class:`~torchcodec.decoders.SimpleVideoDecoder` class. |
| 12 | +In this example, we'll learn how to decode a video using the |
| 13 | +:class:`~torchcodec.decoders.SimpleVideoDecoder` class. |
14 | 14 | """
|
15 | 15 |
|
16 | 16 | # %%
|
17 |
| -import inspect |
18 |
| -import os |
| 17 | +# First, a bit of boilerplate: we'll download a video from the web, and define a |
| 18 | +# plotting utility. You can ignore that part and jump right below to |
| 19 | +# :ref:`creating_decoder`. |
19 | 20 |
|
| 21 | +from typing import Optional |
| 22 | +import torch |
| 23 | +import requests |
| 24 | + |
| 25 | + |
| 26 | +# Video source: https://www.pexels.com/video/dog-eating-854132/ |
| 27 | +# License: CC0. Author: Coverr. |
| 28 | +url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4" |
| 29 | +response = requests.get(url) |
| 30 | +if response.status_code != 200: |
| 31 | + raise RuntimeError(f"Failed to download video. {response.status_code = }.") |
| 32 | + |
| 33 | +raw_video_bytes = response.content |
| 34 | + |
| 35 | + |
| 36 | +def plot(frames: torch.Tensor, title : Optional[str] = None): |
| 37 | + try: |
| 38 | + from torchvision.utils import make_grid |
| 39 | + from torchvision.transforms.v2.functional import to_pil_image |
| 40 | + import matplotlib.pyplot as plt |
| 41 | + except ImportError: |
| 42 | + print("Cannot plot, please run `pip install torchvision matplotlib`") |
| 43 | + return |
| 44 | + |
| 45 | + plt.rcParams["savefig.bbox"] = 'tight' |
| 46 | + fig, ax = plt.subplots() |
| 47 | + ax.imshow(to_pil_image(make_grid(frames))) |
| 48 | + ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) |
| 49 | + if title is not None: |
| 50 | + ax.set_title(title) |
| 51 | + plt.tight_layout() |
| 52 | + |
| 53 | + |
| 54 | +# %% |
| 55 | +# .. _creating_decoder: |
| 56 | +# |
| 57 | +# Creating a decoder |
| 58 | +# ------------------ |
| 59 | +# |
| 60 | +# We can now create a decoder from the raw (encoded) video bytes. You can of |
| 61 | +# course use a local video file and pass the path as input, rather than download |
| 62 | +# a video. |
20 | 63 | from torchcodec.decoders import SimpleVideoDecoder
|
21 | 64 |
|
| 65 | +# You can also pass a path to a local file! |
| 66 | +decoder = SimpleVideoDecoder(raw_video_bytes) |
| 67 | + |
22 | 68 | # %%
|
23 |
| -my_path = os.path.abspath(inspect.getfile(inspect.currentframe())) |
24 |
| -video_file_path = os.path.dirname(my_path) + "/../test/resources/nasa_13013.mp4" |
25 |
| -simple_decoder = SimpleVideoDecoder(video_file_path) |
| 69 | +# The has not yet been decoded by the decoder, but we already have access to |
| 70 | +# some metadata via the ``metadata`` attribute which is a |
| 71 | +# :class:`~torchcodec.decoders.VideoStreamMetadata` object. |
| 72 | +print(decoder.metadata) |
26 | 73 |
|
27 | 74 | # %%
|
28 |
| -# You can get the total frame count for the best video stream by calling len(). |
29 |
| -num_frames = len(simple_decoder) |
30 |
| -print(f"{video_file_path=} has {num_frames} frames") |
| 75 | +# Decoding frames by indexing the decoder |
| 76 | +# --------------------------------------- |
| 77 | + |
| 78 | +first_frame = decoder[0] # using a single int index |
| 79 | +every_twenty_frame = decoder[0 : -1 : 20] # using slices |
| 80 | + |
| 81 | +print(f"{first_frame.shape = }") |
| 82 | +print(f"{first_frame.dtype = }") |
| 83 | +print(f"{every_twenty_frame.shape = }") |
| 84 | +print(f"{every_twenty_frame.dtype = }") |
31 | 85 |
|
32 | 86 | # %%
|
33 |
| -# You can get the decoded frame by using the subscript operator. |
34 |
| -first_frame = simple_decoder[0] |
35 |
| -print(f"decoded frame has type {type(first_frame)}") |
| 87 | +# Indexing the decoder returns the frames as :class:`torch.Tensor` objects. |
| 88 | +# By default, the shape of the frames is ``(N, C, H, W)`` where N is the batch |
| 89 | +# size C the number of channels, H is the height, and W is the width of the |
| 90 | +# frames. The batch dimension N is only present when we're decoding more than |
| 91 | +# one frame. The dimension order can be changed to ``N, H, W, C`` using the |
| 92 | +# ``dimension_order`` parameter of |
| 93 | +# :class:`~torchcodec.decoders.SimpleVideoDecoder`. Frames are always of |
| 94 | +# ``torch.uint8`` dtype. |
| 95 | +# |
| 96 | + |
| 97 | +plot(first_frame, "First frame") |
| 98 | + |
| 99 | +# %% |
| 100 | +plot(every_twenty_frame, "Every 20 frame") |
| 101 | + |
| 102 | +# %% |
| 103 | +# Iterating over frames |
| 104 | +# --------------------- |
| 105 | +# |
| 106 | +# The decoder is a normal iterable object and can be iterated over like so: |
| 107 | + |
| 108 | +for frame in decoder: |
| 109 | + assert ( |
| 110 | + isinstance(frame, torch.Tensor) |
| 111 | + and frame.shape == (3, decoder.metadata.height, decoder.metadata.width) |
| 112 | + ) |
36 | 113 |
|
37 | 114 | # %%
|
38 |
| -# The shape of the decoded frame is (H, W, C) where H and W are the height |
39 |
| -# and width of the video frame. C is 3 because we have 3 channels red, green, |
40 |
| -# and blue. |
41 |
| -print(f"{first_frame.shape=}") |
| 115 | +# Retrieving pts and duration of frames |
| 116 | +# ------------------------------------- |
| 117 | +# |
| 118 | +# Indexing the decoder returns pure :class:`torch.Tensor` objects. Sometimes, it |
| 119 | +# can be useful to retrieve additional information about the frames, such as |
| 120 | +# their :term:`pts` (Presentation Time Stamp), and their duration. |
| 121 | +# This can be achieved using the |
| 122 | +# :meth:`~torchcodec.decoders.SimpleVideoDecoder.get_frame_at` and |
| 123 | +# :meth:`~torchcodec.decoders.SimpleVideoDecoder.get_frames_at` methods, which |
| 124 | +# will return a :class:`~torchcodec.decoders.Frame` and |
| 125 | +# :class:`~torchcodec.decoders.FrameBatch` objects respectively. |
| 126 | + |
| 127 | +last_frame = decoder.get_frame_at(len(decoder) - 1) |
| 128 | +print(f"{type(last_frame) = }") |
| 129 | +print(last_frame) |
| 130 | + |
| 131 | +# %% |
| 132 | +middle_frames = decoder.get_frames_at(start=10, stop=20, step=2) |
| 133 | +print(f"{type(middle_frames) = }") |
| 134 | +print(middle_frames) |
42 | 135 |
|
43 | 136 | # %%
|
44 |
| -# The dtype of the decoded frame is ``torch.uint8``. |
45 |
| -print(f"{first_frame.dtype=}") |
| 137 | +plot(last_frame.data, "Last frame") |
| 138 | +plot(middle_frames.data, "Middle frames") |
46 | 139 |
|
47 | 140 | # %%
|
48 |
| -# Negative indexes are supported. |
49 |
| -last_frame = simple_decoder[-1] |
50 |
| -print(f"{last_frame.shape=}") |
| 141 | +# Both :class:`~torchcodec.decoders.Frame` and |
| 142 | +# :class:`~torchcodec.decoders.FrameBatch` have a ``data`` field, which contains |
| 143 | +# the decoded tensor data. They also have the ``pts_seconds`` and |
| 144 | +# ``duration_seconds`` fields which are single ints for |
| 145 | +# :class:`~torchcodec.decoders.Frame`, and 1-D :class:`torch.Tensor` for |
| 146 | +# :class:`~torchcodec.decoders.FrameBatch` (one value per frame in the batch). |
| 147 | + |
| 148 | +# %% |
| 149 | +# Using time-based indexing |
| 150 | +# ------------------------- |
| 151 | +# |
| 152 | +# So far, we have retrieved frames based on their index. We can also retrieve |
| 153 | +# frames based on *when* they are displayed. The available method are |
| 154 | +# :meth:`~torchcodec.decoders.SimpleVideoDecoder.get_frame_displayed_at` and |
| 155 | +# :meth:`~torchcodec.decoders.SimpleVideoDecoder.get_frames_displayed_at`, which |
| 156 | +# also return :class:`~torchcodec.decoders.Frame` and |
| 157 | +# :class:`~torchcodec.decoders.FrameBatch` objects respectively. |
51 | 158 |
|
52 |
| -# TODO_BEFORE_RELEASE: add documentation for slices and metadata. |
| 159 | +frame_at_2_seconds = decoder.get_frame_displayed_at(seconds=2) |
| 160 | +print(f"{type(frame_at_2_seconds) = }") |
| 161 | +print(frame_at_2_seconds) |
| 162 | +plot(frame_at_2_seconds.data, "Frame displayed at 2 seconds") |
| 163 | +# TODO_BEFORE_RELEASE: illustrate get_frames_displayed_at |
0 commit comments