@@ -19,7 +19,7 @@ from cuda.core._utils import handle_return
19
19
20
20
21
21
@cython.dataclasses.dataclass
22
- cdef class GPUMemoryView :
22
+ cdef class StridedMemoryView :
23
23
24
24
# TODO: switch to use Cython's cdef typing?
25
25
ptr: int = None
@@ -43,14 +43,14 @@ cdef class GPUMemoryView:
43
43
pass
44
44
45
45
def __repr__ (self ):
46
- return (f" GPUMemoryView (ptr={self.ptr},\n "
47
- + f" shape={self.shape},\n "
48
- + f" strides={self.strides},\n "
49
- + f" dtype={get_simple_repr(self.dtype)},\n "
50
- + f" device_id={self.device_id},\n "
51
- + f" device_accessible={self.device_accessible},\n "
52
- + f" readonly={self.readonly},\n "
53
- + f" obj={get_simple_repr(self.obj)})" )
46
+ return (f" StridedMemoryView (ptr={self.ptr},\n "
47
+ + f" shape={self.shape},\n "
48
+ + f" strides={self.strides},\n "
49
+ + f" dtype={get_simple_repr(self.dtype)},\n "
50
+ + f" device_id={self.device_id},\n "
51
+ + f" device_accessible={self.device_accessible},\n "
52
+ + f" readonly={self.readonly},\n "
53
+ + f" obj={get_simple_repr(self.obj)})" )
54
54
55
55
56
56
cdef str get_simple_repr(obj):
@@ -80,7 +80,7 @@ cdef bint check_has_dlpack(obj) except*:
80
80
return has_dlpack
81
81
82
82
83
- cdef class _GPUMemoryViewProxy :
83
+ cdef class _StridedMemoryViewProxy :
84
84
85
85
cdef:
86
86
object obj
@@ -90,14 +90,14 @@ cdef class _GPUMemoryViewProxy:
90
90
self .obj = obj
91
91
self .has_dlpack = check_has_dlpack(obj)
92
92
93
- cpdef GPUMemoryView view(self , stream_ptr = None ):
93
+ cpdef StridedMemoryView view(self , stream_ptr = None ):
94
94
if self .has_dlpack:
95
95
return view_as_dlpack(self .obj, stream_ptr)
96
96
else :
97
97
return view_as_cai(self .obj, stream_ptr)
98
98
99
99
100
- cdef GPUMemoryView view_as_dlpack(obj, stream_ptr, view = None ):
100
+ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view = None ):
101
101
cdef int dldevice, device_id, i
102
102
cdef bint device_accessible, versioned, is_readonly
103
103
dldevice, device_id = obj.__dlpack_device__()
@@ -160,7 +160,7 @@ cdef GPUMemoryView view_as_dlpack(obj, stream_ptr, view=None):
160
160
dl_tensor = & dlm_tensor.dl_tensor
161
161
is_readonly = False
162
162
163
- cdef GPUMemoryView buf = GPUMemoryView () if view is None else view
163
+ cdef StridedMemoryView buf = StridedMemoryView () if view is None else view
164
164
buf.ptr = < intptr_t> (dl_tensor.data)
165
165
buf.shape = tuple (int (dl_tensor.shape[i]) for i in range (dl_tensor.ndim))
166
166
if dl_tensor.strides:
@@ -242,7 +242,7 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
242
242
return numpy.dtype(np_dtype)
243
243
244
244
245
- cdef GPUMemoryView view_as_cai(obj, stream_ptr, view = None ):
245
+ cdef StridedMemoryView view_as_cai(obj, stream_ptr, view = None ):
246
246
cdef dict cai_data = obj.__cuda_array_interface__
247
247
if cai_data[" version" ] < 3 :
248
248
raise BufferError(" only CUDA Array Interface v3 or above is supported" )
@@ -251,7 +251,7 @@ cdef GPUMemoryView view_as_cai(obj, stream_ptr, view=None):
251
251
if stream_ptr is None :
252
252
raise BufferError(" stream=None is ambiguous with view()" )
253
253
254
- cdef GPUMemoryView buf = GPUMemoryView () if view is None else view
254
+ cdef StridedMemoryView buf = StridedMemoryView () if view is None else view
255
255
buf.obj = obj
256
256
buf.ptr, buf.readonly = cai_data[" data" ]
257
257
buf.shape = cai_data[" shape" ]
@@ -291,7 +291,7 @@ def viewable(tuple arg_indices):
291
291
args = list (args)
292
292
cdef int idx
293
293
for idx in arg_indices:
294
- args[idx] = _GPUMemoryViewProxy (args[idx])
294
+ args[idx] = _StridedMemoryViewProxy (args[idx])
295
295
return func(* args, ** kwargs)
296
296
return wrapped_func
297
297
return wrapped_func_with_indices
0 commit comments