3
3
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4
4
5
5
6
+ from warnings import warn
7
+
6
8
from cuda .core .experimental ._utils import driver , get_binding_version , handle_return , precondition
7
9
8
10
_backend = {
9
11
"old" : {
10
12
"file" : driver .cuModuleLoad ,
11
13
"data" : driver .cuModuleLoadDataEx ,
12
14
"kernel" : driver .cuModuleGetFunction ,
15
+ "attribute" : driver .cuFuncGetAttribute ,
13
16
},
14
17
}
15
18
@@ -34,6 +37,7 @@ def _lazy_init():
34
37
"file" : driver .cuLibraryLoadFromFile ,
35
38
"data" : driver .cuLibraryLoadData ,
36
39
"kernel" : driver .cuLibraryGetKernel ,
40
+ "attribute" : driver .cuKernelGetAttribute ,
37
41
}
38
42
_kernel_ctypes = (driver .CUfunction , driver .CUkernel )
39
43
else :
@@ -46,19 +50,30 @@ class KernelAttributes:
46
50
def __init__ (self ):
47
51
raise RuntimeError ("KernelAttributes should not be instantiated directly" )
48
52
49
- slots = ("_handle" , "_cache" )
53
+ slots = ("_handle" , "_cache" , "_backend_version" , "_loader" )
50
54
51
55
def _init (handle ):
52
56
self = KernelAttributes .__new__ (KernelAttributes )
53
57
self ._handle = handle
54
58
self ._cache = {}
59
+
60
+ self ._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000 ) else "old"
61
+ self ._loader = _backend [self ._backend_version ]
55
62
return self
56
63
57
64
def _get_cached_attribute (self , device_id : int , attribute : driver .CUfunction_attribute ) -> int :
58
65
"""Helper function to get a cached attribute or fetch and cache it if not present."""
59
66
if device_id in self ._cache and attribute in self ._cache [device_id ]:
60
67
return self ._cache [device_id ][attribute ]
61
- result = handle_return (driver .cuKernelGetAttribute (attribute , self ._handle , device_id ))
68
+ if self ._backend_version == "new" :
69
+ result = handle_return (self ._loader ["attribute" ](attribute , self ._handle , device_id ))
70
+ else : # "old" backend
71
+ warn (
72
+ "Device ID argument is ignored when getting attribute from kernel when cuda version < 12. " ,
73
+ RuntimeWarning ,
74
+ stacklevel = 2 ,
75
+ )
76
+ result = handle_return (self ._loader ["attribute" ](attribute , self ._handle ))
62
77
if device_id not in self ._cache :
63
78
self ._cache [device_id ] = {}
64
79
self ._cache [device_id ][attribute ] = result
0 commit comments