2
2
3
3
#include " caffe2/core/dispatch/Dispatcher.h"
4
4
#include " caffe2/core/operator.h"
5
+ #include < ATen/core/ArrayRef.h>
6
+ #include " caffe2/utils/Metaprogramming.h"
5
7
6
8
namespace caffe2 {
7
9
10
+ namespace details {
11
+ template <size_t ...>
12
+ struct true_t : std::true_type {};
13
+ template <class State >
14
+ inline std::shared_ptr<State> init_state () {
15
+ return std::make_shared<State>();
16
+ }
17
+ template <>
18
+ inline std::shared_ptr<void > init_state<void >() {
19
+ return std::shared_ptr<void >();
20
+ }
21
+ template <class T >
22
+ using is_output_arg = std::is_same<Tensor*, T>;
23
+ template <class ParameterDef >
24
+ using extract_type_t =
25
+ c10::guts::result_of_t <decltype (&ParameterDef::parse)(ArgumentHelper)>;
26
+ } // namespace details
27
+
8
28
/* *
9
29
* To make a c10 operator "C10Add" callable from caffe2 as "C2MyAddOpName", just
10
30
* write
@@ -16,26 +36,251 @@ namespace caffe2 {
16
36
* TODO: Figure out a better way to handle output parameters
17
37
*/
18
38
19
- template <class OpSchemaDef , class Context >
39
+ template <
40
+ class OpSchemaDef ,
41
+ class Context ,
42
+ class State ,
43
+ bool use_array_input,
44
+ class ParameterDefTuple >
20
45
class C10OperatorWrapper final : public Operator<Context> {
21
46
using Schema = c10::OpSchema<OpSchemaDef>;
22
47
23
48
public:
24
- C10OperatorWrapper (const OperatorDef& operator_def, Workspace* ws)
25
- : Operator<Context>(operator_def, ws) {}
49
+ static_assert (
50
+ c10::guts::is_instantiation_of<std::tuple, ParameterDefTuple>::value,
51
+ " " );
52
+ using ParameterTuple =
53
+ c10::guts::typelist::to_tuple_t <c10::guts::typelist::map_t <
54
+ details::extract_type_t ,
55
+ c10::guts::typelist::from_tuple_t <ParameterDefTuple>>>;
26
56
27
57
USE_OPERATOR_CONTEXT_FUNCTIONS;
28
58
59
+ static constexpr bool op_has_context_argument = std::is_same<
60
+ BaseContext*,
61
+ c10::guts::typelist::last_t <
62
+ typename Schema::signature::parameter_types>>::value;
63
+ static constexpr bool op_has_state_argument =
64
+ !std::is_same<void , State>::value;
65
+
66
+ C10OperatorWrapper (const OperatorDef& operator_def, Workspace* ws)
67
+ : Operator<Context>(operator_def, ws),
68
+ state_ (details::init_state<State>()),
69
+ parameters_ (parse_parameters_(
70
+ operator_def,
71
+ c10::guts::make_index_sequence<num_parameters()>())) {}
72
+
73
+ static constexpr size_t num_inputs () {
74
+ return Schema::signature::num_args - num_outputs () - num_parameters () -
75
+ (op_has_context_argument ? 1 : 0 ) - (op_has_state_argument ? 1 : 0 );
76
+ }
77
+
78
+ static constexpr size_t num_parameters () {
79
+ return std::tuple_size<ParameterDefTuple>::value;
80
+ }
81
+
82
+ static constexpr size_t num_outputs () {
83
+ return c10::guts::typelist::count_if<
84
+ details::is_output_arg,
85
+ typename Schema::signature::parameter_types>::value;
86
+ }
87
+
29
88
bool RunOnDevice () override {
30
89
RunOnDevice_ (
31
- c10::guts::make_index_sequence<Schema::signature::num_args - 1 >());
90
+ c10::guts::make_index_sequence<num_inputs ()>(),
91
+ c10::guts::make_index_sequence<num_outputs ()>(),
92
+ c10::guts::make_index_sequence<num_parameters ()>());
32
93
return true ;
33
94
}
34
95
35
96
private:
36
- template <size_t ... InputIndex>
37
- void RunOnDevice_ (c10::guts::index_sequence<InputIndex...>) {
38
- c10::Dispatcher<OpSchemaDef>::call (Input (InputIndex)..., Output (0 ));
97
+ template <size_t ... ParameterIndex>
98
+ ParameterTuple parse_parameters_ (
99
+ const OperatorDef& operator_def,
100
+ c10::guts::index_sequence<ParameterIndex...>) {
101
+ return ParameterTuple{Parameter<ParameterIndex>(operator_def)...};
102
+ }
103
+
104
+ template <size_t Index>
105
+ details::extract_type_t <
106
+ typename std::tuple_element<Index, ParameterDefTuple>::type>
107
+ Parameter (const OperatorDef& operator_def) {
108
+ using Parameter =
109
+ typename std::tuple_element<Index, ParameterDefTuple>::type;
110
+ return Parameter::parse (ArgumentHelper (operator_def));
111
+ }
112
+
113
+ template <
114
+ size_t ... InputIndex,
115
+ size_t ... OutputIndex,
116
+ size_t ... ParameterIndex>
117
+ c10::guts::enable_if_t <
118
+ details::true_t <InputIndex...>::value && op_has_context_argument &&
119
+ op_has_state_argument && !use_array_input,
120
+ void >
121
+ RunOnDevice_ (
122
+ c10::guts::index_sequence<InputIndex...>,
123
+ c10::guts::index_sequence<OutputIndex...>,
124
+ c10::guts::index_sequence<ParameterIndex...>) {
125
+ c10::Dispatcher<OpSchemaDef>::call (
126
+ Input (InputIndex)...,
127
+ Output (OutputIndex)...,
128
+ std::get<ParameterIndex>(parameters_)...,
129
+ state_.get (),
130
+ static_cast <BaseContext*>(&context_));
131
+ }
132
+
133
+ template <
134
+ size_t ... InputIndex,
135
+ size_t ... OutputIndex,
136
+ size_t ... ParameterIndex>
137
+ c10::guts::enable_if_t <
138
+ details::true_t <InputIndex...>::value && op_has_context_argument &&
139
+ !op_has_state_argument && !use_array_input,
140
+ void >
141
+ RunOnDevice_ (
142
+ c10::guts::index_sequence<InputIndex...>,
143
+ c10::guts::index_sequence<OutputIndex...>,
144
+ c10::guts::index_sequence<ParameterIndex...>) {
145
+ c10::Dispatcher<OpSchemaDef>::call (
146
+ Input (InputIndex)...,
147
+ Output (OutputIndex)...,
148
+ std::get<ParameterIndex>(parameters_)...,
149
+ static_cast <BaseContext*>(&context_));
150
+ }
151
+
152
+ template <
153
+ size_t ... InputIndex,
154
+ size_t ... OutputIndex,
155
+ size_t ... ParameterIndex>
156
+ c10::guts::enable_if_t <
157
+ details::true_t <InputIndex...>::value && !op_has_context_argument &&
158
+ op_has_state_argument && !use_array_input,
159
+ void >
160
+ RunOnDevice_ (
161
+ c10::guts::index_sequence<InputIndex...>,
162
+ c10::guts::index_sequence<OutputIndex...>,
163
+ c10::guts::index_sequence<ParameterIndex...>) {
164
+ c10::Dispatcher<OpSchemaDef>::call (
165
+ Input (InputIndex)...,
166
+ Output (OutputIndex)...,
167
+ std::get<ParameterIndex>(parameters_)...,
168
+ state_.get ());
169
+ }
170
+
171
+ template <
172
+ size_t ... InputIndex,
173
+ size_t ... OutputIndex,
174
+ size_t ... ParameterIndex>
175
+ c10::guts::enable_if_t <
176
+ details::true_t <InputIndex...>::value && !op_has_context_argument &&
177
+ !op_has_state_argument && !use_array_input,
178
+ void >
179
+ RunOnDevice_ (
180
+ c10::guts::index_sequence<InputIndex...>,
181
+ c10::guts::index_sequence<OutputIndex...>,
182
+ c10::guts::index_sequence<ParameterIndex...>) {
183
+ c10::Dispatcher<OpSchemaDef>::call (
184
+ Input (InputIndex)...,
185
+ Output (OutputIndex)...,
186
+ std::get<ParameterIndex>(parameters_)...);
187
+ }
188
+
189
+ template <
190
+ size_t ... InputIndex,
191
+ size_t ... OutputIndex,
192
+ size_t ... ParameterIndex>
193
+ c10::guts::enable_if_t <
194
+ details::true_t <InputIndex...>::value && op_has_context_argument &&
195
+ op_has_state_argument && use_array_input,
196
+ void >
197
+ RunOnDevice_ (
198
+ c10::guts::index_sequence<InputIndex...>,
199
+ c10::guts::index_sequence<OutputIndex...>,
200
+ c10::guts::index_sequence<ParameterIndex...>) {
201
+ c10::Dispatcher<OpSchemaDef>::call (
202
+ at::ArrayRef<const Tensor*>(array_inputs_ ()),
203
+ Output (OutputIndex)...,
204
+ std::get<ParameterIndex>(parameters_)...,
205
+ state_.get (),
206
+ static_cast <BaseContext*>(&context_));
207
+ }
208
+
209
+ template <
210
+ size_t ... InputIndex,
211
+ size_t ... OutputIndex,
212
+ size_t ... ParameterIndex>
213
+ c10::guts::enable_if_t <
214
+ details::true_t <InputIndex...>::value && op_has_context_argument &&
215
+ !op_has_state_argument && use_array_input,
216
+ void >
217
+ RunOnDevice_ (
218
+ c10::guts::index_sequence<InputIndex...>,
219
+ c10::guts::index_sequence<OutputIndex...>,
220
+ c10::guts::index_sequence<ParameterIndex...>) {
221
+ c10::Dispatcher<OpSchemaDef>::call (
222
+ at::ArrayRef<const Tensor*>(array_inputs_ ()),
223
+ Output (OutputIndex)...,
224
+ std::get<ParameterIndex>(parameters_)...,
225
+ static_cast <BaseContext*>(&context_));
226
+ }
227
+
228
+ template <
229
+ size_t ... InputIndex,
230
+ size_t ... OutputIndex,
231
+ size_t ... ParameterIndex>
232
+ c10::guts::enable_if_t <
233
+ details::true_t <InputIndex...>::value && !op_has_context_argument &&
234
+ op_has_state_argument && use_array_input,
235
+ void >
236
+ RunOnDevice_ (
237
+ c10::guts::index_sequence<InputIndex...>,
238
+ c10::guts::index_sequence<OutputIndex...>,
239
+ c10::guts::index_sequence<ParameterIndex...>) {
240
+ c10::Dispatcher<OpSchemaDef>::call (
241
+ at::ArrayRef<const Tensor*>(array_inputs_ ()),
242
+ Output (OutputIndex)...,
243
+ std::get<ParameterIndex>(parameters_)...,
244
+ state_.get ());
245
+ }
246
+
247
+ template <
248
+ size_t ... InputIndex,
249
+ size_t ... OutputIndex,
250
+ size_t ... ParameterIndex>
251
+ c10::guts::enable_if_t <
252
+ details::true_t <InputIndex...>::value && !op_has_context_argument &&
253
+ !op_has_state_argument && use_array_input,
254
+ void >
255
+ RunOnDevice_ (
256
+ c10::guts::index_sequence<InputIndex...>,
257
+ c10::guts::index_sequence<OutputIndex...>,
258
+ c10::guts::index_sequence<ParameterIndex...>) {
259
+ c10::Dispatcher<OpSchemaDef>::call (
260
+ at::ArrayRef<const Tensor*>(array_inputs_ ()),
261
+ Output (OutputIndex)...,
262
+ std::get<ParameterIndex>(parameters_)...);
263
+ }
264
+
265
+ std::vector<const Tensor*> array_inputs_ () {
266
+ std::vector<const Tensor*> result;
267
+ result.reserve (InputSize ());
268
+ for (size_t i = 0 ; i < InputSize (); ++i) {
269
+ result.push_back (&Input (i));
270
+ }
271
+ return result;
272
+ }
273
+
274
+ std::shared_ptr<State> state_;
275
+
276
+ ParameterTuple parameters_;
277
+ };
278
+
279
+ template <class ParameterDef >
280
+ struct ParameterHelper final {
281
+ static typename ParameterDef::type parse (const ArgumentHelper& helper) {
282
+ return helper.GetSingleArgument <typename ParameterDef::type>(
283
+ ParameterDef::name (), ParameterDef::default_value ());
39
284
}
40
285
};
41
286
@@ -47,8 +292,41 @@ CAFFE_DECLARE_REGISTRY(
47
292
48
293
// TODO Currently we only register the CPU variant. This is going to be fixed
49
294
// once the tensor detemplatization lands.
50
- #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH (OpSchemaDef, Name ) \
51
- CAFFE_REGISTER_CLASS ( \
52
- C10OperatorRegistry, Name, C10OperatorWrapper<OpSchemaDef, CPUContext>)
295
+ #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH (OpSchemaDef, State, Name ) \
296
+ CAFFE_REGISTER_CLASS ( \
297
+ C10OperatorRegistry, \
298
+ Name, \
299
+ C10OperatorWrapper<OpSchemaDef, CPUContext, State, false , std::tuple<>>)
300
+
301
+ #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS ( \
302
+ OpSchemaDef, State, Name, ...) \
303
+ CAFFE_REGISTER_CLASS ( \
304
+ C10OperatorRegistry, \
305
+ Name, \
306
+ C10OperatorWrapper< \
307
+ OpSchemaDef, \
308
+ CPUContext, \
309
+ State, \
310
+ false , \
311
+ std::tuple<__VA_ARGS__>>)
312
+
313
+ #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT ( \
314
+ OpSchemaDef, State, Name) \
315
+ CAFFE_REGISTER_CLASS ( \
316
+ C10OperatorRegistry, \
317
+ Name, \
318
+ C10OperatorWrapper<OpSchemaDef, CPUContext, State, true , std::tuple<>>)
319
+
320
+ #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_ARRAY_INPUT_AND_PARAMETERS ( \
321
+ OpSchemaDef, State, Name, ...) \
322
+ CAFFE_REGISTER_CLASS ( \
323
+ C10OperatorRegistry, \
324
+ Name, \
325
+ C10OperatorWrapper< \
326
+ OpSchemaDef, \
327
+ CPUContext, \
328
+ State, \
329
+ true , \
330
+ std::tuple<__VA_ARGS__>>)
53
331
54
332
} // namespace caffe2
0 commit comments