3
3
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4
4
5
5
6
+ from warnings import warn
7
+
6
8
from cuda .core .experimental ._utils import driver , get_binding_version , handle_return , precondition
7
9
8
10
_backend = {
9
11
"old" : {
10
12
"file" : driver .cuModuleLoad ,
11
13
"data" : driver .cuModuleLoadDataEx ,
12
14
"kernel" : driver .cuModuleGetFunction ,
15
+ "attribute" : driver .cuFuncGetAttribute ,
13
16
},
14
17
}
15
18
@@ -34,6 +37,7 @@ def _lazy_init():
34
37
"file" : driver .cuLibraryLoadFromFile ,
35
38
"data" : driver .cuLibraryLoadData ,
36
39
"kernel" : driver .cuLibraryGetKernel ,
40
+ "attribute" : driver .cuKernelGetAttribute ,
37
41
}
38
42
_kernel_ctypes = (driver .CUfunction , driver .CUkernel )
39
43
else :
@@ -42,6 +46,136 @@ def _lazy_init():
42
46
_inited = True
43
47
44
48
49
+ class KernelAttributes :
50
+ def __init__ (self ):
51
+ raise RuntimeError ("KernelAttributes should not be instantiated directly" )
52
+
53
+ slots = ("_handle" , "_cache" , "_backend_version" , "_loader" )
54
+
55
+ def _init (handle ):
56
+ self = KernelAttributes .__new__ (KernelAttributes )
57
+ self ._handle = handle
58
+ self ._cache = {}
59
+
60
+ self ._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000 ) else "old"
61
+ self ._loader = _backend [self ._backend_version ]
62
+ return self
63
+
64
+ def _get_cached_attribute (self , device_id : int , attribute : driver .CUfunction_attribute ) -> int :
65
+ """Helper function to get a cached attribute or fetch and cache it if not present."""
66
+ if device_id in self ._cache and attribute in self ._cache [device_id ]:
67
+ return self ._cache [device_id ][attribute ]
68
+ if self ._backend_version == "new" :
69
+ result = handle_return (self ._loader ["attribute" ](attribute , self ._handle , device_id ))
70
+ else : # "old" backend
71
+ warn (
72
+ "Device ID argument is ignored when getting attribute from kernel when cuda version < 12. " ,
73
+ RuntimeWarning ,
74
+ stacklevel = 2 ,
75
+ )
76
+ result = handle_return (self ._loader ["attribute" ](attribute , self ._handle ))
77
+ if device_id not in self ._cache :
78
+ self ._cache [device_id ] = {}
79
+ self ._cache [device_id ][attribute ] = result
80
+ return result
81
+
82
+ def max_threads_per_block (self , device_id : int = None ) -> int :
83
+ """int : The maximum number of threads per block.
84
+ This attribute is read-only."""
85
+ return self ._get_cached_attribute (
86
+ device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK
87
+ )
88
+
89
+ def shared_size_bytes (self , device_id : int = None ) -> int :
90
+ """int : The size in bytes of statically-allocated shared memory required by this function.
91
+ This attribute is read-only."""
92
+ return self ._get_cached_attribute (device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES )
93
+
94
+ def const_size_bytes (self , device_id : int = None ) -> int :
95
+ """int : The size in bytes of user-allocated constant memory required by this function.
96
+ This attribute is read-only."""
97
+ return self ._get_cached_attribute (device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES )
98
+
99
+ def local_size_bytes (self , device_id : int = None ) -> int :
100
+ """int : The size in bytes of local memory used by each thread of this function.
101
+ This attribute is read-only."""
102
+ return self ._get_cached_attribute (device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES )
103
+
104
+ def num_regs (self , device_id : int = None ) -> int :
105
+ """int : The number of registers used by each thread of this function.
106
+ This attribute is read-only."""
107
+ return self ._get_cached_attribute (device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_NUM_REGS )
108
+
109
+ def ptx_version (self , device_id : int = None ) -> int :
110
+ """int : The PTX virtual architecture version for which the function was compiled.
111
+ This attribute is read-only."""
112
+ return self ._get_cached_attribute (device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_PTX_VERSION )
113
+
114
+ def binary_version (self , device_id : int = None ) -> int :
115
+ """int : The binary architecture version for which the function was compiled.
116
+ This attribute is read-only."""
117
+ return self ._get_cached_attribute (device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_BINARY_VERSION )
118
+
119
+ def cache_mode_ca (self , device_id : int = None ) -> bool :
120
+ """bool : Whether the function has been compiled with user specified option "-Xptxas --dlcm=ca" set.
121
+ This attribute is read-only."""
122
+ return bool (self ._get_cached_attribute (device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_CACHE_MODE_CA ))
123
+
124
+ def max_dynamic_shared_size_bytes (self , device_id : int = None ) -> int :
125
+ """int : The maximum size in bytes of dynamically-allocated shared memory that can be used
126
+ by this function."""
127
+ return self ._get_cached_attribute (
128
+ device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES
129
+ )
130
+
131
+ def preferred_shared_memory_carveout (self , device_id : int = None ) -> int :
132
+ """int : The shared memory carveout preference, in percent of the total shared memory."""
133
+ return self ._get_cached_attribute (
134
+ device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT
135
+ )
136
+
137
+ def cluster_size_must_be_set (self , device_id : int = None ) -> bool :
138
+ """bool : The kernel must launch with a valid cluster size specified.
139
+ This attribute is read-only."""
140
+ return bool (
141
+ self ._get_cached_attribute (
142
+ device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET
143
+ )
144
+ )
145
+
146
+ def required_cluster_width (self , device_id : int = None ) -> int :
147
+ """int : The required cluster width in blocks."""
148
+ return self ._get_cached_attribute (
149
+ device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH
150
+ )
151
+
152
+ def required_cluster_height (self , device_id : int = None ) -> int :
153
+ """int : The required cluster height in blocks."""
154
+ return self ._get_cached_attribute (
155
+ device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT
156
+ )
157
+
158
+ def required_cluster_depth (self , device_id : int = None ) -> int :
159
+ """int : The required cluster depth in blocks."""
160
+ return self ._get_cached_attribute (
161
+ device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH
162
+ )
163
+
164
+ def non_portable_cluster_size_allowed (self , device_id : int = None ) -> bool :
165
+ """bool : Whether the function can be launched with non-portable cluster size."""
166
+ return bool (
167
+ self ._get_cached_attribute (
168
+ device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED
169
+ )
170
+ )
171
+
172
+ def cluster_scheduling_policy_preference (self , device_id : int = None ) -> int :
173
+ """int : The block scheduling policy of a function."""
174
+ return self ._get_cached_attribute (
175
+ device_id , driver .CUfunction_attribute .CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE
176
+ )
177
+
178
+
45
179
class Kernel :
46
180
"""Represent a compiled kernel that had been loaded onto the device.
47
181
@@ -53,13 +187,10 @@ class Kernel:
53
187
54
188
"""
55
189
56
- __slots__ = (
57
- "_handle" ,
58
- "_module" ,
59
- )
190
+ __slots__ = ("_handle" , "_module" , "_attributes" )
60
191
61
192
def __init__ (self ):
62
- raise NotImplementedError ("directly constructing a Kernel instance is not supported" )
193
+ raise RuntimeError ("directly constructing a Kernel instance is not supported" )
63
194
64
195
@staticmethod
65
196
def _from_obj (obj , mod ):
@@ -68,8 +199,16 @@ def _from_obj(obj, mod):
68
199
ker = Kernel .__new__ (Kernel )
69
200
ker ._handle = obj
70
201
ker ._module = mod
202
+ ker ._attributes = None
71
203
return ker
72
204
205
+ @property
206
+ def attributes (self ):
207
+ """Get the read-only attributes of this kernel."""
208
+ if self ._attributes is None :
209
+ self ._attributes = KernelAttributes ._init (self ._handle )
210
+ return self ._attributes
211
+
73
212
# TODO: implement from_handle()
74
213
75
214
0 commit comments