Skip to content

Commit 3e00c95

Browse files
smessmerRob Kunkle
authored and
Rob Kunkle
committed
Implement c10 ops needed for benchmark (pytorch#9360)
Summary: Pull Request resolved: pytorch#9360 This implements a first set of c10 operators, namely the ones needed for the multithread predictor benchmark. All implementations are CPU-only and experimental. They're not meant to be used in production. They can be used, however, to test calling simple c10 MLPs from Caffe2 or PyTorch when working on these integration paths. Reviewed By: dzhulgakov Differential Revision: D8811698 fbshipit-source-id: 826789c38b2bfdb125a5c0d03c5aebf627785482
1 parent 454dfc3 commit 3e00c95

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+2943
-24
lines changed

caffe2/core/dispatch/DispatchTable.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ThreadsafeOperatorTable_ final {
2727
});
2828
if (!res) {
2929
std::ostringstream msg;
30+
using ::operator<<;
3031
msg << "Tried to register conflicting kernels to the dispatcher: " << key;
3132
throw std::logic_error(msg.str());
3233
}

caffe2/core/dispatch/KernelRegistration.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include "caffe2/core/dispatch/OpSchema.h"
44
#include "caffe2/core/dispatch/Dispatcher.h"
5-
#include "caffe2/utils/Optional.h"
5+
#include <ATen/core/optional.h>
66

77
/**
88
* To register your own kernel for an operator, do in one (!) cpp file:
@@ -89,13 +89,13 @@ class KernelRegistrationBuilder final {
8989
static constexpr uint64_t KERNEL_PRESENT = 0x01 << 0;
9090
static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 1;
9191

92-
optional<typename Schema::signature::func_type*> kernel_;
93-
optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_;
92+
at::optional<typename Schema::signature::func_type*> kernel_;
93+
at::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_;
9494

9595
public:
96-
constexpr KernelRegistrationBuilder(): KernelRegistrationBuilder(nullopt, nullopt) {}
96+
constexpr KernelRegistrationBuilder(): KernelRegistrationBuilder(at::nullopt, at::nullopt) {}
9797

98-
constexpr KernelRegistrationBuilder(optional<typename Schema::signature::func_type*> kernel, optional<typename Schema::dispatch::dispatch_key_type> dispatch_key)
98+
constexpr KernelRegistrationBuilder(at::optional<typename Schema::signature::func_type*> kernel, at::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key)
9999
: kernel_(std::move(kernel)), dispatch_key_(std::move(dispatch_key)) {}
100100

101101
/**

caffe2/core/operator_c10wrapper.h

Lines changed: 288 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,29 @@
22

33
#include "caffe2/core/dispatch/Dispatcher.h"
44
#include "caffe2/core/operator.h"
5+
#include <ATen/core/ArrayRef.h>
6+
#include "caffe2/utils/Metaprogramming.h"
57

68
namespace caffe2 {
79

10+
namespace details {
11+
template <size_t...>
12+
struct true_t : std::true_type {};
13+
template <class State>
14+
inline std::shared_ptr<State> init_state() {
15+
return std::make_shared<State>();
16+
}
17+
template <>
18+
inline std::shared_ptr<void> init_state<void>() {
19+
return std::shared_ptr<void>();
20+
}
21+
template <class T>
22+
using is_output_arg = std::is_same<Tensor*, T>;
23+
template <class ParameterDef>
24+
using extract_type_t =
25+
c10::guts::result_of_t<decltype (&ParameterDef::parse)(ArgumentHelper)>;
26+
} // namespace details
27+
828
/**
929
* To make a c10 operator "C10Add" callable from caffe2 as "C2MyAddOpName", just
1030
* write
@@ -16,26 +36,251 @@ namespace caffe2 {
1636
* TODO: Figure out a better way to handle output parameters
1737
*/
1838

