Skip to content

Commit a38e0f2

Browse files
authored
Merge pull request #2003 from IntelPython/fix-possible-overflow-itemsize
Fix possible `itemsize` overflows in `usm_ndarray` from-pointer constructors
2 parents 7624075 + ba7c090 commit a38e0f2

File tree

2 files changed

+77
-4
lines changed

2 files changed

+77
-4
lines changed

dpctl/tensor/_usmarray.pyx

+8-4
Original file line numberDiff line numberDiff line change
@@ -1777,8 +1777,10 @@ cdef api object UsmNDArray_MakeSimpleFromPtr(
17771777
Returns:
17781778
Created usm_ndarray instance
17791779
"""
1780-
cdef size_t itemsize = type_bytesize(typenum)
1781-
cdef size_t nbytes = itemsize * nelems
1780+
cdef int itemsize = type_bytesize(typenum)
1781+
if (itemsize < 1):
1782+
raise ValueError("dtype with typenum=" + str(typenum) + " is not supported.")
1783+
cdef size_t nbytes = (<size_t> itemsize) * nelems
17821784
cdef c_dpmem._Memory mobj = c_dpmem._Memory.create_from_usm_pointer_size_qref(
17831785
ptr, nbytes, QRef, memory_owner=owner
17841786
)
@@ -1817,7 +1819,7 @@ cdef api object UsmNDArray_MakeFromPtr(
18171819
Returns:
18181820
Created usm_ndarray instance
18191821
"""
1820-
cdef size_t itemsize = type_bytesize(typenum)
1822+
cdef int itemsize = type_bytesize(typenum)
18211823
cdef int err = 0
18221824
cdef size_t nelems = 1
18231825
cdef Py_ssize_t min_disp = 0
@@ -1830,6 +1832,8 @@ cdef api object UsmNDArray_MakeFromPtr(
18301832
cdef object obj_shape
18311833
cdef object obj_strides
18321834

1835+
if (itemsize < 1):
1836+
raise ValueError("dtype with typenum=" + str(typenum) + " is not supported.")
18331837
if (nd < 0):
18341838
raise ValueError("Dimensionality must be non-negative")
18351839
if (ptr is NULL or QRef is NULL):
@@ -1881,7 +1885,7 @@ cdef api object UsmNDArray_MakeFromPtr(
18811885
raise ValueError(
18821886
"Given shape, strides and offset reference out-of-bound memory"
18831887
)
1884-
nbytes = itemsize * (offset + max_disp + 1)
1888+
nbytes = (<size_t> itemsize) * (offset + max_disp + 1)
18851889
mobj = c_dpmem._Memory.create_from_usm_pointer_size_qref(
18861890
ptr, nbytes, QRef, memory_owner=owner
18871891
)

dpctl/tests/test_usm_ndarray_ctor.py

+69
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,75 @@ def test_pyx_capi_make_general():
963963
assert zd_arr._pointer == mat._pointer
964964

965965

966+
def test_pyx_capi_make_fns_invalid_typenum():
967+
q = get_queue_or_skip()
968+
usm_ndarray = dpt.empty(tuple(), dtype="i4", sycl_queue=q)
969+
970+
make_simple_from_ptr = _pyx_capi_fnptr_to_callable(
971+
usm_ndarray,
972+
"UsmNDArray_MakeSimpleFromPtr",
973+
b"PyObject *(size_t, int, DPCTLSyclUSMRef, "
974+
b"DPCTLSyclQueueRef, PyObject *)",
975+
fn_restype=ctypes.py_object,
976+
fn_argtypes=(
977+
ctypes.c_size_t,
978+
ctypes.c_int,
979+
ctypes.c_void_p,
980+
ctypes.c_void_p,
981+
ctypes.py_object,
982+
),
983+
)
984+
985+
nelems = 10
986+
dtype = dpt.int64
987+
arr = dpt.arange(nelems, dtype=dtype, sycl_queue=q)
988+
989+
with pytest.raises(ValueError):
990+
make_simple_from_ptr(
991+
ctypes.c_size_t(nelems),
992+
-1,
993+
arr._pointer,
994+
arr.sycl_queue.addressof_ref(),
995+
arr,
996+
)
997+
998+
make_from_ptr = _pyx_capi_fnptr_to_callable(
999+
usm_ndarray,
1000+
"UsmNDArray_MakeFromPtr",
1001+
b"PyObject *(int, Py_ssize_t const *, int, Py_ssize_t const *, "
1002+
b"DPCTLSyclUSMRef, DPCTLSyclQueueRef, Py_ssize_t, PyObject *)",
1003+
fn_restype=ctypes.py_object,
1004+
fn_argtypes=(
1005+
ctypes.c_int,
1006+
ctypes.POINTER(ctypes.c_ssize_t),
1007+
ctypes.c_int,
1008+
ctypes.POINTER(ctypes.c_ssize_t),
1009+
ctypes.c_void_p,
1010+
ctypes.c_void_p,
1011+
ctypes.c_ssize_t,
1012+
ctypes.py_object,
1013+
),
1014+
)
1015+
c_shape = (ctypes.c_ssize_t * 1)(
1016+
nelems,
1017+
)
1018+
c_strides = (ctypes.c_ssize_t * 1)(
1019+
1,
1020+
)
1021+
with pytest.raises(ValueError):
1022+
make_from_ptr(
1023+
ctypes.c_int(1),
1024+
c_shape,
1025+
-1,
1026+
c_strides,
1027+
arr._pointer,
1028+
arr.sycl_queue.addressof_ref(),
1029+
ctypes.c_ssize_t(0),
1030+
arr,
1031+
)
1032+
del arr
1033+
1034+
9661035
def _pyx_capi_int(X, pyx_capi_name, caps_name=b"int", val_restype=ctypes.c_int):
9671036
import sys
9681037

0 commit comments

Comments
 (0)