Skip to content

recognize and issue error if GPU does not support bf16 #1344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions torchchat/utils/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,51 @@ def get_precision():
### dtype name to torch.dtype mapping ###


def get_cuda_architecture(device=None):
device_str = get_device_str(device)
if "cuda" in device_str and torch.cuda.is_available():
# Get the compute capability as (major, minor) tuple
capability = torch.cuda.get_device_capability(device)
return capability[0], capability[1]
else:
return 0, 0


##########################################################################
### dtype name to torch.dtype mapping ###


def name_to_dtype(name, device):
device_str = get_device_str(device)
# if it's CUDA, the architecture level indicates whether we can use bf16
major, minor = get_cuda_architecture(device)

if (name == "fast") or (name == "fast16"):
# MacOS now supports bfloat16
import platform

if platform.processor() == "arm":
device = get_device_str(device)
# ARM CPU is faster with float16, MPS with bf16 if supported
if device == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14:
if device_str == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14:
return torch.float16
return torch.bfloat16

# if it's not CUDA, we know it's bfloat16
if "cuda" not in device_str:
return torch.bfloat16

if major >= 9:
return torch.bfloat16
else:
return torch.float16

try:
return name_to_dtype_dict[name]
dtype = name_to_dtype_dict[name]
except KeyError:
raise RuntimeError(f"unsupported dtype name {name} specified")

if ("cuda" in device_str) and (dtype == torch.bfloat16) and (major < 9):
raise RuntimeError(f"target device {device_str} does not support the bfloat16 data type")
return dtype


def allowable_dtype_names() -> List[str]:
Expand Down
Loading