diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index d544e47367..97acc59003 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -127,8 +127,10 @@ def expand_dims(X, axis): of size one at the position specified by axis. Args: - x (usm_ndarray): input array - axis (int): axis position (zero-based). If `x` has rank + x (usm_ndarray): + input array + axis (Union[int, Tuple[int]]): + axis position in the expanded axes (zero-based). If `x` has rank (i.e, number of dimensions) `N`, a valid `axis` must reside in the closed-interval `[-N-1, N]`. If provided a negative `axis`, the `axis` position at which to insert a singleton