Skip to content

Update README.md for float8 #1090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,12 @@ m = nn.Sequential(
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# enable torch.compile for competitive performance
# convert specified `torch.nn.Linear` modules to `Float8Linear`, with compute
# and optionally distributed communications in float8
convert_to_float8_training(m)

# enable torch.compile to generate fused kernels for float8 scaling and casting,
# which improves performance
m = torch.compile(m)

# toy training loop
Expand Down Expand Up @@ -94,7 +85,8 @@ config = Float8LinearConfig(
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior
convert_to_float8_training(m, config=config)

# enable torch.compile for competitive performance
# enable torch.compile to generate fused kernels for float8 scaling and casting,
# which improves performance
m = torch.compile(m)

# toy training loop
Expand Down
Loading