From bfab3f108571213d1ca1f4be93d383b8e66e481f Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Tue, 5 Nov 2024 11:42:20 -0800 Subject: [PATCH 1/3] recognize and issue error if GPU does not support bf16 Address pytorch/torchchat#1298 which causes models to fail on T4 (and other pre-9.0 arch level GPUs) by selecting an alternate dtype when possible, and issue an error otherwise --- torchchat/utils/build_utils.py | 37 ++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index 1b649ffbc..f61d7f411 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -148,22 +148,51 @@ def get_precision(): ### dtype name to torch.dtype mapping ### +def get_cuda_architecture(device=None): + device_str = get_device_str(device) + if "cuda" is in device_str and torch.cuda.is_available(): + # Get the compute capability as (major, minor) tuple + capability = torch.cuda.get_device_capability(device) + return capability[0], capability[1] + else: + return 0, 0 + + +########################################################################## +### dtype name to torch.dtype mapping ### + + def name_to_dtype(name, device): + device_str = get_device_str(device) + # if it's CUDA, the architecture level indicates whether we can use bf16 + major, minor = get_cuda_architecture(device) + if (name == "fast") or (name == "fast16"): # MacOS now supports bfloat16 import platform if platform.processor() == "arm": - device = get_device_str(device) # ARM CPU is faster with float16, MPS with bf16 if supported - if device == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14: + if device_str == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14: return torch.float16 - return torch.bfloat16 + + # if it's not CUDA, we know it's bfloat16 + if "cuda" is not in device_str: + return torch.bfloat16 + + if major >= 9: + return torch.bfloat16 + else: + return torch.float16 try: - return name_to_dtype_dict[name] + dtype = name_to_dtype_dict[name] except KeyError: raise RuntimeError(f"unsupported dtype name {name} specified") + + if ("cuda" is in device_str) and (dtype == torch.bfloat16) and (major < 9): + raise RuntimeError(f"target device {device_str} does not support the bfloat16 data type") + return dtype def allowable_dtype_names() -> List[str]: From ae694284869145273f40037258c2360d28128134 Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:20:51 -0800 Subject: [PATCH 2/3] Update build_utils.py typo --- torchchat/utils/build_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index f61d7f411..5ebb2080c 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -150,7 +150,7 @@ def get_precision(): def get_cuda_architecture(device=None): device_str = get_device_str(device) - if "cuda" is in device_str and torch.cuda.is_available(): + if "cuda" in device_str and torch.cuda.is_available(): # Get the compute capability as (major, minor) tuple capability = torch.cuda.get_device_capability(device) return capability[0], capability[1] @@ -190,7 +190,7 @@ def name_to_dtype(name, device): except KeyError: raise RuntimeError(f"unsupported dtype name {name} specified") - if ("cuda" is in device_str) and (dtype == torch.bfloat16) and (major < 9): + if ("cuda" in device_str) and (dtype == torch.bfloat16) and (major < 9): raise RuntimeError(f"target device {device_str} does not support the bfloat16 data type") return dtype From fd0f53e2d58b1900a2696681482fc18624fed7ca Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Tue, 5 Nov 2024 13:36:13 -0800 Subject: [PATCH 3/3] Update build_utils.py typo --- torchchat/utils/build_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index 5ebb2080c..0b52b0835 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -177,7 +177,7 @@ def name_to_dtype(name, device): return torch.float16 # if it's not CUDA, we know it's bfloat16 - if "cuda" is not in device_str: + if "cuda" not in device_str: return torch.bfloat16 if major >= 9: