Skip to content

Cache HW device context #3178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchaudio/csrc/ffmpeg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ set(
sources
ffmpeg.cpp
filter_graph.cpp
hw_context.cpp
stream_reader/buffer/chunked_buffer.cpp
stream_reader/buffer/unchunked_buffer.cpp
stream_reader/conversion.cpp
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/ffmpeg/ffmpeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ void AutoBufferUnref::operator()(AVBufferRef* p) {
av_buffer_unref(&p);
}

AVBufferRefPtr::AVBufferRefPtr()
: Wrapper<AVBufferRef, AutoBufferUnref>(nullptr) {}
AVBufferRefPtr::AVBufferRefPtr(AVBufferRef* p)
: Wrapper<AVBufferRef, AutoBufferUnref>(p) {}

void AVBufferRefPtr::reset(AVBufferRef* p) {
TORCH_CHECK(
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/csrc/ffmpeg/ffmpeg.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ struct AutoBufferUnref {
};

struct AVBufferRefPtr : public Wrapper<AVBufferRef, AutoBufferUnref> {
AVBufferRefPtr();
AVBufferRefPtr(AVBufferRef* p = nullptr);
void reset(AVBufferRef* p);
};

Expand Down
40 changes: 40 additions & 0 deletions torchaudio/csrc/ffmpeg/hw_context.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include <torchaudio/csrc/ffmpeg/hw_context.h>

namespace torchaudio::io {
namespace {

static std::mutex MUTEX;
static std::map<int, AVBufferRefPtr> CUDA_CONTEXT_CACHE;

} // namespace

AVBufferRef* get_cuda_context(int index) {
std::lock_guard<std::mutex> lock(MUTEX);
if (index == -1) {
index = 0;
}
if (CUDA_CONTEXT_CACHE.count(index) == 0) {
AVBufferRef* p = nullptr;
int ret = av_hwdevice_ctx_create(
&p, AV_HWDEVICE_TYPE_CUDA, std::to_string(index).c_str(), nullptr, 0);
TORCH_CHECK(
ret >= 0,
"Failed to create CUDA device context on device ",
index,
"(",
av_err2string(ret),
")");
assert(p);
CUDA_CONTEXT_CACHE.emplace(index, p);
return p;
}
AVBufferRefPtr& buffer = CUDA_CONTEXT_CACHE.at(index);
return buffer;
}

void clear_cuda_context_cache() {
std::lock_guard<std::mutex> lock(MUTEX);
CUDA_CONTEXT_CACHE.clear();
}

} // namespace torchaudio::io
11 changes: 11 additions & 0 deletions torchaudio/csrc/ffmpeg/hw_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <torchaudio/csrc/ffmpeg/ffmpeg.h>

namespace torchaudio::io {

AVBufferRef* get_cuda_context(int index);

void clear_cuda_context_cache();

} // namespace torchaudio::io
2 changes: 2 additions & 0 deletions torchaudio/csrc/ffmpeg/pybind/pybind.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/hw_context.h>
#include <torchaudio/csrc/ffmpeg/pybind/fileobj.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
Expand Down Expand Up @@ -30,6 +31,7 @@ struct StreamWriterFileObj : private FileObj, public StreamWriter {
};

PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
m.def("clear_cuda_context_cache", &clear_cuda_context_cache);
py::class_<Chunk>(m, "Chunk", py::module_local())
.def_readwrite("frames", &Chunk::frames)
.def_readwrite("pts", &Chunk::pts);
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/csrc/ffmpeg/stream_reader/stream_processor.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <torchaudio/csrc/ffmpeg/hw_context.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_processor.h>
#include <stdexcept>

Expand Down Expand Up @@ -99,6 +100,7 @@ void configure_codec_context(
// 2. Set pCodecContext->get_format call back function which
// will retrieve the HW pixel format from opaque pointer.
codec_ctx->get_format = get_hw_format;
codec_ctx->hw_device_ctx = av_buffer_ref(get_cuda_context(device.index()));
#endif
}
}
Expand Down
6 changes: 6 additions & 0 deletions torchaudio/utils/ffmpeg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,9 @@ def get_build_config() -> str:
--prefix=/Users/runner/miniforge3 --cc=arm64-apple-darwin20.0.0-clang --enable-gpl --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-neon --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-libvpx --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1646229390493/_build_env/bin/pkg-config --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1646229390493/_build_env/bin/x86_64-apple-darwin13.4.0-clang # noqa
"""
return torch.ops.torchaudio.ffmpeg_get_build_config()


@torchaudio._extension.fail_if_no_ffmpeg
def clear_cuda_context_cache():
"""Clear the CUDA context used by CUDA Hardware accelerated video decoding"""
torchaudio.lib._torchaudio_ffmpeg.clear_cuda_context_cache()