Skip to content

Commit c6381c6

Browse files
ezyangsoumith
authored andcommitted
Add function to explicitly initialize PyTorch CUDA state. (pytorch#4180)
Signed-off-by: Edward Z. Yang <[email protected]>
1 parent d450895 commit c6381c6

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

torch/cuda/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,19 @@ class DeferredCudaCallError(Exception):
122122
pass
123123

124124

125+
def init():
126+
"""Initialize PyTorch's CUDA state. You may need to call
127+
this explicitly if you are interacting with PyTorch via
128+
its C API, as Python bindings for CUDA functionality will not
129+
be until this initialization takes place. Ordinary users
130+
should not need this, as all of PyTorch's CUDA methods
131+
automatically initialize CUDA state on-demand.
132+
133+
Does nothing if the CUDA state is already initialized.
134+
"""
135+
_lazy_init()
136+
137+
125138
def _lazy_init():
126139
global _initialized, _cudart, _original_pid, _queued_calls
127140
if _initialized:

0 commit comments

Comments
 (0)