diff --git a/functorch/csrc/BatchRulesDecompositions.cpp b/functorch/csrc/BatchRulesDecompositions.cpp index 5bb52ede3..26cb7f950 100644 --- a/functorch/csrc/BatchRulesDecompositions.cpp +++ b/functorch/csrc/BatchRulesDecompositions.cpp @@ -307,6 +307,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { OP_DECOMPOSE(frobenius_norm); OP_DECOMPOSE(type_as); OP_DECOMPOSE(linalg_diagonal); + OP_DECOMPOSE(orgqr); DECOMPOSE_FUNCTIONAL(diag_embed); DECOMPOSE_FUNCTIONAL(block_diag); diff --git a/functorch/csrc/BatchRulesLinearAlgebra.cpp b/functorch/csrc/BatchRulesLinearAlgebra.cpp index 7720b8259..495740f11 100644 --- a/functorch/csrc/BatchRulesLinearAlgebra.cpp +++ b/functorch/csrc/BatchRulesLinearAlgebra.cpp @@ -151,6 +151,20 @@ Tensor linear_decomp( return result; } +std::tuple> +householder_product_batch_rule(const Tensor &input, c10::optional input_bdim, + const Tensor &tau, c10::optional tau_bdim) +{ + auto input_ = moveBatchDimToFront(input, input_bdim); + auto tau_ = moveBatchDimToFront(tau, tau_bdim); + + auto batch_size = get_bdim_size2(input, input_bdim, tau, tau_bdim); + + input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size); + tau_ = ensure_has_bdim(tau_, tau_bdim.has_value(), batch_size); + return std::make_tuple(at::linalg_householder_product(input_, tau_), 0); +} + Tensor addmm_decomp(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) { // Decomposition that is probably not very fast... return at::add(self * beta, at::mm(mat1, mat2), alpha); @@ -165,6 +179,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT(dot, dot_batch_rule); VMAP_SUPPORT(mv, mv_batch_rule); VMAP_SUPPORT(mm, mm_batch_rule); + VMAP_SUPPORT(linalg_householder_product, + householder_product_batch_rule); m.impl("linear", linear_decomp); VARIADIC_BDIMS_BOXED(cholesky_solve); diff --git a/test/test_vmap.py b/test/test_vmap.py index 967bcda70..afd42cc8b 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3158,7 +3158,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('linalg.cholesky'), xfail('linalg.eigvals'), xfail('linalg.eigvalsh'), - xfail('linalg.householder_product'), xfail('linalg.inv'), xfail('linalg.lstsq'), xfail('linalg.matrix_norm'),