19-
template <class OpSchemaDef, class Context>
39+
template <
40+
class OpSchemaDef,
41+
class Context,
42+
class State,
43+
bool use_array_input,
44+
class ParameterDefTuple>
2045
class C10OperatorWrapper final : public Operator<Context> {
2146
using Schema = c10::OpSchema<OpSchemaDef>;
2247

2348
public:
24-
C10OperatorWrapper(const OperatorDef& operator_def, Workspace* ws)
25-
: Operator<Context>(operator_def, ws) {}
49+
static_assert(
50+
c10::guts::is_instantiation_of<std::tuple, ParameterDefTuple>::value,
51+
"");
52+
using ParameterTuple =
53+
c10::guts::typelist::to_tuple_t<c10::guts::typelist::map_t<
54+
details::extract_type_t,
55+
c10::guts::typelist::from_tuple_t<ParameterDefTuple>>>;
2656

2757
USE_OPERATOR_CONTEXT_FUNCTIONS;
2858

59+
static constexpr bool op_has_context_argument = std::is_same<
60+
BaseContext*,
61+
c10::guts::typelist::last_t<
62+
typename Schema::signature::parameter_types>>::value;
63+
static constexpr bool op_has_state_argument =
64+
!std::is_same<void, State>::value;
65+
66+
C10OperatorWrapper(const OperatorDef& operator_def, Workspace* ws)
67+
: Operator<Context>(operator_def, ws),
68+
state_(details::init_state<State>()),
69+
parameters_(parse_parameters_(
70+
operator_def,
71+
c10::guts::make_index_sequence<num_parameters()>())) {}
72+
73+
static constexpr size_t num_inputs() {
74+
return Schema::signature::num_args - num_outputs() - num_parameters() -
75+
(op_has_context_argument ? 1 : 0) - (op_has_state_argument ? 1 : 0);
76+
}
77+
78+
static constexpr size_t num_parameters() {
79+
return std::tuple_size<ParameterDefTuple>::value;
80+
}
81+
82+
static constexpr size_t num_outputs() {
83+
return c10::guts::typelist::count_if<
84+
details::is_output_arg,
85+
typename Schema::signature::parameter_types>::value;
86+
}
87+
2988
bool RunOnDevice() override {
3089
RunOnDevice_(
31-
c10::guts::make_index_sequence<Schema::signature::num_args - 1>());
90+
c10::guts::make_index_sequence<num_inputs()>(),
91+
c10::guts::make_index_sequence<num_outputs()>(),
92+
c10::guts::make_index_sequence<num_parameters()>());
3293
return true;
3394
}
3495

