Skip to content

Add decode_nv12 to allow convert nv12 to rgb #874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
workspace(name = "org_tensorflow_io")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
load("//third_party/toolchains/tf:tf_configure.bzl", "tf_configure")

tf_configure(name = "local_config_tf")
Expand Down Expand Up @@ -110,6 +111,13 @@ http_archive(
],
)

new_git_repository(
name = "libyuv",
build_file = "//third_party:libyuv.BUILD",
commit = "7f00d67d7c279f13b73d3be9c2d85873a7e2fbaf",
remote = "https://chromium.googlesource.com/libyuv/libyuv",
)

http_archive(
name = "libgeotiff",
build_file = "//third_party:libgeotiff.BUILD",
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_io/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,12 @@ cc_library(
"kernels/image_font_kernels.cc",
"kernels/image_hdr_kernels.cc",
"kernels/image_jpeg_kernels.cc",
"kernels/image_nv12_kernels.cc",
"kernels/image_openexr_kernels.cc",
"kernels/image_pnm_kernels.cc",
"kernels/image_tiff_kernels.cc",
"kernels/image_webp_kernels.cc",
"kernels/image_yuy2_kernels.cc",
"ops/image_ops.cc",
],
copts = tf_io_copts(),
Expand All @@ -180,6 +182,7 @@ cc_library(
"@libgeotiff",
"@libtiff",
"@libwebp",
"@libyuv",
"@openexr",
"@stb",
],
Expand Down
70 changes: 70 additions & 0 deletions tensorflow_io/core/kernels/image_nv12_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow_io/core/kernels/io_stream.h"

#include "libyuv/convert_argb.h"

namespace tensorflow {
namespace io {
namespace {

class DecodeNV12Op : public OpKernel {
public:
explicit DecodeNV12Op(OpKernelConstruction* context) : OpKernel(context) {
env_ = context->env();
}

void Compute(OpKernelContext* context) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));

const Tensor* size_tensor;
OP_REQUIRES_OK(context, context->input("size", &size_tensor));

const tstring& input = input_tensor->scalar<tstring>()();

int64 channels = 3;
int64 height = size_tensor->flat<int32>()(0);
int64 width = size_tensor->flat<int32>()(1);

Tensor* image_tensor = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(
0, TensorShape({height, width, channels}), &image_tensor));
uint8* rgb = image_tensor->flat<uint8>().data();

uint8* y = (uint8*)&input[0];
uint8* uv = (uint8*)&input[width * height];
uint32 y_stride = width;
uint32 uv_stride = width;
uint32 rgb_stride = width * 3;
int status = libyuv::NV12ToRAW(y, y_stride, uv, uv_stride, rgb, rgb_stride,
width, height);
OP_REQUIRES(
context, (status == 0),
errors::InvalidArgument("unable to convert nv12 to rgb: ", status));
}

private:
mutex mu_;
Env* env_ GUARDED_BY(mu_);
};
REGISTER_KERNEL_BUILDER(Name("IO>DecodeNV12").Device(DEVICE_CPU), DecodeNV12Op);

} // namespace
} // namespace io
} // namespace tensorflow
79 changes: 79 additions & 0 deletions tensorflow_io/core/kernels/image_yuy2_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow_io/core/kernels/io_stream.h"

#include "libyuv/convert_argb.h"
#include "libyuv/convert_from_argb.h"

namespace tensorflow {
namespace io {
namespace {

class DecodeYUY2Op : public OpKernel {
public:
explicit DecodeYUY2Op(OpKernelConstruction* context) : OpKernel(context) {
env_ = context->env();
}

void Compute(OpKernelContext* context) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));

const Tensor* size_tensor;
OP_REQUIRES_OK(context, context->input("size", &size_tensor));

const tstring& input = input_tensor->scalar<tstring>()();

int64 channels = 3;
int64 height = size_tensor->flat<int32>()(0);
int64 width = size_tensor->flat<int32>()(1);

Tensor* image_tensor = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(
0, TensorShape({height, width, channels}), &image_tensor));

string buffer;
buffer.resize(width * height * 4);
uint8* argb = (uint8*)&buffer[0];
uint8* yuy2 = (uint8*)&input[0];
uint32 yuy2_stride = width * 2;
uint32 argb_stride = width * 4;
int status =
libyuv::YUY2ToARGB(yuy2, yuy2_stride, argb, argb_stride, width, height);
OP_REQUIRES(
context, (status == 0),
errors::InvalidArgument("unable to convert yuy2 to argb: ", status));

uint8* rgb = image_tensor->flat<uint8>().data();
uint32 rgb_stride = width * 3;
status =
libyuv::ARGBToRAW(argb, argb_stride, rgb, rgb_stride, width, height);
OP_REQUIRES(
context, (status == 0),
errors::InvalidArgument("unable to convert argb to rgb: ", status));
}

