Skip to content

Commit 0f7c652

Browse files
committed
[batch-rule] householder_product
Original code written by @kshitij12345 in #322, this PR is just a rebase onto main
1 parent 0087934 commit 0f7c652

File tree

5 files changed

+35
-1
lines changed

5 files changed

+35
-1
lines changed

codegen/gen_vmap_plumbing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> str:
274274
'gcd',
275275
'igamma',
276276
'igammac',
277+
'linalg_householder_product',
277278
'logaddexp',
278279
'logaddexp2',
279280
'lcm',

functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
217217
OP_DECOMPOSE2(vsplit, int);
218218
OP_DECOMPOSE2(vsplit, array);
219219
OP_DECOMPOSE(vstack);
220+
OP_DECOMPOSE(orgqr);
220221
OP_DECOMPOSE2(unflatten, int);
221222
OP_DECOMPOSE(_convolution_double_backward);
222223
OP_DECOMPOSE(conv_transpose1d);

functorch/csrc/BatchRulesLinearAlgebra.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,20 @@ void _linalg_check_errors_batch_rule(const Tensor& info, optional<int64_t> info_
162162
at::_linalg_check_errors(info_, api_name, false);
163163
}
164164

165+
std::tuple<Tensor, c10::optional<int64_t>>
166+
householder_product_batch_rule(const Tensor &input, c10::optional<int64_t> input_bdim,
167+
const Tensor &tau, c10::optional<int64_t> tau_bdim)
168+
{
169+
auto input_ = moveBatchDimToFront(input, input_bdim);
170+
auto tau_ = moveBatchDimToFront(tau, tau_bdim);
171+
172+
auto batch_size = get_bdim_size2(input, input_bdim, tau, tau_bdim);
173+
174+
input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size);
175+
tau_ = ensure_has_bdim(tau_, tau_bdim.has_value(), batch_size);
176+
return std::make_tuple(at::linalg_householder_product(input_, tau_), 0);
177+
}
178+
165179
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
166180
VMAP_SUPPORT(bmm, bmm_batch_rule);
167181
m.impl("addmv", addmv_decomp);
@@ -172,6 +186,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
172186
VMAP_SUPPORT(mv, mv_batch_rule);
173187
VMAP_SUPPORT(mm, mm_batch_rule);
174188
m.impl("linear", linear_decomp);
189+
VMAP_SUPPORT(linalg_householder_product, householder_product_batch_rule);
175190

176191
VMAP_SUPPORT(_linalg_check_errors, _linalg_check_errors_batch_rule);
177192

functorch/csrc/VmapGeneratedPlumbing.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5431,6 +5431,24 @@ at::Tensor logdet_generated_plumbing(const at::Tensor & self) {
54315431
return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
54325432
}
54335433
template <typename batch_rule_t, batch_rule_t batch_rule>
5434+
at::Tensor linalg_householder_product_generated_plumbing(const at::Tensor & input, const at::Tensor & tau) {
5435+
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
5436+
auto maybe_layer = maybeCurrentDynamicLayer();
5437+
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
5438+
int64_t cur_level = maybe_layer->layerId();
5439+
if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(tau, cur_level)) {
5440+
return at::_ops::linalg_householder_product::call(input, tau);
5441+
}
5442+
Tensor input_value;
5443+
optional<int64_t> input_bdim;
5444+
std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level);
5445+
Tensor tau_value;
5446+
optional<int64_t> tau_bdim;
5447+
std::tie(tau_value, tau_bdim) = unwrapTensorAtLevel(tau, cur_level);
5448+
auto results = batch_rule(input_value, input_bdim, tau_value, tau_bdim);
5449+
return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
5450+
}
5451+
template <typename batch_rule_t, batch_rule_t batch_rule>
54345452
at::Tensor linalg_pinv_generated_plumbing(const at::Tensor & self, double rcond, bool hermitian) {
54355453
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
54365454
auto maybe_layer = maybeCurrentDynamicLayer();

test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3199,7 +3199,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
31993199
xfail('linalg.cholesky'),
32003200
xfail('linalg.eigvals'),
32013201
xfail('linalg.eigvalsh'),
3202-
xfail('linalg.householder_product'),
32033202
xfail('linalg.inv'),
32043203
xfail('linalg.lstsq'),
32053204
xfail('linalg.lstsq', 'grad_oriented'),

0 commit comments

Comments
 (0)