3596
private:
36-
template <size_t... InputIndex>
37-
void RunOnDevice_(c10::guts::index_sequence<InputIndex...>) {
38-
c10::Dispatcher<OpSchemaDef>::call(Input(InputIndex)..., Output(0));
97+
template <size_t... ParameterIndex>
98+
ParameterTuple parse_parameters_(
99+
const OperatorDef& operator_def,
100+
c10::guts::index_sequence<ParameterIndex...>) {
101+
return ParameterTuple{Parameter<ParameterIndex>(operator_def)...};
102+
}
103+
104+
template <size_t Index>
105+
details::extract_type_t<
106+
typename std::tuple_element<Index, ParameterDefTuple>::type>
107+
Parameter(const OperatorDef& operator_def) {
108+
using Parameter =
109+
typename std::tuple_element<Index, ParameterDefTuple>::type;
110+
return Parameter::parse(ArgumentHelper(operator_def));
111+
}
112+
113+
template <
114+
size_t... InputIndex,
115+
size_t... OutputIndex,
116+
size_t... ParameterIndex>
117+
c10::guts::enable_if_t<
118+
details::true_t<InputIndex...>::value && op_has_context_argument &&
119+
op_has_state_argument && !use_array_input,
120+
void>
121+
RunOnDevice_(
122+
c10::guts::index_sequence<InputIndex...>,
123+
c10::guts::index_sequence<OutputIndex...>,
124+
c10::guts::index_sequence<ParameterIndex...>) {
125+
c10::Dispatcher<OpSchemaDef>::call(
126+
Input(InputIndex)...,
127+
Output(OutputIndex)...,
128+
std::get<ParameterIndex>(parameters_)...,
129+
state_.get(),
130+
static_cast<BaseContext*>(&context_));
131+
}
132+
133+
template <
134+
size_t... InputIndex,
135+
size_t... OutputIndex,
136+
size_t... ParameterIndex>
137+
c10::guts::enable_if_t<
138+
details::true_t<InputIndex...>::value && op_has_context_argument &&
139+
!op_has_state_argument && !use_array_input,
140+
void>
141+
RunOnDevice_(
142+
c10::guts::index_sequence<InputIndex...>,
143+
c10::guts::index_sequence<OutputIndex...>,
144+
c10::guts::index_sequence<ParameterIndex...>) {
145+
c10::Dispatcher<OpSchemaDef>::call(
146+
Input(InputIndex)...,
147+
Output(OutputIndex)...,
148+
std::get<ParameterIndex>(parameters_)...,
149+
static_cast<BaseContext*>(&context_));
150+
}
151+
152+
template <
153+
size_t... InputIndex,
154+
size_t... OutputIndex,
155+
size_t... ParameterIndex>
156+
c10::guts::enable_if_t<
157+
details::true_t<InputIndex...>::value && !op_has_context_argument &&
158+
op_has_state_argument && !use_array_input,
159+
void>
160+
RunOnDevice_(
161+
c10::guts::index_sequence<InputIndex...>,
162+
c10::guts::index_sequence<OutputIndex...>,
163+
c10::guts::index_sequence<ParameterIndex...>) {
164+
c10::Dispatcher<OpSchemaDef>::call(
165+
Input(InputIndex)...,
166+
Output(OutputIndex)...,
167+
std::get<ParameterIndex>(parameters_)...,
168+
state_.get());
169+
}
170+
171+
template <
172+
size_t... InputIndex,
173+
size_t... OutputIndex,
174+
size_t... ParameterIndex>
175+
c10::guts::enable_if_t<
176+
details::true_t<InputIndex...>::value && !op_has_context_argument &&
177+
!op_has_state_argument && !use_array_input,
178+
void>
179+
RunOnDevice_(
180+
c10::guts::index_sequence<InputIndex...>,
181+
c10::guts::index_sequence<OutputIndex...>,
182+
c10::guts::index_sequence<ParameterIndex...>) {
183+
c10::Dispatcher<OpSchemaDef>::call(
184+
Input(InputIndex)...,
185+
Output(OutputIndex)...,
186+
std::get<ParameterIndex>(parameters_)...);
187+
}
188+
189+
template <
190+
size_t... InputIndex,
191+
size_t... OutputIndex,
192+
size_t... ParameterIndex>
193+
c10::guts::enable_if_t<
194+
details::true_t<InputIndex...>::value && op_has_context_argument &&
195+
op_has_state_argument && use_array_input,
196+
void>
197+
RunOnDevice_(
198+
c10::guts::index_sequence<InputIndex...>,
199+
c10::guts::index_sequence<OutputIndex...>,
200+
c10::guts::index_sequence<ParameterIndex...>) {
201+
c10::Dispatcher<OpSchemaDef>::call(
202+
at::ArrayRef<const Tensor*>(array_inputs_()),
203+
Output(OutputIndex)...,
204+
std::get<ParameterIndex>(parameters_)...,
205+
state_.get(),
206+
static_cast<BaseContext*>(&context_));
207+
}
208+
209+
template <
210+
size_t... InputIndex,
211+
size_t... OutputIndex,
212+
size_t... ParameterIndex>
213+
c10::guts::enable_if_t<
214+
details::true_t<InputIndex...>::value && op_has_context_argument &&
215+
!op_has_state_argument && use_array_input,
216+
void>
217+
RunOnDevice_(
218+
c10::guts::index_sequence<InputIndex...>,
219+
c10::guts::index_sequence<OutputIndex...>,
220+
c10::guts::index_sequence<ParameterIndex...>) {
221+
c10::Dispatcher<OpSchemaDef>::call(
222+
at::ArrayRef<const Tensor*>(array_inputs_()),
223+
Output(OutputIndex)...,
224+
std::get<ParameterIndex>(parameters_)...,
225+
static_cast<BaseContext*>(&context_));
226+
}
227+
228+
template <
229+
size_t... InputIndex,
230+
size_t... OutputIndex,
231+
size_t... ParameterIndex>
232+
c10::guts::enable_if_t<
233+
details::true_t<InputIndex...>::value && !op_has_context_argument &&
234+
op_has_state_argument && use_array_input,
235+
void>
236+
RunOnDevice_(
237+
c10::guts::index_sequence<InputIndex...>,
238+
c10::guts::index_sequence<OutputIndex...>,
239+
c10::guts::index_sequence<ParameterIndex...>) {
240+
c10::Dispatcher<OpSchemaDef>::call(
241+
at::ArrayRef<const Tensor*>(array_inputs_()),
242+
Output(OutputIndex)...,
243+
std::get<ParameterIndex>(parameters_)...,
244+
state_.get());
245+
}
246+
247+
template <
248+
size_t... InputIndex,
249+
size_t... OutputIndex,
250+
size_t... ParameterIndex>
251+
c10::guts::enable_if_t<
252+
details::true_t<InputIndex...>::value && !op_has_context_argument &&
253+
!op_has_state_argument && use_array_input,
254+
void>
255+
RunOnDevice_(
256+
c10::guts::index_sequence<InputIndex...>,
257+
c10::guts::index_sequence<OutputIndex...>,
258+
c10::guts::index_sequence<ParameterIndex...>) {
259+
c10::Dispatcher<OpSchemaDef>::call(
260+
at::ArrayRef<const Tensor*>(array_inputs_()),
261+
Output(OutputIndex)...,
262+
std::get<ParameterIndex>(parameters_)...);
263+
}
264+
265+
std::vector<const Tensor*> array_inputs_() {
266+
std::vector<const Tensor*> result;
267+
result.reserve(InputSize());
268+
for (size_t i = 0; i < InputSize(); ++i) {
269+
result.push_back(&Input(i));
270+
}
271+
return result;
272+
}
273+
274+
std::shared_ptr<State> state_;
275+
276+
ParameterTuple parameters_;
277+
};
278+
279+
template <class ParameterDef>
280+
struct ParameterHelper final {
281+
static typename ParameterDef::type parse(const ArgumentHelper& helper) {
282+
return helper.GetSingleArgument<typename ParameterDef::type>(
283+
ParameterDef::name(), ParameterDef::default_value());
39284
}
40285
};
41286

