Skip to content

Commit d4dfb9c

Browse files
Define LocalAccessor type to use to specify local accessor kernel arguments
LocalAccessor(ndim, elemental_type_str, dim0, dim1, dim2) The elemental type can be one of the following: "i1", "u1", "i2", "u2", "i4", "u4", "i8", "u8", "f4", "f8"
1 parent 4fe79f9 commit d4dfb9c

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

dpctl/_sycl_queue.pyx

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ from ._backend cimport ( # noqa: E211
5656
DPCTLSyclEventRef,
5757
_arg_data_type,
5858
_backend_type,
59+
_md_local_accessor,
5960
_queue_property_type,
6061
)
6162
from .memory._memory cimport _Memory
@@ -121,6 +122,47 @@ cdef class kernel_arg_type_attribute:
121122
return self.attr_value
122123

123124

125+
cdef class LocalAccessor:
126+
cdef _md_local_accessor lacc
127+
128+
def __cinit__(self, size_t ndim, str type, size_t dim0, size_t dim1, size_t dim2):
129+
self.lacc.ndim = ndim
130+
self.lacc.dim0 = dim0
131+
self.lacc.dim1 = dim1
132+
self.lacc.dim2 = dim2
133+
134+
if ndim < 1 or ndim > 3:
135+
raise ValueError
136+
if type == 'i1':
137+
self.lacc.dpctl_type_id = _arg_data_type._INT8_T
138+
elif type == 'u1':
139+
self.lacc.dpctl_type_id = _arg_data_type._UINT8_T
140+
elif type == 'i2':
141+
self.lacc.dpctl_type_id = _arg_data_type._INT16_T
142+
elif type == 'u2':
143+
self.lacc.dpctl_type_id = _arg_data_type._UINT16_T
144+
elif type == 'i4':
145+
self.lacc.dpctl_type_id = _arg_data_type._INT32_T
146+
elif type == 'u4':
147+
self.lacc.dpctl_type_id = _arg_data_type._UINT32_T
148+
elif type == 'i8':
149+
self.lacc.dpctl_type_id = _arg_data_type._INT64_T
150+
elif type == 'u8':
151+
self.lacc.dpctl_type_id = _arg_data_type._UINT64_T
152+
elif type == 'f4':
153+
self.lacc.dpctl_type_id = _arg_data_type._FLOAT
154+
elif type == 'f8':
155+
self.lacc.dpctl_type_id = _arg_data_type._DOUBLE
156+
else:
157+
raise ValueError(f"Unrecornigzed type value: '{type}'")
158+
159+
def __repr__(self):
160+
return "LocalAccessor(" + self.ndim + ")"
161+
162+
cdef size_t addressof(self):
163+
return <size_t>&self.lacc
164+
165+
124166
cdef class _kernel_arg_type:
125167
"""
126168
An enumeration of supported kernel argument types in
@@ -849,6 +891,9 @@ cdef class SyclQueue(_SyclQueue):
849891
elif isinstance(arg, _Memory):
850892
kargs[idx]= <void*>(<size_t>arg._pointer)
851893
kargty[idx] = _arg_data_type._VOID_PTR
894+
elif isinstance(arg, LocalAccessor):
895+
kargs[idx] = <void*>((<LocalAccessor>arg).addressof())
896+
kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
852897
else:
853898
ret = -1
854899
return ret

0 commit comments

Comments
 (0)