@@ -56,6 +56,7 @@ from ._backend cimport ( # noqa: E211
56
56
DPCTLSyclEventRef,
57
57
_arg_data_type,
58
58
_backend_type,
59
+ _md_local_accessor,
59
60
_queue_property_type,
60
61
)
61
62
from .memory._memory cimport _Memory
@@ -121,6 +122,47 @@ cdef class kernel_arg_type_attribute:
121
122
return self .attr_value
122
123
123
124
125
+ cdef class LocalAccessor:
126
+ cdef _md_local_accessor lacc
127
+
128
+ def __cinit__ (self , size_t ndim , str type , size_t dim0 , size_t dim1 , size_t dim2 ):
129
+ self .lacc.ndim = ndim
130
+ self .lacc.dim0 = dim0
131
+ self .lacc.dim1 = dim1
132
+ self .lacc.dim2 = dim2
133
+
134
+ if ndim < 1 or ndim > 3 :
135
+ raise ValueError
136
+ if type == ' i1' :
137
+ self .lacc.dpctl_type_id = _arg_data_type._INT8_T
138
+ elif type == ' u1' :
139
+ self .lacc.dpctl_type_id = _arg_data_type._UINT8_T
140
+ elif type == ' i2' :
141
+ self .lacc.dpctl_type_id = _arg_data_type._INT16_T
142
+ elif type == ' u2' :
143
+ self .lacc.dpctl_type_id = _arg_data_type._UINT16_T
144
+ elif type == ' i4' :
145
+ self .lacc.dpctl_type_id = _arg_data_type._INT32_T
146
+ elif type == ' u4' :
147
+ self .lacc.dpctl_type_id = _arg_data_type._UINT32_T
148
+ elif type == ' i8' :
149
+ self .lacc.dpctl_type_id = _arg_data_type._INT64_T
150
+ elif type == ' u8' :
151
+ self .lacc.dpctl_type_id = _arg_data_type._UINT64_T
152
+ elif type == ' f4' :
153
+ self .lacc.dpctl_type_id = _arg_data_type._FLOAT
154
+ elif type == ' f8' :
155
+ self .lacc.dpctl_type_id = _arg_data_type._DOUBLE
156
+ else :
157
+ raise ValueError (f" Unrecornigzed type value: '{type}'" )
158
+
159
+ def __repr__ (self ):
160
+ return " LocalAccessor(" + self .ndim + " )"
161
+
162
+ cdef size_t addressof(self ):
163
+ return < size_t> & self .lacc
164
+
165
+
124
166
cdef class _kernel_arg_type:
125
167
"""
126
168
An enumeration of supported kernel argument types in
@@ -849,6 +891,9 @@ cdef class SyclQueue(_SyclQueue):
849
891
elif isinstance (arg, _Memory):
850
892
kargs[idx]= < void * > (< size_t> arg._pointer)
851
893
kargty[idx] = _arg_data_type._VOID_PTR
894
+ elif isinstance (arg, LocalAccessor):
895
+ kargs[idx] = < void * > ((< LocalAccessor> arg).addressof())
896
+ kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
852
897
else :
853
898
ret = - 1
854
899
return ret
0 commit comments