@@ -47,8 +292,41 @@ CAFFE_DECLARE_REGISTRY(
47292

48293
// TODO Currently we only register the CPU variant. This is going to be fixed
49294
// once the tensor detemplatization lands.
50-
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(OpSchemaDef, Name) \
51-
CAFFE_REGISTER_CLASS( \
52-
C10OperatorRegistry, Name, C10OperatorWrapper<OpSchemaDef, CPUContext>)
295+
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(OpSchemaDef, State, Name) \
296+
CAFFE_REGISTER_CLASS( \
297+
C10OperatorRegistry, \
298+
Name, \
299+
C10OperatorWrapper<OpSchemaDef, CPUContext, State, false, std::tuple<>>)
300+
301+
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( \
302+
OpSchemaDef, State, Name, ...) \
303+
CAFFE_REGISTER_CLASS( \
304+
C10OperatorRegistry, \
305+
Name, \
306+
C10OperatorWrapper< \
307+
OpSchemaDef, \
308+
CPUContext, \
309+
State, \
310+
false, \
311+
std::tuple<__VA_ARGS__>>)
312+
313+
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT( \
314+
OpSchemaDef, State, Name) \
315+
CAFFE_REGISTER_CLASS( \
316+
C10OperatorRegistry, \
317+
Name, \
318+
C10OperatorWrapper<OpSchemaDef, CPUContext, State, true, std::tuple<>>)
319+
320+
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS( \
321+
OpSchemaDef, State, Name, ...) \
322+
CAFFE_REGISTER_CLASS( \
323+
C10OperatorRegistry, \
324+
Name, \
325+
C10OperatorWrapper< \
326+
OpSchemaDef, \
327+
CPUContext, \
328+
State, \
329+
true, \
330+
std::tuple<__VA_ARGS__>>)
53331

54332
} // namespace caffe2

caffe2/operators/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ file(GLOB tmp *.cc)
4040
file(GLOB tmp_cudnn *_cudnn.cc)
4141
exclude(tmp "${tmp}" ${tmp_cudnn})
4242
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp})
43+
file(GLOB_RECURSE tmp c10/*.cc)
44+
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp})
4345
# exclude test files and gpu files
4446
file(GLOB tmp *_test.cc)
4547
exclude(Caffe2_CPU_SRCS "${Caffe2_CPU_SRCS}" ${tmp})

0 commit comments

Comments
 (0)