Skip to content

Commit 5880a66

Browse files
kshitij12345pytorchmergebot
authored andcommitted
[composite compliance] matrix_exp (pytorch#81225)
Ref: pytorch#69991 Pull Request resolved: pytorch#81225 Approved by: https://github.com/zou3519
1 parent 83c6113 commit 5880a66

File tree

4 files changed

+17
-10
lines changed

4 files changed

+17
-10
lines changed

functorch/test/test_ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,6 @@ def vjp_of_vjp(*args_and_cotangents):
556556
xfail('eig'), # calls aten::item
557557
xfail('linalg.eig'), # Uses aten::allclose
558558
xfail('linalg.householder_product'), # needs select_scatter
559-
xfail('matrix_exp'), # would benefit from narrow_scatter
560559
xfail('nanquantile'), # checks q via a .item() call
561560
xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0
562561
xfail('prod'), # calls nonzero
@@ -635,7 +634,6 @@ def test_vmapvjp(self, device, dtype, op):
635634
xfail('as_strided'),
636635
xfail('nn.functional.gaussian_nll_loss'),
637636
xfail('scatter'),
638-
xfail('matrix_exp'),
639637
xfail('nanquantile'),
640638
xfail('view_as_complex'),
641639
xfail('prod'),
@@ -713,6 +711,7 @@ def test_vmapjvpall(self, device, dtype, op):
713711
xfail('linalg.eig'),
714712
xfail('complex'),
715713
xfail('linalg.pinv', 'hermitian'),
714+
xfail('matrix_exp'),
716715
xfail('pinverse'),
717716
skip('_masked.mean'), # ???
718717
xfail('linalg.cholesky_ex'),
@@ -1281,7 +1280,6 @@ def fn(input, weight, bias):
12811280
@skipOps('TestOperators', 'test_vmap_autograd_grad', {
12821281
# call inplace functions
12831282
xfail('linalg.householder_product'), # inplace
1284-
xfail('matrix_exp'), # inplace
12851283
xfail('take'), # inplace
12861284
12871285
xfail('linalg.eig'), # all close?

functorch/test/test_pythonkey.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ class TestEagerFusionOpInfo(TestCase):
322322
xfail('diag_embed'),
323323
xfail('linalg.householder_product'),
324324
xfail('logit'),
325-
xfail('matrix_exp'),
326325
xfail('trapezoid'),
327326
xfail('trapz'),
328327
xfail('corrcoef'),

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3979,10 +3979,22 @@ Tensor differential_analytic_matrix_function(
39793979
meta_grad_sizes[A.dim() - 1] *= 2;
39803980

39813981
auto n = A.size(-1);
3982-
auto meta_grad = at::zeros(meta_grad_sizes, grad.options());
3983-
meta_grad.narrow(-2, 0, n).narrow(-1, 0, n).copy_(A);
3984-
meta_grad.narrow(-2, n, n).narrow(-1, n, n).copy_(A);
3985-
meta_grad.narrow(-2, 0, n).narrow(-1, n, n).copy_(grad);
3982+
Tensor meta_grad;
3983+
// For Composite Compliance, we can't copy a Subclass into a Regular Tensor,
3984+
// so we use out-of-place ops with equivalent output.
3985+
// NOTE: We can't use `new_zeros` directly as both `A` and `grad` can
3986+
// be Tensor Subclass and we don't want to make assumption about which
3987+
// one to choose for creating output buffer.
3988+
// eg. if both are BatchedTensor at different level.
3989+
if (areAnyTensorSubclassLike({A, grad})) {
3990+
meta_grad = at::cat(
3991+
{at::cat({A, grad}, -1), at::cat({at::zeros_like(A), A}, -1)}, -2);
3992+
} else {
3993+
meta_grad = at::zeros(meta_grad_sizes, grad.options());
3994+
meta_grad.narrow(-2, 0, n).narrow(-1, 0, n).copy_(A);
3995+
meta_grad.narrow(-2, n, n).narrow(-1, n, n).copy_(A);
3996+
meta_grad.narrow(-2, 0, n).narrow(-1, n, n).copy_(grad);
3997+
}
39863998

39873999
return matrix_function(meta_grad).narrow(-2, 0, n).narrow(-1, n, n);
39884000
}

torch/testing/_internal/common_methods_invocations.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13198,8 +13198,6 @@ def error_inputs_mean(op_info, device, **kwargs):
1319813198
# https://github.com/pytorch/pytorch/issues/66357
1319913199
check_batched_forward_grad=False,
1320013200
skips=(
13201-
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
13202-
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
1320313201
# times out
1320413202
DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
1320513203
),

0 commit comments

Comments
 (0)