Skip to content

Commit 0775c86

Browse files
committed
Add Video Capture Support for macOS through AVFoundation/Swift
This PR is part of the effort in resolving 814. In 814, the feature request is to add video capture support for Linux, likely through Video4Linux. Due to some limitations Video4Linux will need a compatible USB camera first. This PR, instead tries to resolve the featue requrest on macOS first. On macOS the built-in camera could be accessed through AVFoundation's Swift API. This PR uses Swift to access AVCaptureSession/etc, and exported to C function (`cdecl`) so that it could be used in C++ kernel in tensorflow-io. Since macOS's raw video capture format is NV12 (kCVPixelFormatType_420YpCbCr8BiPlanarVideoRange) additional work is needed to convert NV12 into RGB format, so that a whole pipeline could be built up to allow using video capture for tf.keras' inference. This PR does not resolve the NV12 => RGB yet. Will address in separate PRs. Also, since video capture is technically a continuous stream and is not repeatable, it is not possible to train based on video capture with multiple epochs. Finally, the following is a sample usage which takes video capture and saves as nv12 raw file. The NV12 raw file could be checked by using ffmpeg to convert to JPEG to validate. Note: the following is a validation YUV image could be converted to JPEG with: ``` ffmpeg -s 1280x720 -pix_fmt nv12 -i frame_{i}.yuv frame_{i}.jpg ``` Usage: ``` dataset = tfio.experimental.IODataset.stream().from_video_capture( "device").take(5) i = 0 for frame in dataset: print("Frame {}: shape({}) dtype({}) length({})".format( i, frame.shape, frame.dtype, tf.strings.length(frame))) tf.io.write_file("frame_{}.yuv".format(i), frame) i += 1 ``` Signed-off-by: Yong Tang <[email protected]>
1 parent f71ef1d commit 0775c86

File tree

8 files changed

+509
-0
lines changed

8 files changed

+509
-0
lines changed

tensorflow_io/core/BUILD

