Skip to content

Commit f4fd193

Browse files
authored
Return RGB frames as output of GPU decoder (#5191)
* Return RGB frames as output of GPU decoder * Move clamp to the conversion function * Cleaned up a bit * Remove utility functions from test * Use data member width directly * Fix linter error
1 parent 038828e commit f4fd193

File tree

6 files changed

+22
-85
lines changed

6 files changed

+22
-85
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def get_extensions():
472472
"z",
473473
"pthread",
474474
"dl",
475+
"nppicc",
475476
],
476477
extra_compile_args=extra_compile_args,
477478
)

test/test_video_gpu_decoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def test_frame_reading(self):
3131
decoder = VideoReader(full_path, device="cuda:0")
3232
with av.open(full_path) as container:
3333
for av_frame in container.decode(container.streams.video[0]):
34-
av_frames = torch.tensor(av_frame.to_ndarray().flatten())
34+
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
3535
vision_frames = next(decoder)["data"]
36-
mean_delta = torch.mean(torch.abs(av_frames.float() - decoder._reformat(vision_frames).float()))
37-
assert mean_delta < 0.1
36+
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
37+
assert mean_delta < 0.75
3838

3939

4040
if __name__ == "__main__":

torchvision/csrc/io/decoder/gpu/decoder.cpp

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "decoder.h"
22
#include <c10/util/Logging.h>
3+
#include <nppi_color_conversion.h>
34
#include <cmath>
45
#include <cstring>
56
#include <unordered_map>
@@ -138,38 +139,24 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* disp_info) {
138139
}
139140

140141
auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA);
141-
torch::Tensor decoded_frame = torch::empty({get_frame_size()}, options);
142+
torch::Tensor decoded_frame = torch::empty({get_height(), width, 3}, options);
142143
uint8_t* frame_ptr = decoded_frame.data_ptr<uint8_t>();
144+
const uint8_t* const source_arr[] = {
145+
(const uint8_t* const)source_frame,
146+
(const uint8_t* const)(source_frame + source_pitch * ((surface_height + 1) & ~1))};
147+
148+
auto err = nppiNV12ToRGB_709CSC_8u_P2C3R(
149+
source_arr,
150+
source_pitch,
151+
frame_ptr,
152+
width * 3,
153+
{(int)decoded_frame.size(1), (int)decoded_frame.size(0)});
154+
155+
TORCH_CHECK(
156+
err == NPP_NO_ERROR,
157+
"Failed to convert from NV12 to RGB. Error code:",
158+
err);
143159

144-
// Copy luma plane
145-
CUDA_MEMCPY2D m = {0};
146-
m.srcMemoryType = CU_MEMORYTYPE_DEVICE;
147-
m.srcDevice = source_frame;
148-
m.srcPitch = source_pitch;
149-
m.dstMemoryType = CU_MEMORYTYPE_DEVICE;
150-
m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr);
151-
m.dstPitch = get_width() * bytes_per_pixel;
152-
m.WidthInBytes = get_width() * bytes_per_pixel;
153-
m.Height = luma_height;
154-
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);
155-
156-
// Copy chroma plane
157-
// NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning
158-
// height
159-
m.srcDevice =
160-
(CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1));
161-
m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height);
162-
m.Height = chroma_height;
163-
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);
164-
165-
if (num_chroma_planes == 2) {
166-
m.srcDevice =
167-
(CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1) * 2);
168-
m.dstDevice =
169-
(CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height * 2);
170-
m.Height = chroma_height;
171-
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);
172-
}
173160
check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__, __FILE__);
174161
decoded_frames.push(decoded_frame);
175162
check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__);

torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -38,48 +38,8 @@ torch::Tensor GPUDecoder::decode() {
3838
return frame;
3939
}
4040

41-
/* Convert a tensor with data in NV12 format to a tensor with data in YUV420
42-
* format in-place.
43-
*/
44-
torch::Tensor GPUDecoder::nv12_to_yuv420(torch::Tensor frameTensor) {
45-
int width = decoder.get_width(), height = decoder.get_height();
46-
int pitch = width;
47-
uint8_t* frame = frameTensor.data_ptr<uint8_t>();
48-
uint8_t* ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)];
49-
50-
// sizes of source surface plane
51-
int sizePlaneY = pitch * height;
52-
int sizePlaneU = ((pitch + 1) / 2) * ((height + 1) / 2);
53-
int sizePlaneV = sizePlaneU;
54-
55-
uint8_t* uv = frame + sizePlaneY;
56-
uint8_t* u = uv;
57-
uint8_t* v = uv + sizePlaneU;
58-
59-
// split chroma from interleave to planar
60-
for (int y = 0; y < (height + 1) / 2; y++) {
61-
for (int x = 0; x < (width + 1) / 2; x++) {
62-
u[y * ((pitch + 1) / 2) + x] = uv[y * pitch + x * 2];
63-
ptr[y * ((width + 1) / 2) + x] = uv[y * pitch + x * 2 + 1];
64-
}
65-
}
66-
if (pitch == width) {
67-
memcpy(v, ptr, sizePlaneV * sizeof(uint8_t));
68-
} else {
69-
for (int i = 0; i < (height + 1) / 2; i++) {
70-
memcpy(
71-
v + ((pitch + 1) / 2) * i,
72-
ptr + ((width + 1) / 2) * i,
73-
((width + 1) / 2) * sizeof(uint8_t));
74-
}
75-
}
76-
delete[] ptr;
77-
return frameTensor;
78-
}
79-
8041
TORCH_LIBRARY(torchvision, m) {
8142
m.class_<GPUDecoder>("GPUDecoder")
8243
.def(torch::init<std::string, int64_t>())
83-
.def("next", &GPUDecoder::decode)
84-
.def("reformat", &GPUDecoder::nv12_to_yuv420);
44+
.def("next", &GPUDecoder::decode);
8545
}

torchvision/csrc/io/decoder/gpu/gpu_decoder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ class GPUDecoder : public torch::CustomClassHolder {
88
GPUDecoder(std::string, int64_t);
99
~GPUDecoder();
1010
torch::Tensor decode();
11-
torch::Tensor nv12_to_yuv420(torch::Tensor);
1211

1312
private:
1413
Demuxer demuxer;

torchvision/io/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -210,16 +210,6 @@ def set_current_stream(self, stream: str) -> bool:
210210
print("GPU decoding only works with video stream.")
211211
return self._c.set_current_stream(stream)
212212

213-
def _reformat(self, tensor, output_format: str = "yuv420"):
214-
supported_formats = [
215-
"yuv420",
216-
]
217-
if output_format not in supported_formats:
218-
raise RuntimeError(f"{output_format} not supported, please use one of {', '.join(supported_formats)}")
219-
if not isinstance(tensor, torch.Tensor):
220-
raise RuntimeError("Expected tensor as input parameter!")
221-
return self._c.reformat(tensor.cpu())
222-
223213

224214
__all__ = [
225215
"write_video",

0 commit comments

Comments
 (0)