Skip to content

Commit 7387715

Browse files
authored
Merge pull request #360 from ksimpson-work/kernel-attributes
Kernel attributes
2 parents 2981bfd + 86a536a commit 7387715

File tree

3 files changed

+208
-7
lines changed

3 files changed

+208
-7
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 144 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

55

6+
from warnings import warn
7+
68
from cuda.core.experimental._utils import driver, get_binding_version, handle_return, precondition
79

810
_backend = {
911
"old": {
1012
"file": driver.cuModuleLoad,
1113
"data": driver.cuModuleLoadDataEx,
1214
"kernel": driver.cuModuleGetFunction,
15+
"attribute": driver.cuFuncGetAttribute,
1316
},
1417
}
1518

@@ -34,6 +37,7 @@ def _lazy_init():
3437
"file": driver.cuLibraryLoadFromFile,
3538
"data": driver.cuLibraryLoadData,
3639
"kernel": driver.cuLibraryGetKernel,
40+
"attribute": driver.cuKernelGetAttribute,
3741
}
3842
_kernel_ctypes = (driver.CUfunction, driver.CUkernel)
3943
else:
@@ -42,6 +46,136 @@ def _lazy_init():
4246
_inited = True
4347

4448

49+
class KernelAttributes:
50+
def __init__(self):
51+
raise RuntimeError("KernelAttributes should not be instantiated directly")
52+
53+
slots = ("_handle", "_cache", "_backend_version", "_loader")
54+
55+
def _init(handle):
56+
self = KernelAttributes.__new__(KernelAttributes)
57+
self._handle = handle
58+
self._cache = {}
59+
60+
self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
61+
self._loader = _backend[self._backend_version]
62+
return self
63+
64+
def _get_cached_attribute(self, device_id: int, attribute: driver.CUfunction_attribute) -> int:
65+
"""Helper function to get a cached attribute or fetch and cache it if not present."""
66+
if device_id in self._cache and attribute in self._cache[device_id]:
67+
return self._cache[device_id][attribute]
68+
if self._backend_version == "new":
69+
result = handle_return(self._loader["attribute"](attribute, self._handle, device_id))
70+
else: # "old" backend
71+
warn(
72+
"Device ID argument is ignored when getting attribute from kernel when cuda version < 12. ",
73+
RuntimeWarning,
74+
stacklevel=2,
75+
)
76+
result = handle_return(self._loader["attribute"](attribute, self._handle))
77+
if device_id not in self._cache:
78+
self._cache[device_id] = {}
79+
self._cache[device_id][attribute] = result
80+
return result
81+
82+
def max_threads_per_block(self, device_id: int = None) -> int:
83+
"""int : The maximum number of threads per block.
84+
This attribute is read-only."""
85+
return self._get_cached_attribute(
86+
device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK
87+
)
88+
89+
def shared_size_bytes(self, device_id: int = None) -> int:
90+
"""int : The size in bytes of statically-allocated shared memory required by this function.
91+
This attribute is read-only."""
92+
return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)
93+
94+
def const_size_bytes(self, device_id: int = None) -> int:
95+
"""int : The size in bytes of user-allocated constant memory required by this function.
96+
This attribute is read-only."""
97+
return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES)
98+
99+
def local_size_bytes(self, device_id: int = None) -> int:
100+
"""int : The size in bytes of local memory used by each thread of this function.
101+
This attribute is read-only."""
102+
return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES)
103+
104+
def num_regs(self, device_id: int = None) -> int:
105+
"""int : The number of registers used by each thread of this function.
106+
This attribute is read-only."""
107+
return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS)
108+
109+
def ptx_version(self, device_id: int = None) -> int:
110+
"""int : The PTX virtual architecture version for which the function was compiled.
111+
This attribute is read-only."""
112+
return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_PTX_VERSION)
113+
114+
def binary_version(self, device_id: int = None) -> int:
115+
"""int : The binary architecture version for which the function was compiled.
116+
This attribute is read-only."""
117+
return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_BINARY_VERSION)
118+
119+
def cache_mode_ca(self, device_id: int = None) -> bool:
120+
"""bool : Whether the function has been compiled with user specified option "-Xptxas --dlcm=ca" set.
121+
This attribute is read-only."""
122+
return bool(self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_CACHE_MODE_CA))
123+
124+
def max_dynamic_shared_size_bytes(self, device_id: int = None) -> int:
125+
"""int : The maximum size in bytes of dynamically-allocated shared memory that can be used
126+
by this function."""
127+
return self._get_cached_attribute(
128+
device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES
129+
)
130+
131+
def preferred_shared_memory_carveout(self, device_id: int = None) -> int:
132+
"""int : The shared memory carveout preference, in percent of the total shared memory."""
133+
return self._get_cached_attribute(
134+
device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT
135+
)
136+
137+
def cluster_size_must_be_set(self, device_id: int = None) -> bool:
138+
"""bool : The kernel must launch with a valid cluster size specified.
139+
This attribute is read-only."""
140+
return bool(
141+
self._get_cached_attribute(
142+
device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET
143+
)
144+
)
145+
146+
def required_cluster_width(self, device_id: int = None) -> int:
147+
"""int : The required cluster width in blocks."""
148+
return self._get_cached_attribute(
149+
device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH
150+
)
151+
152+
def required_cluster_height(self, device_id: int = None) -> int:
153+
"""int : The required cluster height in blocks."""
154+
return self._get_cached_attribute(
155+
device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT
156+
)
157+
158+
def required_cluster_depth(self, device_id: int = None) -> int:
159+
"""int : The required cluster depth in blocks."""
160+
return self._get_cached_attribute(
161+
device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH
162+
)
163+
164+
def non_portable_cluster_size_allowed(self, device_id: int = None) -> bool:
165+
"""bool : Whether the function can be launched with non-portable cluster size."""
166+
return bool(
167+
self._get_cached_attribute(
168+
device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED
169+
)
170+
)
171+
172+
def cluster_scheduling_policy_preference(self, device_id: int = None) -> int:
173+
"""int : The block scheduling policy of a function."""
174+
return self._get_cached_attribute(
175+
device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE
176+
)
177+
178+
45179
class Kernel:
46180
"""Represent a compiled kernel that had been loaded onto the device.
47181
@@ -53,13 +187,10 @@ class Kernel:
53187
54188
"""
55189

