Skip to content

Commit 1586ea1

Browse files
jagadish-amddnikolaev-amd
authored andcommitted
ROCm: Enable tf32 testing on test_nn (#55)
* Add trailing comma for consistency in gfx architecture list Signed-off-by: Jagadish Krishnamoorthy <[email protected]> * ROCm: Enable tf32 testing on test_nn Signed-off-by: Jagadish Krishnamoorthy <[email protected]> --------- Signed-off-by: Jagadish Krishnamoorthy <[email protected]> (cherry picked from commit 00a0d8b3ff035b560c320b082ac3e0158e4ee1c4)
1 parent a2df696 commit 1586ea1

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

torch/cuda/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,12 @@ def _check_bf16_tensor_supported(device: _device_t):
217217

218218
def is_tf32_supported() -> bool:
219219
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype tf32."""
220-
# Check for ROCm. If true, return false, since PyTorch does not currently support
221-
# tf32 on ROCm.
222220
if torch.version.hip:
221+
prop_name = torch.cuda.get_device_properties().gcnArchName
222+
archs = ("gfx94", "gfx95")
223+
for arch in archs:
224+
if arch in prop_name:
225+
return True
223226
return False
224227

225228
# Otherwise, tf32 is supported on CUDA platforms that natively (i.e. no emulation)

torch/testing/_internal/common_cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def tf32_off():
145145

146146
@contextlib.contextmanager
147147
def tf32_on(self, tf32_precision=1e-5):
148+
if torch.version.hip:
149+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
150+
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
148151
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
149152
old_precision = self.precision
150153
try:
@@ -153,6 +156,11 @@ def tf32_on(self, tf32_precision=1e-5):
153156
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
154157
yield
155158
finally:
159+
if torch.version.hip:
160+
if hip_allow_tf32 is not None:
161+
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
162+
else:
163+
del os.environ["HIPBLASLT_ALLOW_TF32"]
156164
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
157165
self.precision = old_precision
158166

0 commit comments

Comments
 (0)