From 96b0935c92c6a795ab1a0641dfcd8ad28ca3b054 Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Wed, 7 Oct 2020 11:46:33 -0500 Subject: [PATCH 1/9] gitignore now supports IPYNB aux files --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 6d649a7c019..4ed0749da06 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ gen.yml .mypy_cache .vscode/ *.orig +*-checkpoint.ipynb \ No newline at end of file From 523a79f02bb3dc0ba9ffe49781881c39a3172da5 Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Wed, 7 Oct 2020 11:57:38 -0500 Subject: [PATCH 2/9] Adding the reference ipython notebook --- references/videoAPI/VideoAPI_Reference.ipynb | 484 +++++++++++++++++++ 1 file changed, 484 insertions(+) create mode 100644 references/videoAPI/VideoAPI_Reference.ipynb diff --git a/references/videoAPI/VideoAPI_Reference.ipynb b/references/videoAPI/VideoAPI_Reference.ipynb new file mode 100644 index 00000000000..8a00b8459d4 --- /dev/null +++ b/references/videoAPI/VideoAPI_Reference.ipynb @@ -0,0 +1,484 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Welcome to torchvision's new video API\n", + "\n", + "Here, we're going to examine the capabilities of the new video API, together with the examples on how to build datasets and more. \n", + "\n", + "### Table of contents\n", + "1. Introduction: building a new video object and examining the properties\n", + "2. Building a sample `read_video` function\n", + "3. Building an example dataset (can be applied to e.g. kinetics400)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Introduction: building a new video object and examining the properties\n", + "\n", + "First we select a video to test the object out. For the sake of argument we're using one from Kinetics400 dataset. To create it, we need to define the path and the stream we want to use. See inline comments for description. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch, torchvision\n", + "\"\"\"\n", + "chosen video statistics:\n", + "WUzgd7C1pWA.mp4\n", + " - source: kinetics-400\n", + " - video: H-264 - MPEG-4 AVC (part 10) (avc1)\n", + " - fps: 29.97\n", + " - audio: MPEG AAC audio (mp4a)\n", + " - sample rate: 48K Hz\n", + "\"\"\"\n", + "video_path = \"../../test/assets/videos/WUzgd7C1pWA.mp4\"\n", + "\n", + "\"\"\"\n", + "streams are defined in a similar fashion as torch devices. We encode them as strings in a form\n", + "of `stream_type:stream_id` where stream_type is a string and stream_id a long int. \n", + "\n", + "The constructor accepts passing a stream_type only, in which case the stream is auto-discovered.\n", + "\"\"\"\n", + "stream = \"video\"\n", + "\n", + "\n", + "\n", + "video = torch.classes.torchvision.Video(video_path, stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's get the metadata for our particular video:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'video': {'duration': [10.9109], 'fps': [29.97002997002997]},\n", + " 'audio': {'duration': [10.9], 'framerate': [48000.0]}}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "video.get_metadata()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we can see that video has two streams - a video and an audio stream. \n", + "\n", + "Let's read all the frames from the video stream." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of frames: 327\n", + "We can expect approx: 327.0\n", + "Tensor size: torch.Size([3, 256, 340])\n" + ] + } + ], + "source": [ + "# first we select the video stream \n", + "video.set_current_stream(\"video:0\")\n", + "\n", + "frames = [] # we are going to save the frames here.\n", + "frame, pts = video.next()\n", + "# note that next will return emptyframe at the end of the video stream\n", + "while frame.numel() != 0:\n", + " frames.append(frame)\n", + " frame, pts = video.next()\n", + " \n", + "print(\"Total number of frames: \", len(frames))\n", + "approx_nf = video.get_metadata()['video']['duration'][0] * video.get_metadata()['video']['fps'][0]\n", + "print(\"We can expect approx: \", approx_nf)\n", + "print(\"Tensor size: \", frames[0].size())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that selecting zero video stream is equivalent to selecting video stream automatically. I.e. `video:0` and `video` will end up with same results in this case. \n", + "\n", + "Let's try this for audio" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of frames: 511\n", + "We can expect approx: 327.0\n", + "Tensor size: torch.Size([1024, 1])\n" + ] + } + ], + "source": [ + "video.set_current_stream(\"audio\")\n", + "\n", + "frames = [] # we are going to save the frames here.\n", + "frame, pts = video.next()\n", + "# note that next will return emptyframe at the end of the video stream\n", + "while frame.numel() != 0:\n", + " frames.append(frame)\n", + " frame, pts = video.next()\n", + " \n", + "print(\"Total number of frames: \", len(frames))\n", + "approx_nf = video.get_metadata()['video']['duration'][0] * video.get_metadata()['video']['fps'][0]\n", + "print(\"We can expect approx: \", approx_nf)\n", + "print(\"Tensor size: \", frames[0].size())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But what if we only want to read certain time segment of the video?\n", + "\n", + "That can be done easily using the combination of our seek function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds.\n", + "\n", + "For example, if we wanted to read video from second to fifth second:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of frames: 90\n", + "We can expect approx: 89.91008991008991\n", + "Tensor size: torch.Size([3, 256, 340])\n" + ] + } + ], + "source": [ + "video.set_current_stream(\"video\")\n", + "\n", + "frames = [] # we are going to save the frames here.\n", + "\n", + "# we seek into a second second of the video \n", + "# the following call to next returns the first following frame\n", + "video.seek(2) \n", + "frame, pts = video.next()\n", + "# note that we add exit condition\n", + "while pts < 5 and frame.numel() != 0:\n", + " frames.append(frame)\n", + " frame, pts = video.next()\n", + " \n", + "print(\"Total number of frames: \", len(frames))\n", + "approx_nf = (5-2) * video.get_metadata()['video']['fps'][0]\n", + "print(\"We can expect approx: \", approx_nf)\n", + "print(\"Tensor size: \", frames[0].size())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Building a sample `read_video` function\n", + "\n", + "We can utilize the methods above to build the read video function that follows the same API to the existing `read_video` function " + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "def example_read_video(video_object, start=0, end=None, read_video=True, read_audio=True):\n", + "\n", + " if end is None:\n", + " end = float(\"inf\")\n", + " if end < start:\n", + " raise ValueError(\n", + " \"end time should be larger than start time, got \"\n", + " \"start time={} and end time={}\".format(s, e)\n", + " )\n", + " \n", + " video_frames = torch.empty(0)\n", + " video_pts = []\n", + " if read_video:\n", + " video_object.set_current_stream(\"video\")\n", + " video_object.seek(start)\n", + " frames = []\n", + " t, pts = video_object.next()\n", + " while t.numel() > 0 and (pts >= start and pts <= end):\n", + " frames.append(t)\n", + " video_pts.append(pts)\n", + " t, pts = video_object.next()\n", + " if len(frames) > 0:\n", + " video_frames = torch.stack(frames, 0)\n", + "\n", + " audio_frames = torch.empty(0)\n", + " audio_pts = []\n", + " if read_audio:\n", + " video_object.set_current_stream(\"audio\")\n", + " video_object.seek(start)\n", + " frames = []\n", + " t, pts = video_object.next()\n", + " while t.numel() > 0 and (pts >= start and pts <= end):\n", + " frames.append(t)\n", + " audio_pts.append(pts)\n", + " t, pts = video_object.next()\n", + " if len(frames) > 0:\n", + " audio_frames = torch.cat(frames, 1)\n", + "\n", + " return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([327, 3, 256, 340])" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vf, af, info, meta = example_read_video(video)\n", + "# total number of frames should be 327\n", + "vf.size()" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1024, 511])" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# you can also get the sequence of audio frames as well\n", + "af.size()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Building an example randomly sampled dataset (can be applied to training dataest of kinetics400)\n", + "\n", + "Cool, so now we can use the same principle to make the sample dataset. We suggest trying out iterable dataset for this purpose. \n", + "\n", + "Here, we are going to build\n", + "\n", + "a. an example dataset that reads randomly selected 10 frames of video" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "# first, housekeeping and utilities\n", + "import os\n", + "import random\n", + "\n", + "import torch\n", + "from torchvision.datasets.folder import make_dataset\n", + "from torchvision import transforms as t\n", + "\n", + "def _find_classes(dir):\n", + " classes = [d.name for d in os.scandir(dir) if d.is_dir()]\n", + " classes.sort()\n", + " class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}\n", + " return classes, class_to_idx\n", + "\n", + "def get_samples(root, extensions=(\".mp4\", \".avi\")):\n", + " _, class_to_idx = _find_classes(root)\n", + " return make_dataset(root, class_to_idx, extensions=extensions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We are going to define the dataset and some basic arguments. We asume the structure of the FolderDataset, and add the following parameters:\n", + " \n", + "1. frame transform: with this API, we can chose to apply transforms on every frame of the video\n", + "2. videotransform: equally, we can also apply transform to a 4D tensor\n", + "3. length of the clip: do we want a single or multiple frames?\n", + "\n", + "Note that we actually add `epoch size` as using `IterableDataset` class allows us to naturally oversample clips or images from each video if needed. " + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "class RandomDataset(torch.utils.data.IterableDataset):\n", + " def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):\n", + " super(RandomDataset).__init__()\n", + " \n", + " self.samples = get_samples(root)\n", + " \n", + " # allow for temporal jittering\n", + " if epoch_size is None:\n", + " epoch_size = len(self.samples)\n", + " self.epoch_size = epoch_size\n", + " \n", + " self.clip_len = clip_len # length of a clip in frames\n", + " self.frame_transform = frame_transform # transform for every frame individually\n", + " self.video_transform = video_transform # transform on a video sequence\n", + "\n", + " def __iter__(self):\n", + " for i in range(self.epoch_size):\n", + " # get random sample\n", + " path, target = random.choice(self.samples)\n", + " # get video object\n", + " vid = torch.classes.torchvision.Video(path, \"video\")\n", + " metadata = vid.get_metadata()\n", + " video_frames = [] # video frame buffer \n", + " # seek and return frames\n", + " \n", + " max_seek = metadata[\"video\"]['duration'][0] - (self.clip_len / metadata[\"video\"]['fps'][0])\n", + " start = random.uniform(0., max_seek)\n", + " vid.seek(start)\n", + " while len(video_frames) < self.clip_len:\n", + " frame, current_pts = vid.next()\n", + " video_frames.append(self.frame_transform(frame))\n", + " # stack it into a tensor\n", + " video = torch.stack(video_frames, 0)\n", + " if self.video_transform:\n", + " video = self.video_transform(video)\n", + " output = {\n", + " 'path': path,\n", + " 'video': video,\n", + " 'target': target,\n", + " 'start': start,\n", + " 'end': current_pts}\n", + " yield output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Given a path of videos in a folder structure, i.e:\n", + "```\n", + "dataset:\n", + " -class 1:\n", + " file 0\n", + " file 1\n", + " ...\n", + " - class 2:\n", + " file 0\n", + " file 1\n", + " ...\n", + " - ...\n", + "```\n", + "We can generate a dataloader and test the dataset. \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "from torchvision import transforms as t\n", + "transforms = [t.ToPILImage(), t.Resize((112, 112), interpolation=2), t.ToTensor()]\n", + "frame_transform = t.Compose(transforms)\n", + "\n", + "ds = RandomDataset(\"/home/bjuncek/work/video_reader_benchmark/dataset_files\", epoch_size=None, frame_transform=frame_transform)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "loader = DataLoader(ds, batch_size=12)\n", + "d = {\"video\":[], 'start':[], 'end':[]}\n", + "for b in loader:\n", + " for i in range(len(b['path'])):\n", + " d['video'].append(b['path'][i])\n", + " d['start'].append(b['start'][i].item())\n", + " d['end'].append(b['end'][i].item())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From e9cb83f871c6cec7b764891534257120e9554ee1 Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Wed, 7 Oct 2020 13:54:23 -0500 Subject: [PATCH 3/9] Update location --- examples/python/README.md | 6 +++++- .../python/VideoAPI.ipynb | 0 2 files changed, 5 insertions(+), 1 deletion(-) rename references/videoAPI/VideoAPI_Reference.ipynb => examples/python/VideoAPI.ipynb (100%) diff --git a/examples/python/README.md b/examples/python/README.md index 31a308102fc..7b06acc20ce 100644 --- a/examples/python/README.md +++ b/examples/python/README.md @@ -1,7 +1,9 @@ # Python examples - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb) -[Examples of Tensor Images transformations](https://github.com/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb) +- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb) +[Example of VideoAPI](https://github.com/pytorch/vision/blob/master/examples/python/VideoAPI.ipynb) + Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to that. Now, since v0.8.0, transforms implementations are Tensor and PIL compatible and we can achieve the following new @@ -11,3 +13,5 @@ features: - support for GPU acceleration - batched transformation such as for videos - read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats) + +Furhermore, previously we used to provide a very high-level API for video decoding which left little control to the user. We're now expanding that API (and replacing it in the future) with a lower-level API that allows the user a frame-based access to a video. \ No newline at end of file diff --git a/references/videoAPI/VideoAPI_Reference.ipynb b/examples/python/VideoAPI.ipynb similarity index 100% rename from references/videoAPI/VideoAPI_Reference.ipynb rename to examples/python/VideoAPI.ipynb From 3f2b6658d8e3b3bca0d012ed87a39a8dcafe6d5e Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Wed, 7 Oct 2020 14:11:15 -0500 Subject: [PATCH 4/9] link fix --- examples/python/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/python/README.md b/examples/python/README.md index 7b06acc20ce..3d0f14b957b 100644 --- a/examples/python/README.md +++ b/examples/python/README.md @@ -1,7 +1,8 @@ # Python examples - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb) -- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb) +[Examples of Tensor Images transformations](https://github.com/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb) +- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/VideoAPI.ipynb) [Example of VideoAPI](https://github.com/pytorch/vision/blob/master/examples/python/VideoAPI.ipynb) From 7484f1ec20474cd79093e90481cf7d1e3e2a906e Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Wed, 7 Oct 2020 14:27:36 -0500 Subject: [PATCH 5/9] Add autodownload for colab --- examples/python/VideoAPI.ipynb | 212 +++++++++++++++++++++++++++++++-- 1 file changed, 200 insertions(+), 12 deletions(-) diff --git a/examples/python/VideoAPI.ipynb b/examples/python/VideoAPI.ipynb index 8a00b8459d4..84648df3032 100644 --- a/examples/python/VideoAPI.ipynb +++ b/examples/python/VideoAPI.ipynb @@ -14,6 +14,46 @@ "3. Building an example dataset (can be applied to e.g. kinetics400)" ] }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('1.7.0a0+f5c95d5', '0.8.0a0+6eff0a4')" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch, torchvision\n", + "torch.__version__, torchvision.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using downloaded and verified file: ./WUzgd7C1pWA.mp4\n" + ] + } + ], + "source": [ + "# download the sample video\n", + "from torchvision.datasets.utils import download_url\n", + "download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true\", \".\", \"WUzgd7C1pWA.mp4\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -25,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -39,7 +79,7 @@ " - audio: MPEG AAC audio (mp4a)\n", " - sample rate: 48K Hz\n", "\"\"\"\n", - "video_path = \"../../test/assets/videos/WUzgd7C1pWA.mp4\"\n", + "video_path = \"./WUzgd7C1pWA.mp4\"\n", "\n", "\"\"\"\n", "streams are defined in a similar fashion as torch devices. We encode them as strings in a form\n", @@ -63,17 +103,17 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'video': {'duration': [10.9109], 'fps': [29.97002997002997]},\n", - " 'audio': {'duration': [10.9], 'framerate': [48000.0]}}" + "{'video': {'duration': [], 'fps': []},\n", + " 'audio': {'duration': [], 'framerate': []}}" ] }, - "execution_count": 3, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -323,11 +363,110 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# make sample dataest\n", + "import os\n", + "os.makedirs(\"./dataset\", exist_ok=True)\n", + "os.makedirs(\"./dataset/1\", exist_ok=True)\n", + "os.makedirs(\"./dataset/2\", exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true to ./dataset/1/WUzgd7C1pWA.mp4\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100.4%" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi?raw=true to ./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "102.5%" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/SOX5yA1l24A.mp4?raw=true to ./dataset/2/SOX5yA1l24A.mp4\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100.9%" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/v_SoccerJuggling_g23_c01.avi?raw=true to ./dataset/2/v_SoccerJuggling_g23_c01.avi\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "101.5%" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/v_SoccerJuggling_g24_c01.avi?raw=true to ./dataset/2/v_SoccerJuggling_g24_c01.avi\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "101.3%" + ] + } + ], + "source": [ + "# download the videos \n", + "from torchvision.datasets.utils import download_url\n", + "download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true\", \"./dataset/1\", \"WUzgd7C1pWA.mp4\")\n", + "download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi?raw=true\", \"./dataset/1\", \"RATRACE_wave_f_nm_np1_fr_goo_37.avi\")\n", + "download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/SOX5yA1l24A.mp4?raw=true\", \"./dataset/2\", \"SOX5yA1l24A.mp4\")\n", + "download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/v_SoccerJuggling_g23_c01.avi?raw=true\", \"./dataset/2\", \"v_SoccerJuggling_g23_c01.avi\")\n", + "download_url(\"https://github.com/pytorch/vision/blob/master/test/assets/videos/v_SoccerJuggling_g24_c01.avi?raw=true\", \"./dataset/2\", \"v_SoccerJuggling_g24_c01.avi\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "# first, housekeeping and utilities\n", + "# housekeeping and utilities\n", "import os\n", "import random\n", "\n", @@ -361,7 +500,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -432,7 +571,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -440,12 +579,12 @@ "transforms = [t.ToPILImage(), t.Resize((112, 112), interpolation=2), t.ToTensor()]\n", "frame_transform = t.Compose(transforms)\n", "\n", - "ds = RandomDataset(\"/home/bjuncek/work/video_reader_benchmark/dataset_files\", epoch_size=None, frame_transform=frame_transform)" + "ds = RandomDataset(\"./dataset\", epoch_size=None, frame_transform=frame_transform)" ] }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -458,6 +597,55 @@ " d['start'].append(b['start'][i].item())\n", " d['end'].append(b['end'][i].item())" ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'video': ['./dataset/1/WUzgd7C1pWA.mp4',\n", + " './dataset/1/WUzgd7C1pWA.mp4',\n", + " './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi',\n", + " './dataset/2/SOX5yA1l24A.mp4',\n", + " './dataset/1/WUzgd7C1pWA.mp4'],\n", + " 'start': [9.389068196639585,\n", + " 0.49568543984299757,\n", + " 1.8230755088943975,\n", + " 9.825717940944148,\n", + " 8.683333240145584],\n", + " 'end': [9.9099, 1.001, 2.333333, 10.343667, 9.2092]}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "## Cleanup\n", + "import os, shutil\n", + "os.remove(\"./WUzgd7C1pWA.mp4\")\n", + "shutil.rmtree(\"./dataset\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From 4190da439a1782230587d1797f5aba534ffccbd2 Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Thu, 8 Oct 2020 06:27:45 -0500 Subject: [PATCH 6/9] rename and address fmassa's comments --- .../{VideoAPI.ipynb => video_api.ipynb} | 71 +++++++++---------- 1 file changed, 35 insertions(+), 36 deletions(-) rename examples/python/{VideoAPI.ipynb => video_api.ipynb} (93%) diff --git a/examples/python/VideoAPI.ipynb b/examples/python/video_api.ipynb similarity index 93% rename from examples/python/VideoAPI.ipynb rename to examples/python/video_api.ipynb index 84648df3032..41639e8dabd 100644 --- a/examples/python/VideoAPI.ipynb +++ b/examples/python/video_api.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -25,7 +25,7 @@ "('1.7.0a0+f5c95d5', '0.8.0a0+6eff0a4')" ] }, - "execution_count": 1, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -65,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -103,17 +103,17 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'video': {'duration': [], 'fps': []},\n", - " 'audio': {'duration': [], 'framerate': []}}" + "{'video': {'duration': [10.9109], 'fps': [29.97002997002997]},\n", + " 'audio': {'duration': [10.9], 'framerate': [48000.0]}}" ] }, - "execution_count": 5, + "execution_count": 39, "metadata": {}, "output_type": "execute_result" } @@ -133,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 40, "metadata": {}, "outputs": [ { @@ -148,6 +148,7 @@ ], "source": [ "# first we select the video stream \n", + "metadata = video.get_metadata()\n", "video.set_current_stream(\"video:0\")\n", "\n", "frames = [] # we are going to save the frames here.\n", @@ -158,7 +159,7 @@ " frame, pts = video.next()\n", " \n", "print(\"Total number of frames: \", len(frames))\n", - "approx_nf = video.get_metadata()['video']['duration'][0] * video.get_metadata()['video']['fps'][0]\n", + "approx_nf = metadata['video']['duration'][0] * metadata['video']['fps'][0]\n", "print(\"We can expect approx: \", approx_nf)\n", "print(\"Tensor size: \", frames[0].size())" ] @@ -174,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 41, "metadata": {}, "outputs": [ { @@ -182,25 +183,26 @@ "output_type": "stream", "text": [ "Total number of frames: 511\n", - "We can expect approx: 327.0\n", - "Tensor size: torch.Size([1024, 1])\n" + "Approx total number of datapoints we can expect: 523200.0\n", + "Read data size: 523264\n" ] } ], "source": [ + "metadata = video.get_metadata()\n", "video.set_current_stream(\"audio\")\n", "\n", "frames = [] # we are going to save the frames here.\n", "frame, pts = video.next()\n", - "# note that next will return emptyframe at the end of the video stream\n", + "# note that next will return emptyframe at the end of the audio stream\n", "while frame.numel() != 0:\n", " frames.append(frame)\n", " frame, pts = video.next()\n", " \n", "print(\"Total number of frames: \", len(frames))\n", - "approx_nf = video.get_metadata()['video']['duration'][0] * video.get_metadata()['video']['fps'][0]\n", - "print(\"We can expect approx: \", approx_nf)\n", - "print(\"Tensor size: \", frames[0].size())" + "approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]\n", + "print(\"Approx total number of datapoints we can expect: \", approx_nf)\n", + "print(\"Read data size: \", frames[0].size(0) * len(frames))" ] }, { @@ -216,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -260,7 +262,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 43, "metadata": {}, "outputs": [], "source": [ @@ -300,45 +302,42 @@ " audio_pts.append(pts)\n", " t, pts = video_object.next()\n", " if len(frames) > 0:\n", - " audio_frames = torch.cat(frames, 1)\n", + " audio_frames = torch.cat(frames, 0)\n", "\n", " return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata()" ] }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 44, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "torch.Size([327, 3, 256, 340])" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([327, 3, 256, 340]) torch.Size([523264, 1])\n" + ] } ], "source": [ "vf, af, info, meta = example_read_video(video)\n", "# total number of frames should be 327\n", - "vf.size()" + "print(vf.size(), af.size())" ] }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([1024, 511])" + "torch.Size([523264, 1])" ] }, - "execution_count": 44, + "execution_count": 45, "metadata": {}, "output_type": "execute_result" } @@ -363,7 +362,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -376,7 +375,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 47, "metadata": {}, "outputs": [ { @@ -462,7 +461,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ From 0d251142166fb519829a99399c21b8c6a5f5e7ec Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Thu, 8 Oct 2020 06:34:55 -0500 Subject: [PATCH 7/9] nitpicks --- examples/python/video_api.ipynb | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/python/video_api.ipynb b/examples/python/video_api.ipynb index 41639e8dabd..041e40ddfb8 100644 --- a/examples/python/video_api.ipynb +++ b/examples/python/video_api.ipynb @@ -461,7 +461,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ @@ -499,7 +499,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ @@ -570,12 +570,12 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "from torchvision import transforms as t\n", - "transforms = [t.ToPILImage(), t.Resize((112, 112), interpolation=2), t.ToTensor()]\n", + "transforms = [t.Resize((112, 112))]\n", "frame_transform = t.Compose(transforms)\n", "\n", "ds = RandomDataset(\"./dataset\", epoch_size=None, frame_transform=frame_transform)" @@ -583,7 +583,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ @@ -599,26 +599,26 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'video': ['./dataset/1/WUzgd7C1pWA.mp4',\n", + "{'video': ['./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi',\n", " './dataset/1/WUzgd7C1pWA.mp4',\n", " './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi',\n", " './dataset/2/SOX5yA1l24A.mp4',\n", - " './dataset/1/WUzgd7C1pWA.mp4'],\n", - " 'start': [9.389068196639585,\n", - " 0.49568543984299757,\n", - " 1.8230755088943975,\n", - " 9.825717940944148,\n", - " 8.683333240145584],\n", - " 'end': [9.9099, 1.001, 2.333333, 10.343667, 9.2092]}" + " './dataset/2/v_SoccerJuggling_g23_c01.avi'],\n", + " 'start': [0.029482554081669773,\n", + " 3.439334232470971,\n", + " 1.1823159301599728,\n", + " 4.470027811314425,\n", + " 3.3126303902318432],\n", + " 'end': [0.5666669999999999, 3.970633, 1.7, 4.971633, 3.837167]}" ] }, - "execution_count": 12, + "execution_count": 53, "metadata": {}, "output_type": "execute_result" } @@ -629,7 +629,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 54, "metadata": {}, "outputs": [], "source": [ From 50e1d0656e95b61fe9cb366ab19972fce57b9c15 Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Thu, 8 Oct 2020 06:35:22 -0500 Subject: [PATCH 8/9] nitpicks --- examples/python/video_api.ipynb | 7 ------- 1 file changed, 7 deletions(-) diff --git a/examples/python/video_api.ipynb b/examples/python/video_api.ipynb index 041e40ddfb8..76de5587ff9 100644 --- a/examples/python/video_api.ipynb +++ b/examples/python/video_api.ipynb @@ -638,13 +638,6 @@ "os.remove(\"./WUzgd7C1pWA.mp4\")\n", "shutil.rmtree(\"./dataset\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From 0a39f10632ca4d524883bb98ca7fcd62e57d97ce Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 8 Oct 2020 15:30:57 +0200 Subject: [PATCH 9/9] Apply suggestions from code review --- examples/python/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/python/README.md b/examples/python/README.md index 3d0f14b957b..9cd02bcb326 100644 --- a/examples/python/README.md +++ b/examples/python/README.md @@ -2,8 +2,8 @@ - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb) [Examples of Tensor Images transformations](https://github.com/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb) -- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/VideoAPI.ipynb) -[Example of VideoAPI](https://github.com/pytorch/vision/blob/master/examples/python/VideoAPI.ipynb) +- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/video_api.ipynb) +[Example of VideoAPI](https://github.com/pytorch/vision/blob/master/examples/python/video_api.ipynb) Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to @@ -15,4 +15,4 @@ features: - batched transformation such as for videos - read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats) -Furhermore, previously we used to provide a very high-level API for video decoding which left little control to the user. We're now expanding that API (and replacing it in the future) with a lower-level API that allows the user a frame-based access to a video. \ No newline at end of file +Furthermore, previously we used to provide a very high-level API for video decoding which left little control to the user. We're now expanding that API (and replacing it in the future) with a lower-level API that allows the user a frame-based access to a video.