Skip to content

Commit 97ae431

Browse files
BowenBaopytorchmergebot
authored andcommitted
[ONNX] Add symbolic support for torch.nn.cosinesimilarity (#72128) (#73283)
Summary: Pull Request resolved: #73283 * Add support for torch.nn.cosine_similarity * Remove fallback logic * Fix onnx test failures * Fix opset version * Modify rtol * Add aten fallback mode * fix mypy * gate with caffe2 fallback Test Plan: Imported from OSS Reviewed By: jbschlosser Differential Revision: D34625650 Pulled By: malfet fbshipit-source-id: bf15d32b1d7055d0ca166d9941ba90b5c8e81cc2 (cherry picked from commit 7086031)
1 parent 95b1232 commit 97ae431

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7271,6 +7271,12 @@ def forward(self, x):
72717271
for x in [torch.randn(3, 4), torch.randn(3, 4).to(dtype=torch.bool)]:
72727272
self.run_test(EinsumModelTranspose(), input=(x,))
72737273

7274+
@skipIfUnsupportedMinOpsetVersion(9)
7275+
def test_cosine_similarity(self):
7276+
x = torch.randn(5, 3, 2)
7277+
y = torch.randn(5, 3, 2)
7278+
self.run_test(torch.nn.CosineSimilarity(dim=2), input=(x, y))
7279+
72747280
@skipIfUnsupportedMinOpsetVersion(12)
72757281
def test_crossentropyloss(self):
72767282
for ignore_index in [-100, 1]:

test/onnx/test_utility_funs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ def test_onnx_fallthrough(self):
957957
# Test aten export of op with symbolic for aten
958958
x = torch.randn(100, 128)
959959
y = torch.randn(100, 128)
960-
model = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
960+
model = torch.nn.PairwiseDistance(p=2, eps=1e-6)
961961

962962
graph, _, __ = self._model_to_graph(model, (x, y),
963963
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
@@ -966,7 +966,8 @@ def test_onnx_fallthrough(self):
966966
iter = graph.nodes()
967967
self.assertEqual(next(iter).kind(), "onnx::Constant")
968968
self.assertEqual(next(iter).kind(), "onnx::Constant")
969-
self.assertEqual(next(iter).kind(), "aten::cosine_similarity")
969+
self.assertEqual(next(iter).kind(), "onnx::Constant")
970+
self.assertEqual(next(iter).kind(), "aten::pairwise_distance")
970971

971972
# prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
972973
@skipIfUnsupportedMaxOpsetVersion(10)

torch/onnx/symbolic_opset9.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,10 +1548,18 @@ def type_as(g, self, other):
15481548

15491549
@parse_args("v", "v", "i", "f")
15501550
def cosine_similarity(g, x1, x2, dim, eps):
1551-
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
1551+
# preserve legacy behavior for Caffe2
1552+
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and \
1553+
torch.onnx._CAFFE2_ATEN_FALLBACK:
15521554
return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps)
1553-
else:
1554-
return sym_help._onnx_unsupported("cosine_similarity")
1555+
cross = sym_help._reducesum_helper(g, mul(g, x1, x2),
1556+
axes_i=[dim], keepdims_i=0)
1557+
x1_l2 = sym_help._reducesum_helper(g, mul(g, x1, x1),
1558+
axes_i=[dim], keepdims_i=0)
1559+
x2_l2 = sym_help._reducesum_helper(g, mul(g, x2, x2),
1560+
axes_i=[dim], keepdims_i=0)
1561+
div_tens = max(g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])))
1562+
return div(g, cross, div_tens)
15551563

15561564

15571565
# ignore clone operators that are inserted by PyTorch autograd

0 commit comments

Comments
 (0)