+21
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ cc_library(
187187

188188
exports_files([
189189
"swift/audio.swift",
190+
"swift/video.swift",
190191
])
191192

192193
cc_library(
@@ -219,6 +220,25 @@ cc_library(
219220
alwayslink = 1,
220221
)
221222

223+
cc_library(
224+
name = "video_ops",
225+
srcs = [
226+
"kernels/video_kernels.cc",
227+
"ops/video_ops.cc",
228+
],
229+
copts = tf_io_copts(),
230+
linkstatic = True,
231+
deps = [
232+
"//tensorflow_io/core:dataset_ops",
233+
] + select({
234+
"@bazel_tools//src/conditions:darwin": [
235+
"//tools/build/swift:video_swift",
236+
],
237+
"//conditions:default": [],
238+
}),
239+
alwayslink = 1,
240+
)
241+
222242
cc_library(
223243
name = "ffmpeg_3.4_ops",
224244
srcs = [
@@ -542,6 +562,7 @@ cc_binary(
542562
"//tensorflow_io/core:serialization_ops",
543563
"//tensorflow_io/core:sql_ops",
544564
"//tensorflow_io/core:text_ops",
565+
"//tensorflow_io/core:video_ops",
545566
"@local_config_tf//:libtensorflow_framework",
546567
"@local_config_tf//:tf_header_lib",
547568
] + select({
+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/resource_mgr.h"
17+
#include "tensorflow/core/framework/resource_op_kernel.h"
18+
19+
extern "C" {
20+
#if defined(__APPLE__)
21+
void* VideoCaptureInitFunction(int64_t* bytes, int64_t* width, int64_t* height);
22+
void VideoCaptureNextFunction(void* context, void* data, int64_t size);
23+
void VideoCaptureFiniFunction(void* context);
24+
#else
25+
void* VideoCaptureInitFunction(int64_t* bytes, int64_t* width,
26+
int64_t* height) {
27+
return NULL;
28+
}
29+
void VideoCaptureNextFunction(void* context, void* data, int64_t size) {}
30+
void VideoCaptureFiniFunction(void* context) {}
31+
#endif
32+
}
33+
namespace tensorflow {
34+
namespace data {
35+
namespace {
36+
37+
class VideoCaptureReadableResource : public ResourceBase {
38+
public:
39+
VideoCaptureReadableResource(Env* env)
40+
: env_(env), context_(nullptr, [](void* p) {
41+
if (p != nullptr) {
42+
VideoCaptureFiniFunction(p);
43+
}
44+
}) {}
45+
~VideoCaptureReadableResource() {}
46+
47+
Status Init(const string& input) {
48+
mutex_lock l(mu_);
49+
50+
int64_t bytes, width, height;
51+
context_.reset(VideoCaptureInitFunction(&bytes, &width, &height));
52+
if (context_.get() == nullptr) {
53+
return errors::InvalidArgument("unable to open device ", input);
54+
}
55+
bytes_ = static_cast<int64>(bytes);
56+
width_ = static_cast<int64>(width);
57+
height_ = static_cast<int64>(height);
58+
return Status::OK();
59+
}
60+
Status Read(
61+
std::function<Status(const TensorShape& shape, Tensor** value_tensor)>
62+
allocate_func) {
63+
mutex_lock l(mu_);
64+
65+
Tensor* value_tensor;
66+
TF_RETURN_IF_ERROR(allocate_func(TensorShape({1}), &value_tensor));
67+
68+
string buffer;
69+
buffer.resize(bytes_);
70+
VideoCaptureNextFunction(context_.get(), (void*)&buffer[0],
71+
static_cast<int64_t>(bytes_));
72+
value_tensor->flat<string>()(0) = buffer;
73+
74+
return Status::OK();
75+
}
76+
string DebugString() const override {
77+
mutex_lock l(mu_);
78+
return "VideoCaptureReadableResource";
79+
}
80+
81+
protected:
82+
mutable mutex mu_;
83+
Env* env_ GUARDED_BY(mu_);
84+
85+
std::unique_ptr<void, void (*)(void*)> context_;
86+
int64 bytes_;
87+
int64 width_;
88+
int64 height_;
89+
};
90+
91+
class VideoCaptureReadableInitOp
92+
: public ResourceOpKernel<VideoCaptureReadableResource> {
93+
public:
94+
explicit VideoCaptureReadableInitOp(OpKernelConstruction* context)
95+
: ResourceOpKernel<VideoCaptureReadableResource>(context) {
96+
env_ = context->env();
97+
}
98+
99+
private:
100+
void Compute(OpKernelContext* context) override {
101+
ResourceOpKernel<VideoCaptureReadableResource>::Compute(context);
102+
103+
const Tensor* input_tensor;
104+
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
105+
const string& input = input_tensor->scalar<string>()();
106+
107+
OP_REQUIRES_OK(context, resource_->Init(input));
108+
}
109+
Status CreateResource(VideoCaptureReadableResource** resource)
110+
EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
111+
*resource = new VideoCaptureReadableResource(env_);
112+
return Status::OK();
113+
}
114+
115+
private:
116+
mutable mutex mu_;
117+
Env* env_ GUARDED_BY(mu_);
118+
};
119+
120+
class VideoCaptureReadableReadOp : public OpKernel {
121+
public:
122+
explicit VideoCaptureReadableReadOp(OpKernelConstruction* context)
123+
: OpKernel(context) {
124+
env_ = context->env();
125+
}
126+
127+
void Compute(OpKernelContext* context) override {
128+
VideoCaptureReadableResource* resource;
129+
OP_REQUIRES_OK(context,
130+
GetResourceFromContext(context, "input", &resource));
131+
core::ScopedUnref unref(resource);
132+
133+
OP_REQUIRES_OK(
134+
context, resource->Read([&](const TensorShape& shape,
135+
Tensor** value_tensor) -> Status {
136+
TF_RETURN_IF_ERROR(context->allocate_output(0, shape, value_tensor));
137+
return Status::OK();
138+
}));
139+
}
140+
141+
private:
142+
mutable mutex mu_;
143+
Env* env_ GUARDED_BY(mu_);
144+
};
145+
REGISTER_KERNEL_BUILDER(Name("IO>VideoCaptureReadableInit").Device(DEVICE_CPU),
146+
VideoCaptureReadableInitOp);
147+
REGISTER_KERNEL_BUILDER(Name("IO>VideoCaptureReadableRead").Device(DEVICE_CPU),
148+
VideoCaptureReadableReadOp);
149+
150+
} // namespace
151+
} // namespace data
152+
} // namespace tensorflow

tensorflow_io/core/ops/video_ops.cc

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/common_shape_fns.h"
17+
#include "tensorflow/core/framework/op.h"
18+
#include "tensorflow/core/framework/shape_inference.h"
19+
20+
namespace tensorflow {
21+
namespace io {
22+
namespace {
23+
24+
REGISTER_OP("IO>VideoCaptureReadableInit")
25+
.Input("input: string")
26+
.Output("resource: resource")
27+
.Attr("container: string = ''")
28+
.Attr("shared_name: string = ''")
29+
.SetShapeFn([](shape_inference::InferenceContext* c) {
30+
c->set_output(0, c->Scalar());
31+
return Status::OK();
32+
});
33+
34+
REGISTER_OP("IO>VideoCaptureReadableRead")
35+
.Input("input: resource")
36+
.Input("index: int64")
37+
.Output("value: string")
38+
.SetShapeFn([](shape_inference::InferenceContext* c) {
39+
c->set_output(0, c->MakeShape({c->UnknownDim()}));
40+
return Status::OK();
41+
});
42+
43+
} // namespace
44+
} // namespace io
45+
} // namespace tensorflow

tensorflow_io/core/python/experimental/io_dataset_ops.py

+16
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tensorflow_io.core.python.experimental import file_dataset_ops
2929
from tensorflow_io.core.python.experimental import numpy_dataset_ops
3030
from tensorflow_io.core.python.experimental import sql_dataset_ops
31+
from tensorflow_io.core.python.experimental import video_dataset_ops
3132

3233
class IODataset(io_dataset.IODataset):
3334
"""IODataset"""
@@ -269,6 +270,21 @@ def to_file(cls,
269270
class StreamIODataset(tf.data.Dataset):
270271
"""StreamIODataset"""
271272

273+
@classmethod
274+
def from_video_capture(cls, device, **kwargs):
275+
"""Creates an `StreamIODataset` from video capture device.
276+
277+
Args:
278+
device: A string, the name of the device.
279+
name: A name prefix for the IODataset (optional).
280+
281+
Returns:
282+
A `IODataset`.
283+
"""
284+
with tf.name_scope(kwargs.get("name", "IOFromVideoCapture")):
285+
return video_dataset_ops.VideoCaptureIODataset(
286+
device, internal=True)
287+
272288
@classmethod
273289
def from_prometheus_scrape(cls,
274290
metric,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""VideoCaptureDataset"""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
from tensorflow_io.core.python.ops import core_ops
22+
23+
class VideoCaptureIODataset(tf.data.Dataset):
24+
"""VideoCaptureIODataset"""
25+
26+
def __init__(self,
27+
device,
28+
internal=True):
29+
"""VideoCaptureIODataset"""
30+
with tf.name_scope("VideoCaptureIODataset"):
31+
assert internal
32+
33+
resource = core_ops.io_video_capture_readable_init(device)
34+
35+
self._resource = resource
36+
37+
dataset = tf.data.experimental.Counter()
38+
dataset = dataset.map(
39+
lambda i: core_ops.io_video_capture_readable_read(self._resource, i))
40+
dataset = dataset.apply(
41+
tf.data.experimental.take_while(
42+
lambda v: tf.greater(tf.shape(v)[0], 0)))
43+
dataset = dataset.unbatch()
44+
45+
self._dataset = dataset
46+
super(VideoCaptureIODataset, self).__init__(
47+
self._dataset._variant_tensor) # pylint: disable=protected-access
48+
49+
def _inputs(self):
50+
return []
51+
52+
@property
53+
def element_spec(self):
54+
return self._dataset.element_spec

0 commit comments

Comments
 (0)