Skip to content

Commit 1f67a49

Browse files
houseroadRob Kunkle
authored and
Rob Kunkle
committed
Add support for ArgMax and ArgMin in C2 onnx backend and frontend (pytorch#9050)
Summary: Pass the end to end test cases in onnx/onnx#1049 Closes pytorch#9050 Reviewed By: hlu1 Differential Revision: D8703757 Pulled By: houseroad fbshipit-source-id: 63308202e349dfc02d532e87f49495ba1aab085b
1 parent 4fcfae9 commit 1f67a49

File tree

4 files changed

+40
-0
lines changed

4 files changed

+40
-0
lines changed

caffe2/onnx/backend.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,8 @@ Caffe2Backend::get_special_operators() const {
335335
const static std::
336336
unordered_map<std::string, Caffe2Backend::SpecialOpConverter>
337337
kSpecialOperators = {
338+
{"ArgMax", &Caffe2Backend::CreateArgMaxMin},
339+
{"ArgMin", &Caffe2Backend::CreateArgMaxMin},
338340
{"Cast", &Caffe2Backend::CreateCast},
339341
{"Constant", &Caffe2Backend::CreateConstant},
340342
{"Conv", &Caffe2Backend::CreateConvPoolOpBase},
@@ -363,6 +365,17 @@ Caffe2Backend::get_special_operators() const {
363365
// Special Operator Converters
364366
//============================
365367

368+
Caffe2Ops Caffe2Backend::CreateArgMaxMin(
369+
OnnxNode* onnx_node,
370+
int opset_version) {
371+
auto& attributes = onnx_node->attributes;
372+
if (!attributes.HasAttribute("axis")) {
373+
auto* attr = attributes.AddRewrittenAttribute("axis");
374+
attr->set_i(0);
375+
}
376+
return CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version);
377+
}
378+
366379
Caffe2Ops Caffe2Backend::CreateCast(OnnxNode* onnx_node, int opset_version) {
367380
auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version);
368381

caffe2/onnx/backend.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ class Caffe2Backend {
160160

161161
Caffe2Ops CommonOnnxNodeToCaffe2Ops(OnnxNode* onnx_node, int opset_version);
162162

163+
Caffe2Ops CreateArgMaxMin(OnnxNode* onnx_node, int opset_version);
164+
163165
Caffe2Ops CreateCast(OnnxNode* onnx_node, int opset_version);
164166

165167
Caffe2Ops CreateConstant(OnnxNode* onnx_node, int opset_version);

caffe2/onnx/onnx_exporter.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
222222
OnnxExporter::get_special_operators() const {
223223
const static std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>
224224
kSpecialOperators = {
225+
{"ArgMax", &OnnxExporter::CreateArgMaxMinOpNodes},
226+
{"ArgMin", &OnnxExporter::CreateArgMaxMinOpNodes},
225227
{"Add", &OnnxExporter::CreateBinaryElementwiseOpNodes},
226228
{"Sub", &OnnxExporter::CreateBinaryElementwiseOpNodes},
227229
{"Mul", &OnnxExporter::CreateBinaryElementwiseOpNodes},
@@ -351,6 +353,25 @@ ConvertedResult OnnxExporter::CommonCaffe2OpToOnnxNodes(
351353
return result;
352354
}
353355

356+
ConvertedResult OnnxExporter::CreateArgMaxMinOpNodes(
357+
const caffe2::OperatorDef& def,
358+
const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
359+
auto result = CommonCaffe2OpToOnnxNodes(def);
360+
auto& nodes = result.first;
361+
362+
CAFFE_ENFORCE_EQ(nodes.size(), 1);
363+
auto& node = nodes.back();
364+
365+
if (!ArgumentHelper::HasArgument(def, "axis")) {
366+
const auto& x = def.input(0);
367+
const auto& x_shape = shapes.at(x);
368+
node.add_attribute()->CopyFrom(
369+
MakeAttribute("axis", x_shape.dims().size() - 1));
370+
}
371+
372+
return result;
373+
}
374+
354375
ConvertedResult OnnxExporter::CreateBinaryElementwiseOpNodes(
355376
const caffe2::OperatorDef& def,
356377
const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {

caffe2/onnx/onnx_exporter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ class OnnxExporter {
5252
private:
5353
ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);
5454

55+
ConvertedResult CreateArgMaxMinOpNodes(
56+
const caffe2::OperatorDef& def,
57+
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
58+
5559
ConvertedResult CreateBinaryElementwiseOpNodes(
5660
const caffe2::OperatorDef& def,
5761
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);

0 commit comments

Comments
 (0)