diff --git a/functorch/csrc/BatchRulesBinaryOps.cpp b/functorch/csrc/BatchRulesBinaryOps.cpp index 5de8b5285..3f4894944 100644 --- a/functorch/csrc/BatchRulesBinaryOps.cpp +++ b/functorch/csrc/BatchRulesBinaryOps.cpp @@ -155,6 +155,14 @@ std::tuple> masked_select_batch_rule( return std::make_tuple(result, 0); } +Tensor addr_decomposition( + const Tensor& self, const Tensor& vec1, const Tensor& vec2, + const Scalar& beta, const Scalar& alpha) { + + auto outer = alpha * vec1.unsqueeze(-1) * vec2.unsqueeze(-2); + return self * beta + outer; +} + TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { #define BINARY_POINTWISE2(op, overload) \ VMAP_SUPPORT(#op"."#overload, BINARY_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload))); @@ -193,6 +201,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { BINARY_SCALAR_2(add, Tensor, Scalar); POINTWISE_BOXED(addcdiv); POINTWISE_BOXED(addcmul); + m.impl("addr", addr_decomposition); BINARY_POINTWISE(atan2); BINARY_SCALAR_2(bitwise_and, Tensor, Scalar); BINARY_POINTWISE2(bitwise_or, Tensor); diff --git a/test/test_ops.py b/test/test_ops.py index 29cfdd9ad..54b74c093 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -570,7 +570,6 @@ def test_vmapjvp(self, device, dtype, op): @skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({ xfail('view_as_complex'), xfail('__getitem__'), - xfail('addr'), xfail('cdist'), xfail('cholesky'), xfail('clamp'), diff --git a/test/test_vmap.py b/test/test_vmap.py index db97de351..2603e69d9 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3099,7 +3099,6 @@ def test_vmap_exhaustive(self, device, dtype, op): @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) @skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({ - xfail('addr'), xfail('cdist'), xfail('complex'), xfail('copysign'),