|
13 | 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
| 16 | +import builtins |
16 | 17 | import operator
|
17 | 18 |
|
18 | 19 | import numpy as np
|
@@ -289,6 +290,96 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
|
289 | 290 | _copy_same_shape(dst, src_same_shape)
|
290 | 291 |
|
291 | 292 |
|
| 293 | +def _empty_like_orderK(X, dt, usm_type=None, dev=None): |
| 294 | + """Returns empty array like `x`, using order='K' |
| 295 | +
|
| 296 | + For an array `x` that was obtained by permutation of a contiguous |
| 297 | + array the returned array will have the same shape and the same |
| 298 | + strides as `x`. |
| 299 | + """ |
| 300 | + if not isinstance(X, dpt.usm_ndarray): |
| 301 | + raise TypeError(f"Expected usm_ndarray, got {type(X)}") |
| 302 | + if usm_type is None: |
| 303 | + usm_type = X.usm_type |
| 304 | + if dev is None: |
| 305 | + dev = X.device |
| 306 | + fl = X.flags |
| 307 | + if fl["C"] or X.size <= 1: |
| 308 | + return dpt.empty_like( |
| 309 | + X, dtype=dt, usm_type=usm_type, device=dev, order="C" |
| 310 | + ) |
| 311 | + elif fl["F"]: |
| 312 | + return dpt.empty_like( |
| 313 | + X, dtype=dt, usm_type=usm_type, device=dev, order="F" |
| 314 | + ) |
| 315 | + st = list(X.strides) |
| 316 | + perm = sorted( |
| 317 | + range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True |
| 318 | + ) |
| 319 | + inv_perm = sorted(range(X.ndim), key=lambda i: perm[i]) |
| 320 | + st_sorted = [st[i] for i in perm] |
| 321 | + sh = X.shape |
| 322 | + sh_sorted = tuple(sh[i] for i in perm) |
| 323 | + R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") |
| 324 | + if min(st_sorted) < 0: |
| 325 | + sl = tuple( |
| 326 | + slice(None, None, -1) |
| 327 | + if st_sorted[i] < 0 |
| 328 | + else slice(None, None, None) |
| 329 | + for i in range(X.ndim) |
| 330 | + ) |
| 331 | + R = R[sl] |
| 332 | + return dpt.permute_dims(R, inv_perm) |
| 333 | + |
| 334 | + |
| 335 | +def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev): |
| 336 | + if not isinstance(X1, dpt.usm_ndarray): |
| 337 | + raise TypeError(f"Expected usm_ndarray, got {type(X1)}") |
| 338 | + if not isinstance(X2, dpt.usm_ndarray): |
| 339 | + raise TypeError(f"Expected usm_ndarray, got {type(X2)}") |
| 340 | + nd1 = X1.ndim |
| 341 | + nd2 = X2.ndim |
| 342 | + if nd1 > nd2 and X1.shape == res_shape: |
| 343 | + return _empty_like_orderK(X1, dt, usm_type, dev) |
| 344 | + elif nd1 < nd2 and X2.shape == res_shape: |
| 345 | + return _empty_like_orderK(X2, dt, usm_type, dev) |
| 346 | + fl1 = X1.flags |
| 347 | + fl2 = X2.flags |
| 348 | + if fl1["C"] or fl2["C"]: |
| 349 | + return dpt.empty( |
| 350 | + res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C" |
| 351 | + ) |
| 352 | + if fl1["F"] and fl2["F"]: |
| 353 | + return dpt.empty( |
| 354 | + res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F" |
| 355 | + ) |
| 356 | + st1 = list(X1.strides) |
| 357 | + st2 = list(X2.strides) |
| 358 | + max_ndim = max(nd1, nd2) |
| 359 | + st1 += [0] * (max_ndim - len(st1)) |
| 360 | + st2 += [0] * (max_ndim - len(st2)) |
| 361 | + perm = sorted( |
| 362 | + range(max_ndim), |
| 363 | + key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])), |
| 364 | + reverse=True, |
| 365 | + ) |
| 366 | + inv_perm = sorted(range(max_ndim), key=lambda i: perm[i]) |
| 367 | + st1_sorted = [st1[i] for i in perm] |
| 368 | + st2_sorted = [st2[i] for i in perm] |
| 369 | + sh = res_shape |
| 370 | + sh_sorted = tuple(sh[i] for i in perm) |
| 371 | + R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") |
| 372 | + if max(min(st1_sorted), min(st2_sorted)) < 0: |
| 373 | + sl = tuple( |
| 374 | + slice(None, None, -1) |
| 375 | + if (st1_sorted[i] < 0 and st2_sorted[i] < 0) |
| 376 | + else slice(None, None, None) |
| 377 | + for i in range(nd1) |
| 378 | + ) |
| 379 | + R = R[sl] |
| 380 | + return dpt.permute_dims(R, inv_perm) |
| 381 | + |
| 382 | + |
292 | 383 | def copy(usm_ary, order="K"):
|
293 | 384 | """copy(ary, order="K")
|
294 | 385 |
|
@@ -334,28 +425,15 @@ def copy(usm_ary, order="K"):
|
334 | 425 | "Unrecognized value of the order keyword. "
|
335 | 426 | "Recognized values are 'A', 'C', 'F', or 'K'"
|
336 | 427 | )
|
337 |
| - c_contig = usm_ary.flags.c_contiguous |
338 |
| - f_contig = usm_ary.flags.f_contiguous |
339 |
| - R = dpt.usm_ndarray( |
340 |
| - usm_ary.shape, |
341 |
| - dtype=usm_ary.dtype, |
342 |
| - buffer=usm_ary.usm_type, |
343 |
| - order=copy_order, |
344 |
| - buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, |
345 |
| - ) |
346 |
| - if order == "K" and (not c_contig and not f_contig): |
347 |
| - original_strides = usm_ary.strides |
348 |
| - ind = sorted( |
349 |
| - range(usm_ary.ndim), |
350 |
| - key=lambda i: abs(original_strides[i]), |
351 |
| - reverse=True, |
352 |
| - ) |
353 |
| - new_strides = tuple(R.strides[ind[i]] for i in ind) |
| 428 | + if order == "K": |
| 429 | + R = _empty_like_orderK(usm_ary, usm_ary.dtype) |
| 430 | + else: |
354 | 431 | R = dpt.usm_ndarray(
|
355 | 432 | usm_ary.shape,
|
356 | 433 | dtype=usm_ary.dtype,
|
357 |
| - buffer=R.usm_data, |
358 |
| - strides=new_strides, |
| 434 | + buffer=usm_ary.usm_type, |
| 435 | + order=copy_order, |
| 436 | + buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, |
359 | 437 | )
|
360 | 438 | _copy_same_shape(R, usm_ary)
|
361 | 439 | return R
|
@@ -432,26 +510,15 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
|
432 | 510 | "Unrecognized value of the order keyword. "
|
433 | 511 | "Recognized values are 'A', 'C', 'F', or 'K'"
|
434 | 512 | )
|
435 |
| - R = dpt.usm_ndarray( |
436 |
| - usm_ary.shape, |
437 |
| - dtype=target_dtype, |
438 |
| - buffer=usm_ary.usm_type, |
439 |
| - order=copy_order, |
440 |
| - buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, |
441 |
| - ) |
442 |
| - if order == "K" and (not c_contig and not f_contig): |
443 |
| - original_strides = usm_ary.strides |
444 |
| - ind = sorted( |
445 |
| - range(usm_ary.ndim), |
446 |
| - key=lambda i: abs(original_strides[i]), |
447 |
| - reverse=True, |
448 |
| - ) |
449 |
| - new_strides = tuple(R.strides[ind[i]] for i in ind) |
| 513 | + if order == "K": |
| 514 | + R = _empty_like_orderK(usm_ary, target_dtype) |
| 515 | + else: |
450 | 516 | R = dpt.usm_ndarray(
|
451 | 517 | usm_ary.shape,
|
452 | 518 | dtype=target_dtype,
|
453 |
| - buffer=R.usm_data, |
454 |
| - strides=new_strides, |
| 519 | + buffer=usm_ary.usm_type, |
| 520 | + order=copy_order, |
| 521 | + buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, |
455 | 522 | )
|
456 | 523 | _copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
|
457 | 524 | return R
|
@@ -492,6 +559,8 @@ def _extract_impl(ary, ary_mask, axis=0):
|
492 | 559 | dst = dpt.empty(
|
493 | 560 | dst_shape, dtype=ary.dtype, usm_type=ary.usm_type, device=ary.device
|
494 | 561 | )
|
| 562 | + if dst.size == 0: |
| 563 | + return dst |
495 | 564 | hev, _ = ti._extract(
|
496 | 565 | src=ary,
|
497 | 566 | cumsum=cumsum,
|
|
0 commit comments