Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in
Original file line number Diff line number Diff line change
Expand Up @@ -60,36 +60,37 @@ cdef int cuPythonInit() except -1 nogil:
except:
handle = None

# Else try default search
if not handle:
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
try:
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
except:
pass

# Final check if DLLs can be found within pip installations
# Check if DLLs can be found within pip installations
if not handle:
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
site_packages = [site.getusersitepackages()] + site.getsitepackages()
for sp in site_packages:
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
if not os.path.isdir(mod_path):
continue
os.add_dll_directory(mod_path)
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
try:
handle = win32api.LoadLibraryEx(
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
os.path.join(mod_path, "nvrtc64_120_0.dll"),
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)

# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
# located in the same mod_path.
# Update PATH environ so that the two dlls can find each other
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
except:
pass
if os.path.isdir(mod_path):
os.add_dll_directory(mod_path)
try:
handle = win32api.LoadLibraryEx(
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
os.path.join(mod_path, "nvrtc64_120_0.dll"),
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)

# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
# located in the same mod_path.
# Update PATH environ so that the two dlls can find each other
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
Comment on lines +78 to +81
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be manually loading this dll instead of modifying PATH? Regardless, we don't need to fix this in this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvrtc is very special, because of the builtins runtime dependency.

In one of my ChatGPT chats about it a few days ago, it suggested strongly to do both, os.environ["PATH"] update, and os.add_dll_directory(mod_path). Therefore I'm carrying that into the path_finder.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to add: I somehow have it in my mind that @leofang made a remark that we shouldn't load the builtins ourselves. Leo, does that make sense?

Copy link
Member Author

@leofang leofang Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-loading a DLL works (it's what I did in nvmath). My reason against pre-loading is not because it does not work, but because I am not willing to maintain other libraries' implementation details (dlopen, which has no DT_NEEDED entry or package dependency metadata for us to inspect). This can easily go out-of-date as these libraries evolve.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Leo, I linked your comment here:

1344621

We can weigh the pros-and-cons of "soft dependency pre-loading" vs os.environ["PATH"] & os.add_dll_directory() management when the path_finder code is complete.

except:
pass
else:
break
else:
# Else try default search
# Only reached if DLL wasn't found in any site-package path
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
try:
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
except:
pass

if not handle:
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')
Expand Down
36 changes: 13 additions & 23 deletions cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -56,40 +56,30 @@ cdef load_library(const int driver_ver):

# First check if the DLL has been loaded by 3rd parties
try:
handle = win32api.GetModuleHandle(dll_name)
return win32api.GetModuleHandle(dll_name)
except:
pass
else:
break

# Next, check if DLLs are installed via pip
for sp in get_site_packages():
mod_path = os.path.join(sp, "nvidia", "nvJitLink", "bin")
if not os.path.isdir(mod_path):
continue
os.add_dll_directory(mod_path)
try:
handle = win32api.LoadLibraryEx(
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
os.path.join(mod_path, dll_name),
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
except:
pass
else:
break

if os.path.isdir(mod_path):
os.add_dll_directory(mod_path)
try:
return win32api.LoadLibraryEx(
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
os.path.join(mod_path, dll_name),
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
except:
pass
# Finally, try default search
# Only reached if DLL wasn't found in any site-package path
try:
handle = win32api.LoadLibrary(dll_name)
return win32api.LoadLibrary(dll_name)
except:
pass
else:
break
else:
raise RuntimeError('Failed to load nvJitLink')

assert handle != 0
return handle
raise RuntimeError('Failed to load nvJitLink')


cdef int _check_or_init_nvjitlink() except -1 nogil:
Expand Down
50 changes: 24 additions & 26 deletions cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ cdef void* __nvvmGetProgramLog = NULL


cdef inline list get_site_packages():
return [site.getusersitepackages()] + site.getsitepackages()
return [site.getusersitepackages()] + site.getsitepackages() + ["conda"]


cdef load_library(const int driver_ver):
Expand All @@ -50,44 +50,42 @@ cdef load_library(const int driver_ver):
for suffix in get_nvvm_dso_version_suffix(driver_ver):
if len(suffix) == 0:
continue
dll_name = "nvvm64_40_0"
dll_name = "nvvm64_40_0.dll"

# First check if the DLL has been loaded by 3rd parties
try:
handle = win32api.GetModuleHandle(dll_name)
return win32api.GetModuleHandle(dll_name)
except:
pass
else:
break

# Next, check if DLLs are installed via pip
# Next, check if DLLs are installed via pip or conda
for sp in get_site_packages():
mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin")
if not os.path.isdir(mod_path):
continue
os.add_dll_directory(mod_path)
try:
handle = win32api.LoadLibraryEx(
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
os.path.join(mod_path, dll_name),
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
except:
pass
else:
break
if sp == "conda":
# nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path
conda_prefix = os.environ.get("CONDA_PREFIX")
if conda_prefix is None:
continue
mod_path = os.path.join(conda_prefix, "Library", "nvvm", "bin")
else:
mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin")
if os.path.isdir(mod_path):
os.add_dll_directory(mod_path)
try:
return win32api.LoadLibraryEx(
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
os.path.join(mod_path, dll_name),
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
except:
pass

# Finally, try default search
# Only reached if DLL wasn't found in any site-package path
try:
handle = win32api.LoadLibrary(dll_name)
return win32api.LoadLibrary(dll_name)
except:
pass
else:
break
else:
raise RuntimeError('Failed to load nvvm')

assert handle != 0
return handle
raise RuntimeError('Failed to load nvvm')


cdef int _check_or_init_nvvm() except -1 nogil:
Expand Down
2 changes: 1 addition & 1 deletion cuda_bindings/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def build_extension(self, ext):
# to <loc>/site-packages/nvidia/cuda_nvcc/nvvm/lib64/
rel1 = "$ORIGIN/../../../nvidia/cuda_nvcc/nvvm/lib64"
# from <loc>/lib/python3.*/site-packages/cuda/bindings/_internal/
# to <loc>/lib/nvvm/lib64/
# to <loc>/nvvm/lib64/
rel2 = "$ORIGIN/../../../../../../nvvm/lib64"
ldflag = f"-Wl,--disable-new-dtags,-rpath,{rel1},-rpath,{rel2}"
else:
Expand Down
Loading