Skip to content

Commit a0baf71

Browse files
committed
Ensure mod_path is always defined when used. Make DLL search order consistent between all three cases.
1 parent 21a7fb4 commit a0baf71

File tree

3 files changed

+67
-69
lines changed

3 files changed

+67
-69
lines changed

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -60,38 +60,37 @@ cdef int cuPythonInit() except -1 nogil:
6060
except:
6161
handle = None
6262

63-
# Else try default search
64-
if not handle:
65-
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
66-
try:
67-
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
68-
except:
69-
pass
70-
71-
# Final check if DLLs can be found within pip installations
63+
# Check if DLLs can be found within pip installations
7264
if not handle:
7365
site_packages = [site.getusersitepackages()] + site.getsitepackages()
7466
for sp in site_packages:
7567
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
76-
if not os.path.isdir(mod_path):
77-
continue
78-
else:
68+
if os.path.isdir(mod_path):
7969
os.add_dll_directory(mod_path)
80-
break
81-
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
82-
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
83-
try:
84-
handle = win32api.LoadLibraryEx(
85-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
86-
os.path.join(mod_path, "nvrtc64_120_0.dll"),
87-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
88-
89-
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
90-
# located in the same mod_path.
91-
# Update PATH environ so that the two dlls can find each other
92-
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
93-
except:
94-
pass
70+
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
71+
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
72+
try:
73+
handle = win32api.LoadLibraryEx(
74+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
75+
os.path.join(mod_path, "nvrtc64_120_0.dll"),
76+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
77+
78+
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
79+
# located in the same mod_path.
80+
# Update PATH environ so that the two dlls can find each other
81+
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
82+
except:
83+
pass
84+
else:
85+
break
86+
else:
87+
# Else try default search
88+
# Only reached if DLL wasn't found in any site-package path
89+
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
90+
try:
91+
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
92+
except:
93+
pass
9594

9695
if not handle:
9796
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,28 +65,26 @@ cdef load_library(const int driver_ver):
6565
# Next, check if DLLs are installed via pip
6666
for sp in get_site_packages():
6767
mod_path = os.path.join(sp, "nvidia", "nvJitLink", "bin")
68-
if not os.path.isdir(mod_path):
69-
continue
70-
else:
68+
if os.path.isdir(mod_path):
7169
os.add_dll_directory(mod_path)
72-
break
73-
try:
74-
handle = win32api.LoadLibraryEx(
75-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
76-
os.path.join(mod_path, dll_name),
77-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
78-
except:
79-
pass
80-
else:
81-
break
82-
83-
# Finally, try default search
84-
try:
85-
handle = win32api.LoadLibrary(dll_name)
86-
except:
87-
pass
70+
try:
71+
handle = win32api.LoadLibraryEx(
72+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
73+
os.path.join(mod_path, dll_name),
74+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
75+
except:
76+
pass
77+
else:
78+
break
8879
else:
89-
break
80+
# Finally, try default search
81+
# Only reached if DLL wasn't found in any site-package path
82+
try:
83+
handle = win32api.LoadLibrary(dll_name)
84+
except:
85+
pass
86+
else:
87+
break
9088
else:
9189
raise RuntimeError('Failed to load nvJitLink')
9290

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,33 +62,34 @@ cdef load_library(const int driver_ver):
6262

6363
# Next, check if DLLs are installed via pip or conda
6464
for sp in get_site_packages():
65-
if sp == "conda" and "CONDA_PREFIX" in os.environ:
65+
if sp == "conda":
6666
# nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path
67-
mod_path = os.path.join(os.environ["CONDA_PREFIX"], "Library", "nvvm", "bin")
67+
conda_prefix = os.environ.get("CONDA_PREFIX")
68+
if conda_prefix is None:
69+
continue
70+
mod_path = os.path.join(conda_prefix, "Library", "nvvm", "bin")
6871
else:
6972
mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin")
70-
if not os.path.isdir(mod_path):
71-
continue
72-
else:
73+
if os.path.isdir(mod_path):
7374
os.add_dll_directory(mod_path)
74-
break
75-
try:
76-
handle = win32api.LoadLibraryEx(
77-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
78-
os.path.join(mod_path, dll_name),
79-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
80-
except:
81-
pass
82-
else:
83-
break
84-
85-
# Finally, try default search
86-
try:
87-
handle = win32api.LoadLibrary(dll_name)
88-
except:
89-
pass
75+
try:
76+
handle = win32api.LoadLibraryEx(
77+
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
78+
os.path.join(mod_path, dll_name),
79+
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
80+
except:
81+
pass
82+
else:
83+
break
9084
else:
91-
break
85+
# Finally, try default search
86+
# Only reached if DLL wasn't found in any site-package path
87+
try:
88+
handle = win32api.LoadLibrary(dll_name)
89+
except:
90+
pass
91+
else:
92+
break
9293
else:
9394
raise RuntimeError('Failed to load nvvm')
9495

0 commit comments

Comments
 (0)