Skip to content

Commit 1290efe

Browse files
jagadish-amdpruthvistony
authored andcommitted
ROCm: Enable tf32 testing on test_nn (#55)
* Add trailing comma for consistency in gfx architecture list Signed-off-by: Jagadish Krishnamoorthy <[email protected]> * ROCm: Enable tf32 testing on test_nn Signed-off-by: Jagadish Krishnamoorthy <[email protected]> --------- Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent 8ccfc47 commit 1290efe

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

torch/cuda/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,21 @@ def _check_bf16_tensor_supported(device: _device_t):
168168
return False
169169

170170

171+
def is_tf32_supported() -> bool:
172+
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype tf32."""
173+
if torch.version.hip:
174+
prop_name = torch.cuda.get_device_properties().gcnArchName
175+
archs = ("gfx94", "gfx95")
176+
for arch in archs:
177+
if arch in prop_name:
178+
return True
179+
return False
180+
181+
# Otherwise, tf32 is supported on CUDA platforms that natively (i.e. no emulation)
182+
# support bfloat16.
183+
return is_bf16_supported(including_emulation=False)
184+
185+
171186
def _sleep(cycles):
172187
torch._C._cuda_sleep(cycles)
173188

torch/testing/_internal/common_cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ def tf32_off():
150150

151151
@contextlib.contextmanager
152152
def tf32_on(self, tf32_precision=1e-5):
153+
if torch.version.hip:
154+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
155+
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
153156
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
154157
old_precision = self.precision
155158
try:
@@ -158,6 +161,11 @@ def tf32_on(self, tf32_precision=1e-5):
158161
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
159162
yield
160163
finally:
164+
if torch.version.hip:
165+
if hip_allow_tf32 is not None:
166+
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
167+
else:
168+
del os.environ["HIPBLASLT_ALLOW_TF32"]
161169
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
162170
self.precision = old_precision
163171

0 commit comments

Comments
 (0)