Skip to content

Commit 83c5a5a

Browse files
authored
Enable aten::linalg_inv and aten::linalg_inv_ex (#1940)
This PR introduces the `inverse`, `linalg_inv` and `linalg_inv_ex` operators for XPU. Added skip entries for error and warning tests related to inverse operations for all dtypes, due to a temporary regression in oneMKL.
1 parent a6790a9 commit 83c5a5a

File tree

3 files changed

+42
-26
lines changed

3 files changed

+42
-26
lines changed

src/ATen/native/xpu/XPUFallback.template

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
201201
"_linalg_eigh.eigenvalues",
202202
"linalg_householder_product",
203203
"linalg_householder_product.out",
204-
"linalg_inv_ex.inverse",
205204
"linalg_ldl_factor_ex.out",
206205
"linalg_ldl_solve.out",
207206
"linalg_lstsq.out",

test/xpu/skip_list_common.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,6 @@
551551
"test_neg_view_linalg_eigh_xpu_float64",
552552
"test_neg_view_linalg_eigvalsh_xpu_float64",
553553
"test_neg_view_linalg_householder_product_xpu_float64",
554-
"test_neg_view_linalg_inv_ex_xpu_float64",
555-
"test_neg_view_linalg_inv_xpu_float64",
556554
"test_neg_view_linalg_ldl_factor_ex_xpu_float64",
557555
"test_neg_view_linalg_ldl_factor_xpu_float64",
558556
"test_neg_view_linalg_ldl_solve_xpu_float64",
@@ -1279,10 +1277,8 @@
12791277
"test_invariance_error_spectral_decompositions_xpu_complex128",
12801278
"test_inverse_many_batches_xpu_complex128",
12811279
"test_inverse_many_batches_xpu_complex64",
1282-
"test_inverse_many_batches_xpu_float64",
12831280
"test_inverse_xpu_complex128",
12841281
"test_inverse_xpu_complex64",
1285-
"test_inverse_xpu_float64",
12861282
"test_ldl_factor_xpu_complex128",
12871283
"test_ldl_factor_xpu_complex64",
12881284
"test_ldl_factor_xpu_float64",
@@ -1471,6 +1467,27 @@
14711467
"test_scaled_gemm_offline_tunableop_xpu_float8_e5m2fnuz",
14721468
# case need to port for xpu
14731469
"test_gemm_bias_offline_tunableop_xpu_bfloat16",
1470+
# Exception is temporarily unavailable due to regression in oneMKL
1471+
"test_inv_errors_and_warnings_xpu_complex128",
1472+
"test_inv_errors_and_warnings_xpu_complex64",
1473+
"test_inv_errors_and_warnings_xpu_float32",
1474+
"test_inv_errors_and_warnings_xpu_float64",
1475+
"test_inverse_errors_large_xpu_complex128",
1476+
"test_inverse_errors_large_xpu_complex64",
1477+
"test_inverse_errors_large_xpu_float32",
1478+
"test_inverse_errors_large_xpu_float64",
1479+
"test_inverse_errors_xpu_complex128",
1480+
"test_inverse_errors_xpu_complex64",
1481+
"test_inverse_errors_xpu_float32",
1482+
"test_inverse_errors_xpu_float64",
1483+
"test_inv_ex_singular_xpu_complex128",
1484+
"test_inv_ex_singular_xpu_complex64",
1485+
"test_inv_ex_singular_xpu_float32",
1486+
"test_inv_ex_singular_xpu_float64",
1487+
"test_tensorinv_singular_input_xpu_complex128",
1488+
"test_tensorinv_singular_input_xpu_complex64",
1489+
"test_tensorinv_singular_input_xpu_float32",
1490+
"test_tensorinv_singular_input_xpu_float64",
14741491
),
14751492
"test_ops_fwd_gradients_xpu.py": (
14761493
# All of the followings are oneDNN issues
@@ -1520,9 +1537,7 @@
15201537
"test_fn_fwgrad_bwgrad_linalg_householder_product_xpu_complex128",
15211538
"test_fn_fwgrad_bwgrad_linalg_householder_product_xpu_float64",
15221539
"test_fn_fwgrad_bwgrad_linalg_inv_ex_xpu_complex128",
1523-
"test_fn_fwgrad_bwgrad_linalg_inv_ex_xpu_float64",
15241540
"test_fn_fwgrad_bwgrad_linalg_inv_xpu_complex128",
1525-
"test_fn_fwgrad_bwgrad_linalg_inv_xpu_float64",
15261541
"test_fn_fwgrad_bwgrad_linalg_lstsq_grad_oriented_xpu_complex128",
15271542
"test_fn_fwgrad_bwgrad_linalg_lstsq_grad_oriented_xpu_float64",
15281543
"test_fn_fwgrad_bwgrad_linalg_lu_factor_ex_xpu_complex128",
@@ -1554,7 +1569,6 @@
15541569
"test_fn_fwgrad_bwgrad_linalg_svdvals_xpu_complex128",
15551570
"test_fn_fwgrad_bwgrad_linalg_svdvals_xpu_float64",
15561571
"test_fn_fwgrad_bwgrad_linalg_tensorinv_xpu_complex128",
1557-
"test_fn_fwgrad_bwgrad_linalg_tensorinv_xpu_float64",
15581572
"test_fn_fwgrad_bwgrad_linalg_tensorsolve_xpu_complex128",
15591573
"test_fn_fwgrad_bwgrad_linalg_tensorsolve_xpu_float64",
15601574
"test_fn_fwgrad_bwgrad_logdet_xpu_complex128",
@@ -1632,9 +1646,7 @@
16321646
"test_forward_mode_AD_linalg_householder_product_xpu_complex128",
16331647
"test_forward_mode_AD_linalg_householder_product_xpu_float64",
16341648
"test_forward_mode_AD_linalg_inv_ex_xpu_complex128",
1635-
"test_forward_mode_AD_linalg_inv_ex_xpu_float64",
16361649
"test_forward_mode_AD_linalg_inv_xpu_complex128",
1637-
"test_forward_mode_AD_linalg_inv_xpu_float64",
16381650
"test_forward_mode_AD_linalg_lstsq_grad_oriented_xpu_complex128",
16391651
"test_forward_mode_AD_linalg_lstsq_grad_oriented_xpu_float64",
16401652
"test_forward_mode_AD_linalg_lu_factor_ex_xpu_complex128",
@@ -1667,7 +1679,6 @@
16671679
"test_forward_mode_AD_linalg_svdvals_xpu_complex128",
16681680
"test_forward_mode_AD_linalg_svdvals_xpu_float64",
16691681
"test_forward_mode_AD_linalg_tensorinv_xpu_complex128",
1670-
"test_forward_mode_AD_linalg_tensorinv_xpu_float64",
16711682
"test_forward_mode_AD_linalg_tensorsolve_xpu_complex128",
16721683
"test_forward_mode_AD_linalg_tensorsolve_xpu_float64",
16731684
"test_forward_mode_AD_logdet_xpu_complex128",
@@ -1858,9 +1869,7 @@
18581869
"test_fn_grad_linalg_householder_product_xpu_complex128",
18591870
"test_fn_grad_linalg_householder_product_xpu_float64",
18601871
"test_fn_grad_linalg_inv_ex_xpu_complex128",
1861-
"test_fn_grad_linalg_inv_ex_xpu_float64",
18621872
"test_fn_grad_linalg_inv_xpu_complex128",
1863-
"test_fn_grad_linalg_inv_xpu_float64",
18641873
"test_fn_grad_linalg_lstsq_grad_oriented_xpu_complex128",
18651874
"test_fn_grad_linalg_lstsq_grad_oriented_xpu_float64",
18661875
"test_fn_grad_linalg_lu_factor_ex_xpu_complex128",
@@ -1893,7 +1902,6 @@
18931902
"test_fn_grad_linalg_svdvals_xpu_complex128",
18941903
"test_fn_grad_linalg_svdvals_xpu_float64",
18951904
"test_fn_grad_linalg_tensorinv_xpu_complex128",
1896-
"test_fn_grad_linalg_tensorinv_xpu_float64",
18971905
"test_fn_grad_linalg_tensorsolve_xpu_complex128",
18981906
"test_fn_grad_linalg_tensorsolve_xpu_float64",
18991907
"test_fn_grad_logdet_xpu_complex128",
@@ -1973,9 +1981,7 @@
19731981
"test_fn_gradgrad_linalg_householder_product_xpu_complex128",
19741982
"test_fn_gradgrad_linalg_householder_product_xpu_float64",
19751983
"test_fn_gradgrad_linalg_inv_ex_xpu_complex128",
1976-
"test_fn_gradgrad_linalg_inv_ex_xpu_float64",
19771984
"test_fn_gradgrad_linalg_inv_xpu_complex128",
1978-
"test_fn_gradgrad_linalg_inv_xpu_float64",
19791985
"test_fn_gradgrad_linalg_lstsq_grad_oriented_xpu_complex128",
19801986
"test_fn_gradgrad_linalg_lstsq_grad_oriented_xpu_float64",
19811987
"test_fn_gradgrad_linalg_lu_factor_ex_xpu_complex128",
@@ -2006,7 +2012,6 @@
20062012
"test_fn_gradgrad_linalg_svdvals_xpu_complex128",
20072013
"test_fn_gradgrad_linalg_svdvals_xpu_float64",
20082014
"test_fn_gradgrad_linalg_tensorinv_xpu_complex128",
2009-
"test_fn_gradgrad_linalg_tensorinv_xpu_float64",
20102015
"test_fn_gradgrad_linalg_tensorsolve_xpu_complex128",
20112016
"test_fn_gradgrad_linalg_tensorsolve_xpu_float64",
20122017
"test_fn_gradgrad_logdet_xpu_complex128",
@@ -2334,9 +2339,7 @@
23342339
"test_dispatch_meta_outplace_linalg_eigvalsh_xpu_complex",
23352340
"test_dispatch_meta_outplace_linalg_eigvalsh_xpu_float64",
23362341
"test_dispatch_meta_outplace_linalg_inv_ex_xpu_complex",
2337-
"test_dispatch_meta_outplace_linalg_inv_ex_xpu_float64",
23382342
"test_dispatch_meta_outplace_linalg_inv_xpu_complex",
2339-
"test_dispatch_meta_outplace_linalg_inv_xpu_float64",
23402343
"test_dispatch_meta_outplace_linalg_ldl_factor_ex_xpu_complex",
23412344
"test_dispatch_meta_outplace_linalg_ldl_factor_ex_xpu_float64",
23422345
"test_dispatch_meta_outplace_linalg_ldl_factor_xpu_complex",
@@ -2371,7 +2374,6 @@
23712374
"test_dispatch_meta_outplace_linalg_svd_xpu_complex",
23722375
"test_dispatch_meta_outplace_linalg_svd_xpu_float64",
23732376
"test_dispatch_meta_outplace_linalg_tensorinv_xpu_complex",
2374-
"test_dispatch_meta_outplace_linalg_tensorinv_xpu_float64",
23752377
"test_dispatch_meta_outplace_logdet_xpu_complex",
23762378
"test_dispatch_meta_outplace_logdet_xpu_float64",
23772379
"test_dispatch_meta_outplace_lu_solve_xpu_complex",
@@ -2454,9 +2456,7 @@
24542456
"test_dispatch_symbolic_meta_outplace_linalg_eigvalsh_xpu_complex",
24552457
"test_dispatch_symbolic_meta_outplace_linalg_eigvalsh_xpu_float64",
24562458
"test_dispatch_symbolic_meta_outplace_linalg_inv_ex_xpu_complex",
2457-
"test_dispatch_symbolic_meta_outplace_linalg_inv_ex_xpu_float64",
24582459
"test_dispatch_symbolic_meta_outplace_linalg_inv_xpu_complex",
2459-
"test_dispatch_symbolic_meta_outplace_linalg_inv_xpu_float64",
24602460
"test_dispatch_symbolic_meta_outplace_linalg_ldl_factor_ex_xpu_complex",
24612461
"test_dispatch_symbolic_meta_outplace_linalg_ldl_factor_ex_xpu_float64",
24622462
"test_dispatch_symbolic_meta_outplace_linalg_ldl_factor_xpu_complex",
@@ -2491,7 +2491,6 @@
24912491
"test_dispatch_symbolic_meta_outplace_linalg_svd_xpu_complex",
24922492
"test_dispatch_symbolic_meta_outplace_linalg_svd_xpu_float64",
24932493
"test_dispatch_symbolic_meta_outplace_linalg_tensorinv_xpu_complex",
2494-
"test_dispatch_symbolic_meta_outplace_linalg_tensorinv_xpu_float64",
24952494
"test_dispatch_symbolic_meta_outplace_logdet_xpu_complex",
24962495
"test_dispatch_symbolic_meta_outplace_logdet_xpu_float64",
24972496
"test_dispatch_symbolic_meta_outplace_lu_solve_xpu_complex",
@@ -2574,9 +2573,7 @@
25742573
"test_meta_outplace_linalg_eigvalsh_xpu_complex",
25752574
"test_meta_outplace_linalg_eigvalsh_xpu_float64",
25762575
"test_meta_outplace_linalg_inv_ex_xpu_complex",
2577-
"test_meta_outplace_linalg_inv_ex_xpu_float64",
25782576
"test_meta_outplace_linalg_inv_xpu_complex",
2579-
"test_meta_outplace_linalg_inv_xpu_float64",
25802577
"test_meta_outplace_linalg_ldl_factor_ex_xpu_complex",
25812578
"test_meta_outplace_linalg_ldl_factor_ex_xpu_float64",
25822579
"test_meta_outplace_linalg_ldl_factor_xpu_complex",
@@ -2611,7 +2608,6 @@
26112608
"test_meta_outplace_linalg_svd_xpu_complex",
26122609
"test_meta_outplace_linalg_svd_xpu_float64",
26132610
"test_meta_outplace_linalg_tensorinv_xpu_complex",
2614-
"test_meta_outplace_linalg_tensorinv_xpu_float64",
26152611
"test_meta_outplace_logdet_xpu_complex",
26162612
"test_meta_outplace_logdet_xpu_float64",
26172613
"test_meta_outplace_lu_solve_xpu_complex",

yaml/native/native_functions.yaml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9433,3 +9433,24 @@
94339433

94349434
- func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor
94359435
python_module: linalg
9436+
9437+
- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
9438+
python_module: linalg
9439+
structured_delegate: linalg_inv_ex.inverse
9440+
9441+
- func: linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info)
9442+
python_module: linalg
9443+
structured: True
9444+
dispatch:
9445+
XPU: linalg_inv_ex_out
9446+
9447+
- func: linalg_inv(Tensor A) -> Tensor
9448+
python_module: linalg
9449+
9450+
- func: linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)
9451+
python_module: linalg
9452+
9453+
- func: inverse(Tensor self) -> Tensor
9454+
variants: function, method
9455+
9456+
- func: inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)

0 commit comments

Comments
 (0)