Skip to content

Commit 341d4da

Browse files
Merge pull request #1324 from IntelPython/enable-operators
Enable support for Python operators in usm_ndarray class
2 parents acd1a60 + ebd1faf commit 341d4da

File tree

15 files changed

+301
-264
lines changed

15 files changed

+301
-264
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ jobs:
486486
done
487487
488488
array-api-conformity:
489-
needs: test_linux
489+
needs: build_linux
490490
runs-on: ${{ matrix.runner }}
491491

492492
strategy:

dpctl/_sycl_queue.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ from ._sycl_event cimport SyclEvent
2929
from .program._program cimport SyclKernel
3030

3131

32-
cdef void default_async_error_handler(int) nogil except *
32+
cdef void default_async_error_handler(int) except * nogil
3333

3434
cdef public api class _SyclQueue [
3535
object Py_SyclQueueObject, type Py_SyclQueueType

dpctl/tensor/_copy_utils.py

Lines changed: 105 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import builtins
1617
import operator
1718

1819
import numpy as np
@@ -289,6 +290,96 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
289290
_copy_same_shape(dst, src_same_shape)
290291

291292

293+
def _empty_like_orderK(X, dt, usm_type=None, dev=None):
294+
"""Returns empty array like `x`, using order='K'
295+
296+
For an array `x` that was obtained by permutation of a contiguous
297+
array the returned array will have the same shape and the same
298+
strides as `x`.
299+
"""
300+
if not isinstance(X, dpt.usm_ndarray):
301+
raise TypeError(f"Expected usm_ndarray, got {type(X)}")
302+
if usm_type is None:
303+
usm_type = X.usm_type
304+
if dev is None:
305+
dev = X.device
306+
fl = X.flags
307+
if fl["C"] or X.size <= 1:
308+
return dpt.empty_like(
309+
X, dtype=dt, usm_type=usm_type, device=dev, order="C"
310+
)
311+
elif fl["F"]:
312+
return dpt.empty_like(
313+
X, dtype=dt, usm_type=usm_type, device=dev, order="F"
314+
)
315+
st = list(X.strides)
316+
perm = sorted(
317+
range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True
318+
)
319+
inv_perm = sorted(range(X.ndim), key=lambda i: perm[i])
320+
st_sorted = [st[i] for i in perm]
321+
sh = X.shape
322+
sh_sorted = tuple(sh[i] for i in perm)
323+
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
324+
if min(st_sorted) < 0:
325+
sl = tuple(
326+
slice(None, None, -1)
327+
if st_sorted[i] < 0
328+
else slice(None, None, None)
329+
for i in range(X.ndim)
330+
)
331+
R = R[sl]
332+
return dpt.permute_dims(R, inv_perm)
333+
334+
335+
def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
336+
if not isinstance(X1, dpt.usm_ndarray):
337+
raise TypeError(f"Expected usm_ndarray, got {type(X1)}")
338+
if not isinstance(X2, dpt.usm_ndarray):
339+
raise TypeError(f"Expected usm_ndarray, got {type(X2)}")
340+
nd1 = X1.ndim
341+
nd2 = X2.ndim
342+
if nd1 > nd2 and X1.shape == res_shape:
343+
return _empty_like_orderK(X1, dt, usm_type, dev)
344+
elif nd1 < nd2 and X2.shape == res_shape:
345+
return _empty_like_orderK(X2, dt, usm_type, dev)
346+
fl1 = X1.flags
347+
fl2 = X2.flags
348+
if fl1["C"] or fl2["C"]:
349+
return dpt.empty(
350+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C"
351+
)
352+
if fl1["F"] and fl2["F"]:
353+
return dpt.empty(
354+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F"
355+
)
356+
st1 = list(X1.strides)
357+
st2 = list(X2.strides)
358+
max_ndim = max(nd1, nd2)
359+
st1 += [0] * (max_ndim - len(st1))
360+
st2 += [0] * (max_ndim - len(st2))
361+
perm = sorted(
362+
range(max_ndim),
363+
key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])),
364+
reverse=True,
365+
)
366+
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
367+
st1_sorted = [st1[i] for i in perm]
368+
st2_sorted = [st2[i] for i in perm]
369+
sh = res_shape
370+
sh_sorted = tuple(sh[i] for i in perm)
371+
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
372+
if max(min(st1_sorted), min(st2_sorted)) < 0:
373+
sl = tuple(
374+
slice(None, None, -1)
375+
if (st1_sorted[i] < 0 and st2_sorted[i] < 0)
376+
else slice(None, None, None)
377+
for i in range(nd1)
378+
)
379+
R = R[sl]
380+
return dpt.permute_dims(R, inv_perm)
381+
382+
292383
def copy(usm_ary, order="K"):
293384
"""copy(ary, order="K")
294385
@@ -334,28 +425,15 @@ def copy(usm_ary, order="K"):
334425
"Unrecognized value of the order keyword. "
335426
"Recognized values are 'A', 'C', 'F', or 'K'"
336427
)
337-
c_contig = usm_ary.flags.c_contiguous
338-
f_contig = usm_ary.flags.f_contiguous
339-
R = dpt.usm_ndarray(
340-
usm_ary.shape,
341-
dtype=usm_ary.dtype,
342-
buffer=usm_ary.usm_type,
343-
order=copy_order,
344-
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
345-
)
346-
if order == "K" and (not c_contig and not f_contig):
347-
original_strides = usm_ary.strides
348-
ind = sorted(
349-
range(usm_ary.ndim),
350-
key=lambda i: abs(original_strides[i]),
351-
reverse=True,
352-
)
353-
new_strides = tuple(R.strides[ind[i]] for i in ind)
428+
if order == "K":
429+
R = _empty_like_orderK(usm_ary, usm_ary.dtype)
430+
else:
354431
R = dpt.usm_ndarray(
355432
usm_ary.shape,
356433
dtype=usm_ary.dtype,
357-
buffer=R.usm_data,
358-
strides=new_strides,
434+
buffer=usm_ary.usm_type,
435+
order=copy_order,
436+
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
359437
)
360438
_copy_same_shape(R, usm_ary)
361439
return R
@@ -432,26 +510,15 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
432510
"Unrecognized value of the order keyword. "
433511
"Recognized values are 'A', 'C', 'F', or 'K'"
434512
)
435-
R = dpt.usm_ndarray(
436-
usm_ary.shape,
437-
dtype=target_dtype,
438-
buffer=usm_ary.usm_type,
439-
order=copy_order,
440-
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
441-
)
442-
if order == "K" and (not c_contig and not f_contig):
443-
original_strides = usm_ary.strides
444-
ind = sorted(
445-
range(usm_ary.ndim),
446-
key=lambda i: abs(original_strides[i]),
447-
reverse=True,
448-
)
449-
new_strides = tuple(R.strides[ind[i]] for i in ind)
513+
if order == "K":
514+
R = _empty_like_orderK(usm_ary, target_dtype)
515+
else:
450516
R = dpt.usm_ndarray(
451517
usm_ary.shape,
452518
dtype=target_dtype,
453-
buffer=R.usm_data,
454-
strides=new_strides,
519+
buffer=usm_ary.usm_type,
520+
order=copy_order,
521+
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
455522
)
456523
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
457524
return R
@@ -492,6 +559,8 @@ def _extract_impl(ary, ary_mask, axis=0):
492559
dst = dpt.empty(
493560
dst_shape, dtype=ary.dtype, usm_type=ary.usm_type, device=ary.device
494561
)
562+
if dst.size == 0:
563+
return dst
495564
hev, _ = ti._extract(
496565
src=ary,
497566
cumsum=cumsum,

dpctl/tensor/_elementwise_common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@
2626
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
2727
from dpctl.utils import ExecutionPlacementError
2828

29+
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
2930
from ._type_utils import (
3031
_acceptance_fn_default,
31-
_empty_like_orderK,
32-
_empty_like_pair_orderK,
3332
_find_buf_dtype,
3433
_find_buf_dtype2,
3534
_find_inplace_dtype,

dpctl/tensor/_manipulation_functions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import dpctl.tensor._tensor_impl as ti
2626
import dpctl.utils as dputils
2727

28+
from ._type_utils import _to_device_supported_dtype
29+
2830
__doc__ = (
2931
"Implementation module for array manipulation "
3032
"functions in :module:`dpctl.tensor`"
@@ -504,8 +506,10 @@ def _arrays_validation(arrays, check_ndim=True):
504506
_supported_dtype(Xi.dtype for Xi in arrays)
505507

506508
res_dtype = X0.dtype
509+
dev = exec_q.sycl_device
507510
for i in range(1, n):
508511
res_dtype = np.promote_types(res_dtype, arrays[i])
512+
res_dtype = _to_device_supported_dtype(res_dtype, dev)
509513

510514
if check_ndim:
511515
for i in range(1, n):
@@ -554,8 +558,13 @@ def _concat_axis_None(arrays):
554558
sycl_queue=exec_q,
555559
)
556560
else:
561+
src_ = array
562+
# _copy_usm_ndarray_for_reshape requires src and dst to have
563+
# the same data type
564+
if not array.dtype == res_dtype:
565+
src_ = dpt.astype(src_, res_dtype)
557566
hev, _ = ti._copy_usm_ndarray_for_reshape(
558-
src=array,
567+
src=src_,
559568
dst=res[fill_start:fill_end],
560569
shift=0,
561570
sycl_queue=exec_q,

dpctl/tensor/_slicing.pxi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,13 @@ cdef Py_ssize_t _slice_len(
3333
if sl_start == sl_stop:
3434
return 0
3535
if sl_step > 0:
36+
if sl_start > sl_stop:
37+
return 0
3638
# 1 + argmax k such htat sl_start + sl_step*k < sl_stop
3739
return 1 + ((sl_stop - sl_start - 1) // sl_step)
3840
else:
41+
if sl_start < sl_stop:
42+
return 0
3943
return 1 + ((sl_stop - sl_start + 1) // sl_step)
4044

4145

@@ -221,6 +225,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
221225
k_new = k + ellipses_count
222226
new_shape.extend(shape[k:k_new])
223227
new_strides.extend(strides[k:k_new])
228+
if any(dim == 0 for dim in shape[k:k_new]):
229+
is_empty = True
230+
new_offset = offset
224231
k = k_new
225232
elif ind_i is None:
226233
new_shape.append(1)
@@ -236,6 +243,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
236243
new_offset = new_offset + sl_start * strides[k]
237244
if sh_i == 0:
238245
is_empty = True
246+
new_offset = offset
239247
k = k_new
240248
elif _is_boolean(ind_i):
241249
new_shape.append(1 if ind_i else 0)

dpctl/tensor/_stride_utils.pxi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ cdef int _from_input_shape_strides(
6464
cdef int j
6565
cdef bint all_incr = 1
6666
cdef bint all_decr = 1
67+
cdef bint all_incr_modified = 0
68+
cdef bint all_decr_modified = 0
6769
cdef Py_ssize_t elem_count = 1
6870
cdef Py_ssize_t min_shift = 0
6971
cdef Py_ssize_t max_shift = 0
@@ -166,12 +168,14 @@ cdef int _from_input_shape_strides(
166168
j = j + 1
167169
if j < nd:
168170
if all_incr:
171+
all_incr_modified = 1
169172
all_incr = (
170173
(strides_arr[i] > 0) and
171174
(strides_arr[j] > 0) and
172175
(strides_arr[i] <= strides_arr[j])
173176
)
174177
if all_decr:
178+
all_decr_modified = 1
175179
all_decr = (
176180
(strides_arr[i] > 0) and
177181
(strides_arr[j] > 0) and
@@ -180,6 +184,10 @@ cdef int _from_input_shape_strides(
180184
i = j
181185
else:
182186
break
187+
# should only set contig flags on actually obtained
188+
# values, rather than default values
189+
all_incr = all_incr and all_incr_modified
190+
all_decr = all_decr and all_decr_modified
183191
if all_incr and all_decr:
184192
contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)
185193
elif all_incr:

0 commit comments

Comments
 (0)