From 4d8872cde40aa0703703b4fc148c8e78f6e52561 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 1 Mar 2020 12:07:13 -0800 Subject: [PATCH 1/5] Add Video Capture Support for macOS through AVFoundation/Swift This PR is part of the effort in resolving 814. In 814, the feature request is to add video capture support for Linux, likely through Video4Linux. Due to some limitations Video4Linux will need a compatible USB camera first. This PR, instead tries to resolve the featue requrest on macOS first. On macOS the built-in camera could be accessed through AVFoundation's Swift API. This PR uses Swift to access AVCaptureSession/etc, and exported to C function (`cdecl`) so that it could be used in C++ kernel in tensorflow-io. Since macOS's raw video capture format is NV12 (kCVPixelFormatType_420YpCbCr8BiPlanarVideoRange) additional work is needed to convert NV12 into RGB format, so that a whole pipeline could be built up to allow using video capture for tf.keras' inference. This PR does not resolve the NV12 => RGB yet. Will address in separate PRs. Also, since video capture is technically a continuous stream and is not repeatable, it is not possible to train based on video capture with multiple epochs. Finally, the following is a sample usage which takes video capture and saves as nv12 raw file. The NV12 raw file could be checked by using ffmpeg to convert to JPEG to validate. Note: the following is a validation YUV image could be converted to JPEG with: ``` ffmpeg -s 1280x720 -pix_fmt nv12 -i frame_{i}.yuv frame_{i}.jpg ``` Usage: ``` dataset = tfio.experimental.IODataset.stream().from_video_capture( "device").take(5) i = 0 for frame in dataset: print("Frame {}: shape({}) dtype({}) length({})".format( i, frame.shape, frame.dtype, tf.strings.length(frame))) tf.io.write_file("frame_{}.yuv".format(i), frame) i += 1 ``` Signed-off-by: Yong Tang --- tensorflow_io/core/BUILD | 21 +++ tensorflow_io/core/kernels/video_kernels.cc | 152 +++++++++++++++++ tensorflow_io/core/ops/video_ops.cc | 45 +++++ .../python/experimental/io_dataset_ops.py | 16 ++ .../python/experimental/video_dataset_ops.py | 54 ++++++ tensorflow_io/core/swift/video.swift | 161 ++++++++++++++++++ tests/test_io_dataset_eager.py | 48 ++++++ tools/build/swift/BUILD | 12 ++ 8 files changed, 509 insertions(+) create mode 100644 tensorflow_io/core/kernels/video_kernels.cc create mode 100644 tensorflow_io/core/ops/video_ops.cc create mode 100644 tensorflow_io/core/python/experimental/video_dataset_ops.py create mode 100644 tensorflow_io/core/swift/video.swift diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index c6728182c..48287718b 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -187,6 +187,7 @@ cc_library( exports_files([ "swift/audio.swift", + "swift/video.swift", ]) cc_library( @@ -219,6 +220,25 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "video_ops", + srcs = [ + "kernels/video_kernels.cc", + "ops/video_ops.cc", + ], + copts = tf_io_copts(), + linkstatic = True, + deps = [ + "//tensorflow_io/core:dataset_ops", + ] + select({ + "@bazel_tools//src/conditions:darwin": [ + "//tools/build/swift:video_swift", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + cc_library( name = "ffmpeg_3.4_ops", srcs = [ @@ -542,6 +562,7 @@ cc_binary( "//tensorflow_io/core:serialization_ops", "//tensorflow_io/core:sql_ops", "//tensorflow_io/core:text_ops", + "//tensorflow_io/core:video_ops", "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", ] + select({ diff --git a/tensorflow_io/core/kernels/video_kernels.cc b/tensorflow_io/core/kernels/video_kernels.cc new file mode 100644 index 000000000..8cb44fdd3 --- /dev/null +++ b/tensorflow_io/core/kernels/video_kernels.cc @@ -0,0 +1,152 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_op_kernel.h" + +extern "C" { +#if defined(__APPLE__) +void* VideoCaptureInitFunction(int64_t* bytes, int64_t* width, int64_t* height); +void VideoCaptureNextFunction(void* context, void* data, int64_t size); +void VideoCaptureFiniFunction(void* context); +#else +void* VideoCaptureInitFunction(int64_t* bytes, int64_t* width, + int64_t* height) { + return NULL; +} +void VideoCaptureNextFunction(void* context, void* data, int64_t size) {} +void VideoCaptureFiniFunction(void* context) {} +#endif +} +namespace tensorflow { +namespace data { +namespace { + +class VideoCaptureReadableResource : public ResourceBase { + public: + VideoCaptureReadableResource(Env* env) + : env_(env), context_(nullptr, [](void* p) { + if (p != nullptr) { + VideoCaptureFiniFunction(p); + } + }) {} + ~VideoCaptureReadableResource() {} + + Status Init(const string& input) { + mutex_lock l(mu_); + + int64_t bytes, width, height; + context_.reset(VideoCaptureInitFunction(&bytes, &width, &height)); + if (context_.get() == nullptr) { + return errors::InvalidArgument("unable to open device ", input); + } + bytes_ = static_cast(bytes); + width_ = static_cast(width); + height_ = static_cast(height); + return Status::OK(); + } + Status Read( + std::function + allocate_func) { + mutex_lock l(mu_); + + Tensor* value_tensor; + TF_RETURN_IF_ERROR(allocate_func(TensorShape({1}), &value_tensor)); + + string buffer; + buffer.resize(bytes_); + VideoCaptureNextFunction(context_.get(), (void*)&buffer[0], + static_cast(bytes_)); + value_tensor->flat()(0) = buffer; + + return Status::OK(); + } + string DebugString() const override { + mutex_lock l(mu_); + return "VideoCaptureReadableResource"; + } + + protected: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + + std::unique_ptr context_; + int64 bytes_; + int64 width_; + int64 height_; +}; + +class VideoCaptureReadableInitOp + : public ResourceOpKernel { + public: + explicit VideoCaptureReadableInitOp(OpKernelConstruction* context) + : ResourceOpKernel(context) { + env_ = context->env(); + } + + private: + void Compute(OpKernelContext* context) override { + ResourceOpKernel::Compute(context); + + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("input", &input_tensor)); + const string& input = input_tensor->scalar()(); + + OP_REQUIRES_OK(context, resource_->Init(input)); + } + Status CreateResource(VideoCaptureReadableResource** resource) + EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + *resource = new VideoCaptureReadableResource(env_); + return Status::OK(); + } + + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +class VideoCaptureReadableReadOp : public OpKernel { + public: + explicit VideoCaptureReadableReadOp(OpKernelConstruction* context) + : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + VideoCaptureReadableResource* resource; + OP_REQUIRES_OK(context, + GetResourceFromContext(context, "input", &resource)); + core::ScopedUnref unref(resource); + + OP_REQUIRES_OK( + context, resource->Read([&](const TensorShape& shape, + Tensor** value_tensor) -> Status { + TF_RETURN_IF_ERROR(context->allocate_output(0, shape, value_tensor)); + return Status::OK(); + })); + } + + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; +REGISTER_KERNEL_BUILDER(Name("IO>VideoCaptureReadableInit").Device(DEVICE_CPU), + VideoCaptureReadableInitOp); +REGISTER_KERNEL_BUILDER(Name("IO>VideoCaptureReadableRead").Device(DEVICE_CPU), + VideoCaptureReadableReadOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/ops/video_ops.cc b/tensorflow_io/core/ops/video_ops.cc new file mode 100644 index 000000000..3e5539956 --- /dev/null +++ b/tensorflow_io/core/ops/video_ops.cc @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +namespace io { +namespace { + +REGISTER_OP("IO>VideoCaptureReadableInit") + .Input("input: string") + .Output("resource: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("IO>VideoCaptureReadableRead") + .Input("input: resource") + .Input("index: int64") + .Output("value: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + +} // namespace +} // namespace io +} // namespace tensorflow diff --git a/tensorflow_io/core/python/experimental/io_dataset_ops.py b/tensorflow_io/core/python/experimental/io_dataset_ops.py index a694509f6..5786e363e 100644 --- a/tensorflow_io/core/python/experimental/io_dataset_ops.py +++ b/tensorflow_io/core/python/experimental/io_dataset_ops.py @@ -28,6 +28,7 @@ from tensorflow_io.core.python.experimental import file_dataset_ops from tensorflow_io.core.python.experimental import numpy_dataset_ops from tensorflow_io.core.python.experimental import sql_dataset_ops +from tensorflow_io.core.python.experimental import video_dataset_ops class IODataset(io_dataset.IODataset): """IODataset""" @@ -269,6 +270,21 @@ def to_file(cls, class StreamIODataset(tf.data.Dataset): """StreamIODataset""" + @classmethod + def from_video_capture(cls, device, **kwargs): + """Creates an `StreamIODataset` from video capture device. + + Args: + device: A string, the name of the device. + name: A name prefix for the IODataset (optional). + + Returns: + A `IODataset`. + """ + with tf.name_scope(kwargs.get("name", "IOFromVideoCapture")): + return video_dataset_ops.VideoCaptureIODataset( + device, internal=True) + @classmethod def from_prometheus_scrape(cls, metric, diff --git a/tensorflow_io/core/python/experimental/video_dataset_ops.py b/tensorflow_io/core/python/experimental/video_dataset_ops.py new file mode 100644 index 000000000..403c3d55d --- /dev/null +++ b/tensorflow_io/core/python/experimental/video_dataset_ops.py @@ -0,0 +1,54 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""VideoCaptureDataset""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_io.core.python.ops import core_ops + +class VideoCaptureIODataset(tf.data.Dataset): + """VideoCaptureIODataset""" + + def __init__(self, + device, + internal=True): + """VideoCaptureIODataset""" + with tf.name_scope("VideoCaptureIODataset"): + assert internal + + resource = core_ops.io_video_capture_readable_init(device) + + self._resource = resource + + dataset = tf.data.experimental.Counter() + dataset = dataset.map( + lambda i: core_ops.io_video_capture_readable_read(self._resource, i)) + dataset = dataset.apply( + tf.data.experimental.take_while( + lambda v: tf.greater(tf.shape(v)[0], 0))) + dataset = dataset.unbatch() + + self._dataset = dataset + super(VideoCaptureIODataset, self).__init__( + self._dataset._variant_tensor) # pylint: disable=protected-access + + def _inputs(self): + return [] + + @property + def element_spec(self): + return self._dataset.element_spec diff --git a/tensorflow_io/core/swift/video.swift b/tensorflow_io/core/swift/video.swift new file mode 100644 index 000000000..bc87c9ce4 --- /dev/null +++ b/tensorflow_io/core/swift/video.swift @@ -0,0 +1,161 @@ +import AVFoundation + +class VideoDataOutputSampleBufferDelegate : NSObject, AVCaptureVideoDataOutputSampleBufferDelegate { + + var bytes: Int64 + var width: Int64 + var height: Int64 + var copied: Int64 + var buffer: UnsafeMutableRawPointer? + var semaphore_in: DispatchSemaphore + var semaphore_out: DispatchSemaphore + + init(semaphore_in: DispatchSemaphore, semaphore_out: DispatchSemaphore) { + self.bytes = 0 + self.width = 0 + self.height = 0 + self.copied = 0 + self.buffer = nil + self.semaphore_in = semaphore_in + self.semaphore_out = semaphore_out + super.init() + } + + deinit { + // TODO: This is not invoked, memory leak? + print("VideoDataOutputSampleBufferDelegate.deinit") + } + + func captureOutput(_ output: AVCaptureOutput, didDrop sampleBuffer: CMSampleBuffer, from connection: AVCaptureConnection) { + + print("frame dropped: \(sampleBuffer)") + } + + func captureOutput(_ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, from connection: AVCaptureConnection) { + + semaphore_in.wait() + + defer { semaphore_out.signal() } + + if sampleBuffer.numSamples != 1 { + print("number of samples \(sampleBuffer.numSamples) is not supported") + return + } + + let pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer) + + let pixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer!) + let planeCount = CVPixelBufferGetPlaneCount(pixelBuffer!) + + if pixelFormat != kCVPixelFormatType_420YpCbCr8BiPlanarVideoRange || planeCount != 2 { + print("PixelFormat \(pixelFormat) or PlaneCount \(planeCount) is not supported") + return + } + + let bytes = Int64(CVPixelBufferGetBytesPerRowOfPlane(pixelBuffer!, 0) * CVPixelBufferGetHeightOfPlane(pixelBuffer!, 0) + CVPixelBufferGetBytesPerRowOfPlane(pixelBuffer!, 1) * CVPixelBufferGetHeightOfPlane(pixelBuffer!, 1)) + let width = Int64(CVPixelBufferGetWidth(pixelBuffer!)) + let height = Int64(CVPixelBufferGetHeight(pixelBuffer!)) + + if (self.bytes == 0 || self.bytes == 0 || self.height == 0) { + self.bytes = bytes + self.width = width + self.height = height + } else if (self.bytes != bytes || self.width != width || self.height != height) { + print("Bytes \(bytes) vs. \(self.bytes), Width \(width) vs. \(self.width), Height \(height) vs. \(self.height)") + return + } + if (self.buffer != nil) { + CVPixelBufferLockBaseAddress(pixelBuffer!, CVPixelBufferLockFlags(rawValue: 0)) + + let baseAddress0 = CVPixelBufferGetBaseAddressOfPlane(pixelBuffer!, 0) + let bytesPerRow0 = CVPixelBufferGetBytesPerRowOfPlane(pixelBuffer!, 0) + let heightOfPlane0 = CVPixelBufferGetHeightOfPlane(pixelBuffer!, 0) + self.buffer!.copyMemory(from: baseAddress0!, byteCount: bytesPerRow0 * heightOfPlane0) + + let baseAddress1 = CVPixelBufferGetBaseAddressOfPlane(pixelBuffer!, 1) + let bytesPerRow1 = CVPixelBufferGetBytesPerRowOfPlane(pixelBuffer!, 1) + let heightOfPlane1 = CVPixelBufferGetHeightOfPlane(pixelBuffer!, 1) + + self.buffer!.advanced(by: bytesPerRow0 * Int(height)).copyMemory(from: baseAddress1!, byteCount: bytesPerRow1 * heightOfPlane1) + + CVPixelBufferUnlockBaseAddress(pixelBuffer!, CVPixelBufferLockFlags(rawValue: 0)) + + self.copied = Int64(bytesPerRow0 * heightOfPlane0 + bytesPerRow1 * heightOfPlane1) + } + } +} + +typealias VideoContext = (session: AVCaptureSession, semaphore_in: DispatchSemaphore, semaphore_out: DispatchSemaphore, delegate: VideoDataOutputSampleBufferDelegate) + +@_silgen_name("VideoCaptureInitFunction") +func VideoCaptureInitFunction(bytes: UnsafeMutablePointer, width: UnsafeMutablePointer, height: UnsafeMutablePointer) -> UnsafeMutablePointer? { + + let session = AVCaptureSession() + let semaphore_in = DispatchSemaphore(value: 0) + let semaphore_out = DispatchSemaphore(value: 0) + let sampleBufferDelegate = VideoDataOutputSampleBufferDelegate(semaphore_in: semaphore_in, semaphore_out: semaphore_out) + + do { + let device = AVCaptureDevice.default(for: .video) + let deviceInput = try AVCaptureDeviceInput(device: device!) + + session.addInput(deviceInput) + } catch { + return nil + } + + let queue = DispatchQueue(label: "VideoDataOutput", attributes: []) + let output = AVCaptureVideoDataOutput() + output.videoSettings = [:] + output.alwaysDiscardsLateVideoFrames = true + output.setSampleBufferDelegate(sampleBufferDelegate, queue: queue) + + session.addOutput(output) + session.commitConfiguration() + session.startRunning() + + // Obtain the first frame to get the information + semaphore_in.signal() + semaphore_out.wait() + + if (sampleBufferDelegate.bytes == 0 || sampleBufferDelegate.width == 0 || sampleBufferDelegate.height == 0) { + return nil + } + bytes.pointee = sampleBufferDelegate.bytes + width.pointee = sampleBufferDelegate.width + height.pointee = sampleBufferDelegate.height + + let context = UnsafeMutablePointer.allocate(capacity: 1) + context.initialize(to: (session: session, semaphore_in: semaphore_in, semaphore_out: semaphore_out, delegate: sampleBufferDelegate)) + + return context +} + +@_silgen_name("VideoCaptureNextFunction") +func VideoCaptureNextFunction(context: UnsafeMutablePointer, data: UnsafeMutableRawPointer, size: Int64) -> Void { + if context != nil { + if (size < context.pointee.delegate.bytes) { + print("not enough buffer to copy: \(size) vs. \(context.pointee.delegate.bytes)") + return + } + context.pointee.delegate.buffer = data + context.pointee.delegate.copied = 0 + context.pointee.semaphore_in.signal() + context.pointee.semaphore_out.wait() + context.pointee.delegate.buffer = nil + if context.pointee.delegate.copied != context.pointee.delegate.bytes { + print("not enough buffer copied: \(context.pointee.delegate.copied) vs. \(context.pointee.delegate.bytes)") + } + context.pointee.delegate.copied = 0 + return + } +} + +@_silgen_name("VideoCaptureFiniFunction") +func VideoCaptureFiniFunction(context: UnsafeMutablePointer) -> Void { + if context != nil { + context.pointee.session.stopRunning() + context.deinitialize(count: 1) + context.deallocate() + } +} diff --git a/tests/test_io_dataset_eager.py b/tests/test_io_dataset_eager.py index 6b6cc6fc6..094b347f2 100644 --- a/tests/test_io_dataset_eager.py +++ b/tests/test_io_dataset_eager.py @@ -805,6 +805,36 @@ def func(q): return args, func, expected +# video capture stream never repeat so +# we only test basic operation only. +@pytest.fixture(name="video_capture") +def fixture_video_capture(): + """fixture_video_capture + # Note: the following is a validation + # YUV image could be converted to JPEG with: + # ffmpeg -s 1280x720 -pix_fmt nv12 -i frame_{i}.yuv frame_{i}.jpg + dataset = tfio.experimental.IODataset.stream().from_video_capture( + "device").take(5) + i = 0 + for frame in dataset: + print("Frame {}: shape({}) dtype({}) length({})".format( + i, frame.shape, frame.dtype, tf.strings.length(frame))) + tf.io.write_file("frame_{}.yuv".format(i), frame) + i += 1 + """ + + args = "device" + def func(q): + dataset = tfio.experimental.IODataset.stream().from_video_capture( + q) + dataset = dataset.map(tf.strings.length) + dataset = dataset.take(10) + return dataset + # 1382400 = (1280 + 1280 / 2) * 720 + expected = [1382400 for _ in range(10)] + + return args, func, expected + # This test make sure dataset works in tf.keras inference. # The requirement for tf.keras inference is the support of `iter()`: # entries = [e for e in dataset] @@ -868,6 +898,14 @@ def func(q): reason="TODO PostgreSQL not tested on macOS/Windows"), ], ), + pytest.param( + "video_capture", + marks=[ + pytest.mark.skipif( + os.environ.get("TEST_VIDEO_CAPTURE", "") != "true", + reason="Video capture not enabled"), + ], + ), ], ids=[ "mnist", @@ -892,6 +930,7 @@ def func(q): "kafka[avro]", "kafka[stream]", "sql", + "capture[video]", ], ) def test_io_dataset_basic(fixture_lookup, io_dataset_fixture): @@ -966,6 +1005,14 @@ def test_io_dataset_basic(fixture_lookup, io_dataset_fixture): reason="TODO PostgreSQL not tested on macOS/Windows"), ], ), + pytest.param( + "video_capture", + marks=[ + pytest.mark.skipif( + os.environ.get("TEST_VIDEO_CAPTURE", "") != "true", + reason="Video capture not enabled"), + ], + ), ], ids=[ "mnist", @@ -988,6 +1035,7 @@ def test_io_dataset_basic(fixture_lookup, io_dataset_fixture): "kafka[avro]", "kafka[stream]", "sql", + "capture[video]", ], ) def test_io_dataset_basic_operation(fixture_lookup, io_dataset_fixture): diff --git a/tools/build/swift/BUILD b/tools/build/swift/BUILD index c65ee403a..31b7d9427 100644 --- a/tools/build/swift/BUILD +++ b/tools/build/swift/BUILD @@ -13,3 +13,15 @@ swift_library( module_name = "audio", alwayslink = True, ) + +swift_library( + name = "video_swift", + srcs = [ + "//tensorflow_io/core:swift/video.swift", + ], + linkopts = [ + "-L/usr/lib/swift", + ], + module_name = "video", + alwayslink = True, +) From b3ffc714f8105f7d9917295cadbfa69ae5acd647 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 2 Mar 2020 15:29:25 -0800 Subject: [PATCH 2/5] Add Video4Linux V2 support on Linux Signed-off-by: Yong Tang --- tensorflow_io/core/kernels/video_kernels.cc | 37 +++- tensorflow_io/core/kernels/video_kernels.h | 183 ++++++++++++++++++++ tests/test_io_dataset_eager.py | 9 +- 3 files changed, 223 insertions(+), 6 deletions(-) create mode 100644 tensorflow_io/core/kernels/video_kernels.h diff --git a/tensorflow_io/core/kernels/video_kernels.cc b/tensorflow_io/core/kernels/video_kernels.cc index 8cb44fdd3..e7a2bafb9 100644 --- a/tensorflow_io/core/kernels/video_kernels.cc +++ b/tensorflow_io/core/kernels/video_kernels.cc @@ -13,21 +13,52 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow_io/core/kernels/video_kernels.h" extern "C" { #if defined(__APPLE__) void* VideoCaptureInitFunction(int64_t* bytes, int64_t* width, int64_t* height); void VideoCaptureNextFunction(void* context, void* data, int64_t size); void VideoCaptureFiniFunction(void* context); -#else +#elif defined(_MSV_VER) void* VideoCaptureInitFunction(int64_t* bytes, int64_t* width, int64_t* height) { return NULL; } void VideoCaptureNextFunction(void* context, void* data, int64_t size) {} void VideoCaptureFiniFunction(void* context) {} +#else +void* VideoCaptureInitFunction(int64_t* bytes, int64_t* width, + int64_t* height) { + tensorflow::data::VideoCaptureContext* p = + new tensorflow::data::VideoCaptureContext(); + if (p != nullptr) { + tensorflow::Status status = p->Init("/dev/video0", bytes, width, height); + if (status.ok()) { + return p; + } + LOG(ERROR) << "unable to initialize video capture: " << status; + delete p; + } + return NULL; +} +void VideoCaptureNextFunction(void* context, void* data, int64_t size) { + tensorflow::data::VideoCaptureContext* p = + static_cast(context); + if (p != nullptr) { + tensorflow::Status status = p->Read(data, size); + if (!status.ok()) { + LOG(ERROR) << "unable to read video capture: " << status; + } + } +} +void VideoCaptureFiniFunction(void* context) { + tensorflow::data::VideoCaptureContext* p = + static_cast(context); + if (p != nullptr) { + delete p; + } +} #endif } namespace tensorflow { diff --git a/tensorflow_io/core/kernels/video_kernels.h b/tensorflow_io/core/kernels/video_kernels.h new file mode 100644 index 000000000..a3e02a3b4 --- /dev/null +++ b/tensorflow_io/core/kernels/video_kernels.h @@ -0,0 +1,183 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_op_kernel.h" + +#if defined(__linux__) + +#include +#include +#include +#include + +#include + +static int xioctl(int fh, int request, void* arg) { + int r; + + do { + r = ioctl(fh, request, arg); + } while (-1 == r && EINTR == errno); + + return r; +} +namespace tensorflow { +namespace data { + +class VideoCaptureContext { + public: + VideoCaptureContext() + : context_(nullptr, + [](void* p) { + if (p != nullptr) { + free(p); + } + }), + fd_scope_(nullptr, [](int* p) { + if (p != nullptr) { + close(*p); + } + }) {} + ~VideoCaptureContext() {} + + Status Init(const string& device, int64_t* bytes, int64_t* width, + int64_t* height) { + device_ = device; + + const char* devname = device.c_str(); + struct stat st; + if (-1 == stat(devname, &st)) { + return errors::InvalidArgument("cannot identify '", devname, "': ", errno, + ", ", strerror(errno)); + } + + if (!S_ISCHR(st.st_mode)) { + return errors::InvalidArgument(devname, " is no device"); + } + + fd_ = open(devname, O_RDWR /* required */ | O_NONBLOCK, 0); + if (-1 == fd_) { + return errors::InvalidArgument("cannot open '", devname, "': ", errno, + ", ", strerror(errno)); + } + fd_scope_.reset(&fd_); + + struct v4l2_capability cap; + if (-1 == xioctl(fd_, VIDIOC_QUERYCAP, &cap)) { + if (EINVAL == errno) { + return errors::InvalidArgument(devname, " is no V4L2 device"); + } else { + return errors::InvalidArgument("cannot VIDIOC_QUERYCAP '", devname, + "': ", errno, ", ", strerror(errno)); + } + } + + if (!(cap.capabilities & V4L2_CAP_VIDEO_CAPTURE)) { + return errors::InvalidArgument(devname, " is no video capture device"); + } + + if (!(cap.capabilities & V4L2_CAP_READWRITE)) { + return errors::InvalidArgument(devname, " does not support read i/o"); + } + + struct v4l2_format fmt; + memset(&(fmt), 0, sizeof(fmt)); + fmt.type = V4L2_BUF_TYPE_VIDEO_CAPTURE; + if (-1 == xioctl(fd_, VIDIOC_G_FMT, &fmt)) { + return errors::InvalidArgument("cannot VIDIOC_G_FMT '", devname, + "': ", errno, ", ", strerror(errno)); + } + + /* Buggy driver paranoia. */ + { + unsigned int min; + min = fmt.fmt.pix.width * 2; + if (fmt.fmt.pix.bytesperline < min) { + fmt.fmt.pix.bytesperline = min; + } + min = fmt.fmt.pix.bytesperline * fmt.fmt.pix.height; + if (fmt.fmt.pix.sizeimage < min) { + fmt.fmt.pix.sizeimage = min; + } + } + + if (fmt.fmt.pix.pixelformat != V4L2_PIX_FMT_YUYV) { + return errors::InvalidArgument( + "only V4L2_PIX_FMT_YUYV is supported, received ", + fmt.fmt.pix.pixelformat); + } + + *bytes = fmt.fmt.pix.sizeimage; + *width = fmt.fmt.pix.width; + *height = fmt.fmt.pix.height; + + return Status::OK(); + } + Status Read(void* data, size_t size) { + do { + fd_set fds; + struct timeval tv; + int r; + + FD_ZERO(&fds); + FD_SET(fd_, &fds); + + /* Timeout. */ + tv.tv_sec = 2; + tv.tv_usec = 0; + r = select(fd_ + 1, &fds, NULL, NULL, &tv); + + if (-1 == r) { + if (EINTR == errno) { + continue; + } + return errors::InvalidArgument("cannot select: ", errno, ", ", + strerror(errno)); + } + if (0 == r) { + return errors::InvalidArgument("select timeout"); + } + + if (-1 == read(fd_, data, size)) { + if (EAGAIN == errno) { + /* EAGAIN - continue select loop. */ + continue; + } + if (EIO == errno) { + /* Could ignore EIO, see spec. */ + /* fall through */ + } + return errors::InvalidArgument("cannot read: ", errno, ", ", + strerror(errno)); + } + // Data Obtained, break + break; + } while (true); + return Status::OK(); + } + + protected: + mutable mutex mu_; + + std::unique_ptr context_; + std::unique_ptr fd_scope_; + string device_; + int fd_; +}; + +} // namespace data +} // namespace tensorflow +#endif diff --git a/tests/test_io_dataset_eager.py b/tests/test_io_dataset_eager.py index 094b347f2..28a285e65 100644 --- a/tests/test_io_dataset_eager.py +++ b/tests/test_io_dataset_eager.py @@ -812,7 +812,8 @@ def fixture_video_capture(): """fixture_video_capture # Note: the following is a validation # YUV image could be converted to JPEG with: - # ffmpeg -s 1280x720 -pix_fmt nv12 -i frame_{i}.yuv frame_{i}.jpg + # macOS: ffmpeg -s 1280x720 -pix_fmt nv12 -i frame_{i}.yuv frame_{i}.jpg + # Linux: ffmpeg -s 320x240 -pix_fmt yuyv422 -i frame_{i}.yuv frame_{i}.jpg dataset = tfio.experimental.IODataset.stream().from_video_capture( "device").take(5) i = 0 @@ -830,8 +831,10 @@ def func(q): dataset = dataset.map(tf.strings.length) dataset = dataset.take(10) return dataset - # 1382400 = (1280 + 1280 / 2) * 720 - expected = [1382400 for _ in range(10)] + # macOS (NV12): 1382400 = (1280 + 1280 / 2) * 720 + # Linux (YUYV): 153600 = 320 * 240 * 2 + value = 1382400 if sys.platform == "darwin" else 153600 + expected = [value for _ in range(10)] return args, func, expected From c86a08e772e0114cfbae318de0b3444270e1e7f9 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 2 Mar 2020 15:46:02 -0800 Subject: [PATCH 3/5] Update to use device name in API calls Signed-off-by: Yong Tang --- tensorflow_io/core/BUILD | 1 + tensorflow_io/core/kernels/video_kernels.cc | 16 +++++++++------- tensorflow_io/core/swift/video.swift | 4 +++- tests/test_io_dataset_eager.py | 2 +- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 48287718b..0306a8c06 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -224,6 +224,7 @@ cc_library( name = "video_ops", srcs = [ "kernels/video_kernels.cc", + "kernels/video_kernels.h", "ops/video_ops.cc", ], copts = tf_io_copts(), diff --git a/tensorflow_io/core/kernels/video_kernels.cc b/tensorflow_io/core/kernels/video_kernels.cc index e7a2bafb9..dc007bf52 100644 --- a/tensorflow_io/core/kernels/video_kernels.cc +++ b/tensorflow_io/core/kernels/video_kernels.cc @@ -17,23 +17,24 @@ limitations under the License. extern "C" { #if defined(__APPLE__) -void* VideoCaptureInitFunction(int64_t* bytes, int64_t* width, int64_t* height); +void* VideoCaptureInitFunction(const char* device, int64_t* bytes, + int64_t* width, int64_t* height); void VideoCaptureNextFunction(void* context, void* data, int64_t size); void VideoCaptureFiniFunction(void* context); #elif defined(_MSV_VER) -void* VideoCaptureInitFunction(int64_t* bytes, int64_t* width, - int64_t* height) { +void* VideoCaptureInitFunction(const char* device, int64_t* bytes, + int64_t* width, int64_t* height) { return NULL; } void VideoCaptureNextFunction(void* context, void* data, int64_t size) {} void VideoCaptureFiniFunction(void* context) {} #else -void* VideoCaptureInitFunction(int64_t* bytes, int64_t* width, - int64_t* height) { +void* VideoCaptureInitFunction(const char* device, int64_t* bytes, + int64_t* width, int64_t* height) { tensorflow::data::VideoCaptureContext* p = new tensorflow::data::VideoCaptureContext(); if (p != nullptr) { - tensorflow::Status status = p->Init("/dev/video0", bytes, width, height); + tensorflow::Status status = p->Init(device, bytes, width, height); if (status.ok()) { return p; } @@ -79,7 +80,8 @@ class VideoCaptureReadableResource : public ResourceBase { mutex_lock l(mu_); int64_t bytes, width, height; - context_.reset(VideoCaptureInitFunction(&bytes, &width, &height)); + context_.reset( + VideoCaptureInitFunction(input.c_str(), &bytes, &width, &height)); if (context_.get() == nullptr) { return errors::InvalidArgument("unable to open device ", input); } diff --git a/tensorflow_io/core/swift/video.swift b/tensorflow_io/core/swift/video.swift index bc87c9ce4..100cf1b48 100644 --- a/tensorflow_io/core/swift/video.swift +++ b/tensorflow_io/core/swift/video.swift @@ -88,7 +88,9 @@ class VideoDataOutputSampleBufferDelegate : NSObject, AVCaptureVideoDataOutputSa typealias VideoContext = (session: AVCaptureSession, semaphore_in: DispatchSemaphore, semaphore_out: DispatchSemaphore, delegate: VideoDataOutputSampleBufferDelegate) @_silgen_name("VideoCaptureInitFunction") -func VideoCaptureInitFunction(bytes: UnsafeMutablePointer, width: UnsafeMutablePointer, height: UnsafeMutablePointer) -> UnsafeMutablePointer? { +func VideoCaptureInitFunction(devname: UnsafePointer, bytes: UnsafeMutablePointer, width: UnsafeMutablePointer, height: UnsafeMutablePointer) -> UnsafeMutablePointer? { + + let deviceName = String(cString: devname) let session = AVCaptureSession() let semaphore_in = DispatchSemaphore(value: 0) diff --git a/tests/test_io_dataset_eager.py b/tests/test_io_dataset_eager.py index 28a285e65..5cfcdfabf 100644 --- a/tests/test_io_dataset_eager.py +++ b/tests/test_io_dataset_eager.py @@ -815,7 +815,7 @@ def fixture_video_capture(): # macOS: ffmpeg -s 1280x720 -pix_fmt nv12 -i frame_{i}.yuv frame_{i}.jpg # Linux: ffmpeg -s 320x240 -pix_fmt yuyv422 -i frame_{i}.yuv frame_{i}.jpg dataset = tfio.experimental.IODataset.stream().from_video_capture( - "device").take(5) + "/dev/video0").take(5) i = 0 for frame in dataset: print("Frame {}: shape({}) dtype({}) length({})".format( From 51b0e5a9749e13232af20dc42fe26d0ab07a4cc4 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 2 Mar 2020 16:47:41 -0800 Subject: [PATCH 4/5] Fix typo in Windows Signed-off-by: Yong Tang --- tensorflow_io/core/kernels/video_kernels.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_io/core/kernels/video_kernels.cc b/tensorflow_io/core/kernels/video_kernels.cc index dc007bf52..725461768 100644 --- a/tensorflow_io/core/kernels/video_kernels.cc +++ b/tensorflow_io/core/kernels/video_kernels.cc @@ -21,7 +21,7 @@ void* VideoCaptureInitFunction(const char* device, int64_t* bytes, int64_t* width, int64_t* height); void VideoCaptureNextFunction(void* context, void* data, int64_t size); void VideoCaptureFiniFunction(void* context); -#elif defined(_MSV_VER) +#elif defined(_MSC_VER) void* VideoCaptureInitFunction(const char* device, int64_t* bytes, int64_t* width, int64_t* height) { return NULL; From eff6922906cf26e8f0ed6bb877ee2a6e3a520398 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 3 Mar 2020 05:46:19 +0000 Subject: [PATCH 5/5] Fix test typo Signed-off-by: Yong Tang --- tests/test_io_dataset_eager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_io_dataset_eager.py b/tests/test_io_dataset_eager.py index 5cfcdfabf..3ab3e78ac 100644 --- a/tests/test_io_dataset_eager.py +++ b/tests/test_io_dataset_eager.py @@ -810,6 +810,11 @@ def func(q): @pytest.fixture(name="video_capture") def fixture_video_capture(): """fixture_video_capture + # Note: on Linux v4l2loopback is used, and the following is needed: + # gst-launch-1.0 videotestsrc ! v4l2sink device=/dev/video0 + # otherwise fmt will not work with + # $ v4l2-ctl -d /dev/video0 -V + # VIDIOC_G_FMT: failed: Invalid argument # Note: the following is a validation # YUV image could be converted to JPEG with: # macOS: ffmpeg -s 1280x720 -pix_fmt nv12 -i frame_{i}.yuv frame_{i}.jpg @@ -824,7 +829,7 @@ def fixture_video_capture(): i += 1 """ - args = "device" + args = "/dev/video0" def func(q): dataset = tfio.experimental.IODataset.stream().from_video_capture( q)