Skip to content

Return RGB frames as output of GPU decoder #5191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 19, 2022
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def get_extensions():
"z",
"pthread",
"dl",
"nppicc",
],
extra_compile_args=extra_compile_args,
)
Expand Down
6 changes: 3 additions & 3 deletions test/test_video_gpu_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def test_frame_reading(self):
decoder = VideoReader(full_path, device="cuda:0")
with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_ndarray().flatten())
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
vision_frames = next(decoder)["data"]
mean_delta = torch.mean(torch.abs(av_frames.float() - decoder._reformat(vision_frames).float()))
assert mean_delta < 0.1
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
assert mean_delta < 0.75


if __name__ == "__main__":
Expand Down
47 changes: 17 additions & 30 deletions torchvision/csrc/io/decoder/gpu/decoder.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "decoder.h"
#include <c10/util/Logging.h>
#include <nppi_color_conversion.h>
#include <cmath>
#include <cstring>
#include <unordered_map>
Expand Down Expand Up @@ -138,38 +139,24 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* disp_info) {
}

auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA);
torch::Tensor decoded_frame = torch::empty({get_frame_size()}, options);
torch::Tensor decoded_frame = torch::empty({get_height(), width, 3}, options);
uint8_t* frame_ptr = decoded_frame.data_ptr<uint8_t>();
const uint8_t* const source_arr[] = {
(const uint8_t* const)source_frame,
(const uint8_t* const)(source_frame + source_pitch * ((surface_height + 1) & ~1))};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like us to double-check this (surface_height + 1) & ~1) condition. Can you create some videos with odd dimensions to validate what is actually needed?

Copy link
Contributor Author

@prabhat00155 prabhat00155 Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

surface_height is different from luma_height which is directly related to the video dimensions.
I can revisit this when doing code refactoring.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is the +1) & ~1 condition. I'm not sure if it's actually necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

luma height is aligned by 2, so the chroma offset should not be odd(chroma base address can't be odd memory location), hence the alignment.


auto err = nppiNV12ToRGB_709CSC_8u_P2C3R(
source_arr,
source_pitch,
frame_ptr,
width * 3,
{(int)decoded_frame.size(1), (int)decoded_frame.size(0)});

TORCH_CHECK(
err == NPP_NO_ERROR,
"Failed to convert from NV12 to RGB. Error code:",
err);

// Copy luma plane
CUDA_MEMCPY2D m = {0};
m.srcMemoryType = CU_MEMORYTYPE_DEVICE;
m.srcDevice = source_frame;
m.srcPitch = source_pitch;
m.dstMemoryType = CU_MEMORYTYPE_DEVICE;
m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr);
m.dstPitch = get_width() * bytes_per_pixel;
m.WidthInBytes = get_width() * bytes_per_pixel;
m.Height = luma_height;
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);

// Copy chroma plane
// NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning
// height
m.srcDevice =
(CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1));
m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height);
m.Height = chroma_height;
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);

if (num_chroma_planes == 2) {
m.srcDevice =
(CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1) * 2);
m.dstDevice =
(CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height * 2);
m.Height = chroma_height;
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);
}
check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__, __FILE__);
decoded_frames.push(decoded_frame);
check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__);
Expand Down
42 changes: 1 addition & 41 deletions torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,48 +38,8 @@ torch::Tensor GPUDecoder::decode() {
return frame;
}

/* Convert a tensor with data in NV12 format to a tensor with data in YUV420
* format in-place.
*/
torch::Tensor GPUDecoder::nv12_to_yuv420(torch::Tensor frameTensor) {
int width = decoder.get_width(), height = decoder.get_height();
int pitch = width;
uint8_t* frame = frameTensor.data_ptr<uint8_t>();
uint8_t* ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)];

// sizes of source surface plane
int sizePlaneY = pitch * height;
int sizePlaneU = ((pitch + 1) / 2) * ((height + 1) / 2);
int sizePlaneV = sizePlaneU;

uint8_t* uv = frame + sizePlaneY;
uint8_t* u = uv;
uint8_t* v = uv + sizePlaneU;

// split chroma from interleave to planar
for (int y = 0; y < (height + 1) / 2; y++) {
for (int x = 0; x < (width + 1) / 2; x++) {
u[y * ((pitch + 1) / 2) + x] = uv[y * pitch + x * 2];
ptr[y * ((width + 1) / 2) + x] = uv[y * pitch + x * 2 + 1];
}
}
if (pitch == width) {
memcpy(v, ptr, sizePlaneV * sizeof(uint8_t));
} else {
for (int i = 0; i < (height + 1) / 2; i++) {
memcpy(
v + ((pitch + 1) / 2) * i,
ptr + ((width + 1) / 2) * i,
((width + 1) / 2) * sizeof(uint8_t));
}
}
delete[] ptr;
return frameTensor;
}

TORCH_LIBRARY(torchvision, m) {
m.class_<GPUDecoder>("GPUDecoder")
.def(torch::init<std::string, int64_t>())
.def("next", &GPUDecoder::decode)
.def("reformat", &GPUDecoder::nv12_to_yuv420);
.def("next", &GPUDecoder::decode);
}
1 change: 0 additions & 1 deletion torchvision/csrc/io/decoder/gpu/gpu_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ class GPUDecoder : public torch::CustomClassHolder {
GPUDecoder(std::string, int64_t);
~GPUDecoder();
torch::Tensor decode();
torch::Tensor nv12_to_yuv420(torch::Tensor);

private:
Demuxer demuxer;
Expand Down
10 changes: 0 additions & 10 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,6 @@ def set_current_stream(self, stream: str) -> bool:
print("GPU decoding only works with video stream.")
return self._c.set_current_stream(stream)

def _reformat(self, tensor, output_format: str = "yuv420"):
supported_formats = [
"yuv420",
]
if output_format not in supported_formats:
raise RuntimeError(f"{output_format} not supported, please use one of {', '.join(supported_formats)}")
if not isinstance(tensor, torch.Tensor):
raise RuntimeError("Expected tensor as input parameter!")
return self._c.reformat(tensor.cpu())


__all__ = [
"write_video",
Expand Down