56-
__slots__ = (
57-
"_handle",
58-
"_module",
59-
)
190+
__slots__ = ("_handle", "_module", "_attributes")
60191

61192
def __init__(self):
62-
raise NotImplementedError("directly constructing a Kernel instance is not supported")
193+
raise RuntimeError("directly constructing a Kernel instance is not supported")
63194

64195
@staticmethod
65196
def _from_obj(obj, mod):
@@ -68,8 +199,16 @@ def _from_obj(obj, mod):
68199
ker = Kernel.__new__(Kernel)
69200
ker._handle = obj
70201
ker._module = mod
202+
ker._attributes = None
71203
return ker
72204

205+
@property
206+
def attributes(self):
207+
"""Get the read-only attributes of this kernel."""
208+
if self._attributes is None:
209+
self._attributes = KernelAttributes._init(self._handle)
210+
return self._attributes
211+
73212
# TODO: implement from_handle()
74213

75214

cuda_core/docs/source/release/0.2.0-notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Highlights
99
----------
1010

1111
- Add :class:`~ProgramOptions` to facilitate the passing of runtime compile options to :obj:`~Program`.
12+
- Add kernel attributes to :class:`~_module.Kernel`
1213

1314
Limitations
1415
-----------

cuda_core/tests/test_module.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,75 @@
1010
import pytest
1111
from conftest import can_load_generated_ptx
1212

13-
from cuda.core.experimental import Program, ProgramOptions
13+
from cuda.core.experimental import Program, ProgramOptions, system
14+
15+
16+
@pytest.fixture(scope="function")
17+
def get_saxpy_kernel(init_cuda):
18+
code = """
19+
template<typename T>
20+
__global__ void saxpy(const T a,
21+
const T* x,
22+
const T* y,
23+
T* out,
24+
size_t N) {
25+
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
26+
for (size_t i=tid; i<N; i+=gridDim.x*blockDim.x) {
27+
out[tid] = a * x[tid] + y[tid];
28+
}
29+
}
30+
"""
31+
32+
# prepare program
33+
prog = Program(code, code_type="c++")
34+
mod = prog.compile(
35+
"cubin",
36+
name_expressions=("saxpy<float>", "saxpy<double>"),
37+
)
38+
39+
# run in single precision
40+
return mod.get_kernel("saxpy<float>")
1441

1542

1643
@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new")
17-
def test_get_kernel():
44+
def test_get_kernel(init_cuda):
1845
kernel = """extern "C" __global__ void ABC() { }"""
1946
object_code = Program(kernel, "c++", options=ProgramOptions(relocatable_device_code=True)).compile("ptx")
2047
assert object_code._handle is None
2148
kernel = object_code.get_kernel("ABC")
2249
assert object_code._handle is not None
2350
assert kernel._handle is not None
51+
52+
53+
@pytest.mark.parametrize(
54+
"attr, expected_type",
55+
[
56+
("max_threads_per_block", int),
57+
("shared_size_bytes", int),
58+
("const_size_bytes", int),
59+
("local_size_bytes", int),
60+
("num_regs", int),
61+
("ptx_version", int),
62+
("binary_version", int),
63+
("cache_mode_ca", bool),
64+
("cluster_size_must_be_set", bool),
65+
("max_dynamic_shared_size_bytes", int),
66+
("preferred_shared_memory_carveout", int),
67+
("required_cluster_width", int),
68+
("required_cluster_height", int),
69+
("required_cluster_depth", int),
70+
("non_portable_cluster_size_allowed", bool),
71+
("cluster_scheduling_policy_preference", int),
72+
],
73+
)
74+
def test_read_only_kernel_attributes(get_saxpy_kernel, attr, expected_type):
75+
kernel = get_saxpy_kernel
76+
method = getattr(kernel.attributes, attr)
77+
# get the value without providing a device ordinal
78+
value = method()
79+
assert value is not None
80+
81+
# get the value for each device on the system
82+
for device in system.devices:
83+
value = method(device.device_id)
84+
assert isinstance(value, expected_type), f"Expected {attr} to be of type {expected_type}, but got {type(value)}"

0 commit comments

Comments
 (0)