Skip to content

Commit 86a536a

Browse files
committed
fallback to cuFuncAPI
1 parent 2ae7cfb commit 86a536a

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

55

6+
from warnings import warn
7+
68
from cuda.core.experimental._utils import driver, get_binding_version, handle_return, precondition
79

810
_backend = {
911
"old": {
1012
"file": driver.cuModuleLoad,
1113
"data": driver.cuModuleLoadDataEx,
1214
"kernel": driver.cuModuleGetFunction,
15+
"attribute": driver.cuFuncGetAttribute,
1316
},
1417
}
1518

@@ -34,6 +37,7 @@ def _lazy_init():
3437
"file": driver.cuLibraryLoadFromFile,
3538
"data": driver.cuLibraryLoadData,
3639
"kernel": driver.cuLibraryGetKernel,
40+
"attribute": driver.cuKernelGetAttribute,
3741
}
3842
_kernel_ctypes = (driver.CUfunction, driver.CUkernel)
3943
else:
@@ -46,19 +50,30 @@ class KernelAttributes:
4650
def __init__(self):
4751
raise RuntimeError("KernelAttributes should not be instantiated directly")
4852

49-
slots = ("_handle", "_cache")
53+
slots = ("_handle", "_cache", "_backend_version", "_loader")
5054

5155
def _init(handle):
5256
self = KernelAttributes.__new__(KernelAttributes)
5357
self._handle = handle
5458
self._cache = {}
59+
60+
self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
61+
self._loader = _backend[self._backend_version]
5562
return self
5663

5764
def _get_cached_attribute(self, device_id: int, attribute: driver.CUfunction_attribute) -> int:
5865
"""Helper function to get a cached attribute or fetch and cache it if not present."""
5966
if device_id in self._cache and attribute in self._cache[device_id]:
6067
return self._cache[device_id][attribute]
61-
result = handle_return(driver.cuKernelGetAttribute(attribute, self._handle, device_id))
68+
if self._backend_version == "new":
69+
result = handle_return(self._loader["attribute"](attribute, self._handle, device_id))
70+
else: # "old" backend
71+
warn(
72+
"Device ID argument is ignored when getting attribute from kernel when cuda version < 12. ",
73+
RuntimeWarning,
74+
stacklevel=2,
75+
)
76+
result = handle_return(self._loader["attribute"](attribute, self._handle))
6277
if device_id not in self._cache:
6378
self._cache[device_id] = {}
6479
self._cache[device_id][attribute] = result

cuda_core/tests/test_module.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,7 @@
1010
import pytest
1111
from conftest import can_load_generated_ptx
1212

13-
try:
14-
from cuda.bindings import driver
15-
except ImportError:
16-
from cuda import cuda as driver
17-
1813
from cuda.core.experimental import Program, ProgramOptions, system
19-
from cuda.core.experimental._utils import get_binding_version, handle_return
20-
21-
22-
@pytest.fixture(scope="module")
23-
def cuda_version():
24-
# binding availability depends on cuda-python version
25-
_py_major_ver, _ = get_binding_version()
26-
_driver_ver = handle_return(driver.cuDriverGetVersion())
27-
return _py_major_ver, _driver_ver
2814

2915

3016
@pytest.fixture(scope="function")
@@ -85,9 +71,7 @@ def test_get_kernel(init_cuda):
8571
("cluster_scheduling_policy_preference", int),
8672
],
8773
)
88-
def test_read_only_kernel_attributes(get_saxpy_kernel, attr, expected_type, cuda_version):
89-
if cuda_version[0] < 12 and cuda_version[1] >= 12000:
90-
pytest.skip("CUDA version is less than 12, and doesn't support kernel attribute access")
74+
def test_read_only_kernel_attributes(get_saxpy_kernel, attr, expected_type):
9175
kernel = get_saxpy_kernel
9276
method = getattr(kernel.attributes, attr)
9377
# get the value without providing a device ordinal

0 commit comments

Comments
 (0)