From 5cf5d8f404b18ff67543762ed8e92cb0f359f885 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 11 May 2025 21:58:35 +0200 Subject: [PATCH] BUG: torch: meshgrid defaults to indexing="xy" As of version 2.6, torch defaults to indexing='ij', and is planning to transition to 'xy' at some point. When it does, we'll be able to drop our wrapper. --- array_api_compat/torch/_aliases.py | 10 ++++++++-- tests/test_torch.py | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 335008e4..de5d1a5d 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,7 +2,7 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union, Literal import torch @@ -828,6 +828,12 @@ def sign(x: Array, /) -> Array: return out +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]: + # enforce the default of 'xy' + # TODO: is the return type a list or a tuple + return list(torch.meshgrid(*arrays, indexing='xy')) + + __all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', @@ -844,6 +850,6 @@ def sign(x: Array, /) -> Array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid'] _all_ignore = ['torch', 'get_xp'] diff --git a/tests/test_torch.py b/tests/test_torch.py index e8340f31..7adb4ab3 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -100,3 +100,20 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b): assert dtype_1 == dtype_2 finally: torch.set_default_dtype(prev_default) + + +def test_meshgrid(): + """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'.""" + + x, y = xp.asarray([1, 2]), xp.asarray([4]) + + X, Y = xp.meshgrid(x, y) + + # output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different + X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]]) + + assert X.shape == X_xy.shape + assert xp.all(X == X_xy) + + assert Y.shape == Y_xy.shape + assert xp.all(Y == Y_xy)