Skip to content

Commit 544cd87

Browse files
vtavanandgrigorian
andauthored
Updating repr() function (#2067)
* print shape for large arrays * updating CHANGELOG.md * remove empty line * Update dpctl/tensor/_print.py Co-authored-by: ndgrigorian <[email protected]> --------- Co-authored-by: ndgrigorian <[email protected]>
1 parent fa4eaa7 commit 544cd87

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
### Changed
1414

1515
* Support for Boolean data-type is added to `dpctl.tensor.ceil`, `dpctl.tensor.floor`, and `dpctl.tensor.trunc` [gh-2033](https://github.com/IntelPython/dpctl/pull/2033)
16-
* Changed implementation of `DPCTLPlatform_GetDefaultContext` from using deprecated `ext_oneapi_get_default_context` to `khr_get_default_context` [#2042](https://github.com/IntelPython/dpctl/pull/2042).
16+
* Changed implementation of `DPCTLPlatform_GetDefaultContext` from using deprecated `ext_oneapi_get_default_context` to `khr_get_default_context` [#2042](https://github.com/IntelPython/dpctl/pull/2042)
17+
* Updated `repr` to show the shape of the abbreviated arrays and show the shape and data type of zero-size arrays [#2067](https://github.com/IntelPython/dpctl/pull/2067)
1718

1819
### Fixed
1920

dpctl/tensor/_print.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@
4040
}
4141

4242

43+
def _move_to_next_line(string, s, line_width, prefix):
44+
"""
45+
Move string to next line if it doesn't fit in the current line.
46+
"""
47+
bottom_len = len(s) - (s.rfind("\n") + 1)
48+
next_line = bottom_len + len(string) + 1 > line_width
49+
string = ",\n" + " " * len(prefix) + string if next_line else ", " + string
50+
51+
return string
52+
53+
4354
def _options_dict(
4455
linewidth=None,
4556
edgeitems=None,
@@ -463,16 +474,18 @@ def usm_ndarray_repr(
463474
suffix=suffix,
464475
)
465476

466-
if show_dtype:
467-
dtype_str = "dtype={}".format(x.dtype.name)
468-
bottom_len = len(s) - (s.rfind("\n") + 1)
469-
next_line = bottom_len + len(dtype_str) + 1 > line_width
470-
dtype_str = (
471-
",\n" + " " * len(prefix) + dtype_str
472-
if next_line
473-
else ", " + dtype_str
474-
)
477+
if show_dtype or x.size == 0:
478+
dtype_str = f"dtype={x.dtype.name}"
479+
dtype_str = _move_to_next_line(dtype_str, s, line_width, prefix)
475480
else:
476481
dtype_str = ""
477482

478-
return prefix + s + dtype_str + suffix
483+
options = get_print_options()
484+
threshold = options["threshold"]
485+
if (x.size == 0 and x.shape != (0,)) or x.size > threshold:
486+
shape_str = f"shape={x.shape}"
487+
shape_str = _move_to_next_line(shape_str, s, line_width, prefix)
488+
else:
489+
shape_str = ""
490+
491+
return prefix + s + shape_str + dtype_str + suffix

dpctl/tests/test_usm_ndarray_print.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,7 @@ def test_print_repr(self):
282282
)
283283

284284
x = dpt.arange(4, dtype="i4", sycl_queue=q)
285-
x.sycl_queue.wait()
286-
r = repr(x)
287-
assert r == "usm_ndarray([0, 1, 2, 3], dtype=int32)"
285+
assert repr(x) == "usm_ndarray([0, 1, 2, 3], dtype=int32)"
288286

289287
dpt.set_print_options(linewidth=1)
290288
np.testing.assert_equal(
@@ -296,30 +294,35 @@ def test_print_repr(self):
296294
"\n dtype=int32)",
297295
)
298296

297+
# zero-size array
298+
dpt.set_print_options(linewidth=75)
299+
x = dpt.ones((9, 0), dtype="i4", sycl_queue=q)
300+
assert repr(x) == "usm_ndarray([], shape=(9, 0), dtype=int32)"
301+
299302
def test_print_repr_abbreviated(self):
300303
q = get_queue_or_skip()
301304

302305
dpt.set_print_options(threshold=0, edgeitems=1)
303306
x = dpt.arange(9, dtype="int64", sycl_queue=q)
304-
assert repr(x) == "usm_ndarray([0, ..., 8])"
307+
assert repr(x) == "usm_ndarray([0, ..., 8], shape=(9,))"
305308

306309
y = dpt.asarray(x, dtype="i4", copy=True)
307-
assert repr(y) == "usm_ndarray([0, ..., 8], dtype=int32)"
310+
assert repr(y) == "usm_ndarray([0, ..., 8], shape=(9,), dtype=int32)"
308311

309312
x = dpt.reshape(x, (3, 3))
310313
np.testing.assert_equal(
311314
repr(x),
312315
"usm_ndarray([[0, ..., 2],"
313316
"\n ...,"
314-
"\n [6, ..., 8]])",
317+
"\n [6, ..., 8]], shape=(3, 3))",
315318
)
316319

317320
y = dpt.reshape(y, (3, 3))
318321
np.testing.assert_equal(
319322
repr(y),
320323
"usm_ndarray([[0, ..., 2],"
321324
"\n ...,"
322-
"\n [6, ..., 8]], dtype=int32)",
325+
"\n [6, ..., 8]], shape=(3, 3), dtype=int32)",
323326
)
324327

325328
dpt.set_print_options(linewidth=1)
@@ -332,6 +335,7 @@ def test_print_repr_abbreviated(self):
332335
"\n [6,"
333336
"\n ...,"
334337
"\n 8]],"
338+
"\n shape=(3, 3),"
335339
"\n dtype=int32)",
336340
)
337341

0 commit comments

Comments
 (0)