private:
mutex mu_;
Env* env_ GUARDED_BY(mu_);
};
REGISTER_KERNEL_BUILDER(Name("IO>DecodeYUY2").Device(DEVICE_CPU), DecodeYUY2Op);

} // namespace
} // namespace io
} // namespace tensorflow
22 changes: 22 additions & 0 deletions tensorflow_io/core/ops/image_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,28 @@ REGISTER_OP("IO>DecodeDICOMData")
loads a dicom file and returns the specified tags values as string.
)doc");

REGISTER_OP("IO>DecodeNV12")
.Input("input: string")
.Input("size: int32")
.Output("image: uint8")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
c->set_output(0, c->MakeShape({c->UnknownDim(), c->UnknownDim(), 3}));
return Status::OK();
});

REGISTER_OP("IO>DecodeYUY2")
.Input("input: string")
.Input("size: int32")
.Output("image: uint8")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
c->set_output(0, c->MakeShape({c->UnknownDim(), c->UnknownDim(), 3}));
return Status::OK();
});

} // namespace
} // namespace io
} // namespace tensorflow
2 changes: 2 additions & 0 deletions tensorflow_io/core/python/api/experimental/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@
decode_exr,
decode_pnm,
decode_hdr,
decode_nv12,
decode_yuy2,
)
34 changes: 33 additions & 1 deletion tensorflow_io/core/python/experimental/image_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def decode_pnm(contents, dtype=tf.uint8, name=None):

def decode_hdr(contents, name=None):
"""
Decode a HDR-encoded image to a uint8 tensor.
Decode a HDR-encoded image to a tf.float tensor.

Args:
contents: A `Tensor` of type `string`. 0-D. The HDR-encoded image.
Expand All @@ -147,3 +147,35 @@ def decode_hdr(contents, name=None):
A `Tensor` of type `float` and shape of `[height, width, 3]` (RGB).
"""
return core_ops.io_decode_hdr(contents, name=name)


def decode_nv12(contents, size, name=None):
"""
Decode a NV12-encoded image to a uint8 tensor.

Args:
contents: A `Tensor` of type `string`. 0-D. The NV12-encoded image.
size: A 1-D int32 Tensor of 2 elements: height, width. The size
for the images.
name: A name for the operation (optional).

Returns:
A `Tensor` of type `uint8` and shape of `[height, width, 3]` (RGB).
"""
return core_ops.io_decode_nv12(contents, size=size, name=name)


def decode_yuy2(contents, size, name=None):
"""
Decode a YUY2-encoded image to a uint8 tensor.

Args:
contents: A `Tensor` of type `string`. 0-D. The YUY2-encoded image.
size: A 1-D int32 Tensor of 2 elements: height, width. The size
for the images.
name: A name for the operation (optional).

Returns:
A `Tensor` of type `uint8` and shape of `[height, width, 3]` (RGB).
"""
return core_ops.io_decode_yuy2(contents, size=size, name=name)
1 change: 1 addition & 0 deletions tests/test_image/Jelly-Beans.nv12

Large diffs are not rendered by default.

Binary file added tests/test_image/Jelly-Beans.nv12.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/test_image/Jelly-Beans.tiff
Binary file not shown.
1 change: 1 addition & 0 deletions tests/test_image/Jelly-Beans.yuy2

Large diffs are not rendered by default.

Binary file added tests/test_image/Jelly-Beans.yuy2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 34 additions & 0 deletions tests/test_image_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,5 +318,39 @@ def test_decode_tiff_geotiff():
assert np.all(png_image.numpy() == image.numpy())


def test_decode_nv12():
"""Test case for decode_nv12"""
filename = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_image", "Jelly-Beans.nv12"
)
png_filename = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_image", "Jelly-Beans.nv12.png"
)
png = tf.image.decode_png(tf.io.read_file(png_filename))

contents = tf.io.read_file(filename)
rgb = tfio.experimental.image.decode_nv12(contents, size=[256, 256])
assert rgb.dtype == tf.uint8
assert rgb.shape == [256, 256, 3]
assert np.all(rgb == png)


def test_decode_yuy2():
"""Test case for decode_yuy2"""
filename = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_image", "Jelly-Beans.yuy2"
)
png_filename = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_image", "Jelly-Beans.yuy2.png"
)
png = tf.image.decode_png(tf.io.read_file(png_filename))

contents = tf.io.read_file(filename)
rgb = tfio.experimental.image.decode_yuy2(contents, size=[256, 256])
assert rgb.dtype == tf.uint8
assert rgb.shape == [256, 256, 3]
assert np.all(rgb == png)


if __name__ == "__main__":
test.main()
24 changes: 24 additions & 0 deletions third_party/libyuv.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Description:
# libyuv library from Chromium

licenses(["notice"])

exports_files(["LICENSE"])

cc_library(
name = "libyuv",
srcs = glob([
"include/libyuv/*.h",
"source/row_*.cc",
"source/scale_*.cc",
]) + [
"source/convert_argb.cc",
"source/convert_from_argb.cc",
"source/cpu_id.cc",
"source/planar_functions.cc",
],
includes = [
"include",
],
visibility = ["//visibility:public"],
)