Skip to content

Commit 40f79d7

Browse files
committed
[BE][3/n] wrap fp8 logic using Float8Handler
ghstack-source-id: e94c7f6 Pull Request resolved: #496
1 parent bf90710 commit 40f79d7

File tree

11 files changed

+163
-150
lines changed

11 files changed

+163
-150
lines changed

estimation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616

1717
from torchtitan.config_manager import JobConfig
1818
from torchtitan.datasets import build_tokenizer
19-
from torchtitan.float8_linear import (
20-
maybe_build_fp8_linear,
21-
maybe_precompute_fp8_dynamic_scale_for_fsdp,
22-
)
19+
from torchtitan.float8_linear import Float8Handler
2320
from torchtitan.logging import init_logger, logger
2421
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
2522
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
@@ -127,8 +124,10 @@ def loss_fn(pred, labels):
127124
with torch.device("meta"):
128125
whole_model = model_cls.from_model_args(model_config)
129126

127+
# a no-op hander if fp8 is not enabled
128+
float8_handler = Float8Handler(job_config, parallel_dims)
130129
# swap to Float8Linear base on fp8 config
131-
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
130+
float8_handler.convert_to_float8_training(whole_model)
132131

133132
# apply PT-D DP/TP parallelisms and activation checkpointing
134133
model_parts = [whole_model]
@@ -184,13 +183,14 @@ def loss_fn(pred, labels):
184183
torch.nn.utils.clip_grad_norm_(
185184
model.parameters(), job_config.training.max_norm, foreach=True
186185
)
186+
# sync float8 amaxes and scales
187+
float8_handler.sync_float8_amax_and_scale_history(model)
187188
# optimizer step
188189
optimizers.step()
189190
lr_schedulers.step()
190-
# when fp8 config is on,
191191
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
192192
# it issues a single all-reduce for all parameters at once for better performance
193-
maybe_precompute_fp8_dynamic_scale_for_fsdp(whole_model, job_config)
193+
float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model)
194194
optimizers.zero_grad()
195195
print(f"Peak Memory at iter: {iter_idx}")
196196
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)

torchtitan/config_manager.py

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -348,46 +348,6 @@ def __init__(self):
348348
action="store_true",
349349
help="Whether to compile the model",
350350
)
351-
self.parser.add_argument(
352-
"--training.enable_float8_linear",
353-
action="store_true",
354-
help="""
355-
If true, swaps `torch.nn.Linear` with `Float8Linear`.
356-
This feature requires you to install 'torchao' which can be found
357-
here: https://github.com/pytorch/ao
358-
""",
359-
)
360-
self.parser.add_argument(
361-
"--training.enable_fsdp_float8_all_gather",
362-
action="store_true",
363-
default=False,
364-
help="Whether enable float8 all-gather in FSDP",
365-
)
366-
self.parser.add_argument(
367-
"--training.precompute_float8_dynamic_scale_for_fsdp",
368-
action="store_true",
369-
default=False,
370-
help="Whether precompute float8 scales dynamically for FSDP",
371-
)
372-
self.parser.add_argument(
373-
"--training.float8_scaling_type_input",
374-
type=str,
375-
default="dynamic",
376-
help="float8 scaling for input, dynamic (default) or delayed",
377-
choices=["dynamic", "delayed"],
378-
)
379-
self.parser.add_argument(
380-
"--training.float8_scaling_type_weight",
381-
type=str,
382-
default="dynamic",
383-
help="float8 scaling for input, dynamic (default) or delayed",
384-
)
385-
self.parser.add_argument(
386-
"--training.float8_scaling_type_grad_output",
387-
type=str,
388-
default="dynamic",
389-
help="float8 scaling for input, dynamic (default) or delayed",
390-
)
391351
self.parser.add_argument(
392352
"--training.gc_freq",
393353
type=int,
@@ -483,6 +443,7 @@ def __init__(self):
483443
0 is the default value.
484444
""",
485445
)
446+
486447
# activation checkpointing configs
487448
self.parser.add_argument(
488449
"--activation_checkpoint.mode",
@@ -500,6 +461,48 @@ def __init__(self):
500461
""",
501462
)
502463

464+
# float8 configs
465+
self.parser.add_argument(
466+
"--float8.enable_float8_linear",
467+
action="store_true",
468+
help="""
469+
If true, swaps `torch.nn.Linear` with `Float8Linear`.
470+
This feature requires you to install 'torchao' which can be found
471+
here: https://github.com/pytorch/ao
472+
""",
473+
)
474+
self.parser.add_argument(
475+
"--float8.enable_fsdp_float8_all_gather",
476+
action="store_true",
477+
default=False,
478+
help="Whether enable float8 all-gather in FSDP",
479+
)
480+
self.parser.add_argument(
481+
"--float8.precompute_float8_dynamic_scale_for_fsdp",
482+
action="store_true",
483+
default=False,
484+
help="Whether precompute float8 scales dynamically for FSDP",
485+
)
486+
self.parser.add_argument(
487+
"--float8.scaling_type_input",
488+
type=str,
489+
default="dynamic",
490+
help="float8 scaling for input, dynamic (default) or delayed",
491+
choices=["dynamic", "delayed"],
492+
)
493+
self.parser.add_argument(
494+
"--float8.scaling_type_weight",
495+
type=str,
496+
default="dynamic",
497+
help="float8 scaling for input, dynamic (default) or delayed",
498+
)
499+
self.parser.add_argument(
500+
"--float8.scaling_type_grad_output",
501+
type=str,
502+
default="dynamic",
503+
help="float8 scaling for input, dynamic (default) or delayed",
504+
)
505+
503506
# communications library settings
504507
self.parser.add_argument(
505508
"--comm.init_timeout_seconds",

torchtitan/float8_linear.py

Lines changed: 87 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -12,127 +12,128 @@
1212

1313
# Note: Performance
1414
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
15-
import functools
16-
from typing import Optional
1715

1816
import torch
1917
import torch.nn as nn
20-
from torch._logging import warning_once
2118

2219
from torchtitan.config_manager import JobConfig
2320
from torchtitan.logging import logger
21+
from torchtitan.parallelisms import ParallelDims
2422

2523

26-
@functools.lru_cache(None)
2724
def is_sm90_or_later():
2825
# Float8 is only supported on H100+ GPUs
2926
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
3027

3128

32-
def maybe_build_fp8_linear(
33-
model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False
34-
):
35-
"""
36-
This function converts the linear layers to `Float8Linear`. Note that today,
37-
only dynamic tensor scaling (the default) is supported.
38-
39-
This will mutate the model inplace.
40-
"""
41-
enable_float8_linear = job_config.training.enable_float8_linear
42-
if not enable_float8_linear:
43-
return
44-
if not is_sm90_or_later():
45-
warning_once(
46-
logger,
47-
"Failed to swap to Float8Linear because SM90 or later is not available",
48-
)
49-
return
50-
try:
51-
from torchao.float8 import (
52-
CastConfig,
53-
convert_to_float8_training,
54-
Float8LinearConfig,
55-
ScalingType,
56-
)
29+
class Float8Handler:
30+
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
31+
self.enabled = False
32+
33+
float8_config = job_config.float8
34+
if not float8_config.enable_float8_linear:
35+
return
36+
if not is_sm90_or_later():
37+
logger.warning(
38+
"Failed to swap to Float8Linear because SM90 or later is not available",
39+
)
40+
return
41+
try:
42+
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
43+
except ImportError as e:
44+
raise ImportError(
45+
"torchao is not installed. Please install it to use fp8 linear layers."
46+
) from e
5747

5848
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
5949
enable_fsdp_float8_all_gather = (
60-
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
61-
)
62-
scaling_type_input = ScalingType(job_config.training.float8_scaling_type_input)
63-
scaling_type_weight = ScalingType(
64-
job_config.training.float8_scaling_type_weight
50+
parallel_dims.dp_enabled
51+
and parallel_dims.dp_type == "fsdp"
52+
and float8_config.enable_fsdp_float8_all_gather
6553
)
66-
scaling_type_grad_output = ScalingType(
67-
job_config.training.float8_scaling_type_grad_output
68-
)
69-
float8_config = Float8LinearConfig(
54+
scaling_type_input = ScalingType(float8_config.scaling_type_input)
55+
scaling_type_weight = ScalingType(float8_config.scaling_type_weight)
56+
scaling_type_grad_output = ScalingType(float8_config.scaling_type_grad_output)
57+
self.config = Float8LinearConfig(
7058
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
7159
cast_config_input=CastConfig(scaling_type=scaling_type_input),
7260
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
7361
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
7462
enable_pre_and_post_forward=False,
7563
)
64+
65+
self.enabled = True
66+
67+
# for precompute_fp8_dynamic_scale_for_fsdp
68+
self.precompute_scale = (
69+
enable_fsdp_float8_all_gather
70+
and float8_config.precompute_float8_dynamic_scale_for_fsdp
71+
)
72+
73+
# for sync_float8_amax_and_scale_history
74+
self.delayed_scaling = (
75+
scaling_type_input == "delayed"
76+
or scaling_type_weight == "delayed"
77+
or scaling_type_grad_output == "delayed"
78+
)
79+
self._sync_float8_amax_and_scale_history = None
80+
self.compile = job_config.training.compile
81+
82+
logger.info("Float8 training active")
83+
84+
def convert_to_float8_training(self, model: nn.Module):
85+
"""
86+
This function converts the linear layers of `model` to `Float8Linear`.
87+
Note that today, only dynamic tensor scaling (the default) is supported.
88+
This will mutate the model inplace.
89+
"""
90+
if not self.enabled:
91+
return
92+
93+
from torchao.float8 import convert_to_float8_training
94+
95+
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
7696
convert_to_float8_training(
7797
model,
78-
config=float8_config,
98+
config=self.config,
7999
module_filter_fn=lambda mod, fqn: fqn != "output",
80100
)
81101
logger.info(
82-
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
102+
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
103+
f"{self.config.enable_fsdp_float8_all_gather}"
83104
)
84-
except ImportError as exc:
85-
raise ImportError(
86-
"torchao is not installed. Please install it to use fp8 linear layers."
87-
) from exc
88-
89-
90-
def maybe_precompute_fp8_dynamic_scale_for_fsdp(
91-
model: nn.Module, job_config: JobConfig
92-
):
93-
if not (
94-
job_config.training.enable_float8_linear
95-
and job_config.training.enable_fsdp_float8_all_gather
96-
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
97-
):
98-
return
99-
if not is_sm90_or_later():
100-
warning_once(
101-
logger,
102-
"Skipped precomputing fp8 scales because SM90 or later is not available",
103-
)
104-
return
105-
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
106105

107-
precompute_float8_dynamic_scale_for_fsdp(model)
106+
def precompute_fp8_dynamic_scale_for_fsdp(self, model: nn.Module):
107+
if not self.enabled:
108+
return
108109

110+
if not self.precompute_scale:
111+
return
109112

110-
_sync_float8_amax_and_scale_history = None
113+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
111114

115+
precompute_float8_dynamic_scale_for_fsdp(model)
112116

113-
def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobConfig):
114-
if not (
115-
job_config.training.enable_float8_linear
116-
and (
117-
job_config.training.float8_scaling_type_input == "delayed"
118-
or job_config.training.float8_scaling_type_weight == "delayed"
119-
or job_config.training.float8_scaling_type_grad_output == "delayed"
120-
)
121-
):
122-
return
117+
def sync_float8_amax_and_scale_history(self, model: nn.Module):
118+
if not self.enabled:
119+
return
123120

124-
from torchao.float8 import sync_float8_amax_and_scale_history
121+
if not self.delayed_scaling:
122+
return
125123

126-
# TODO(future): see if precalculating the modules to sync over is going to
127-
# meaningfully help performance
124+
from torchao.float8 import sync_float8_amax_and_scale_history
128125

129-
global _sync_float8_amax_and_scale_history
130-
if _sync_float8_amax_and_scale_history is None:
131-
if job_config.training.compile:
132-
_sync_float8_amax_and_scale_history = torch.compile(
133-
sync_float8_amax_and_scale_history
134-
)
135-
else:
136-
_sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history
126+
# TODO(vkuzo): see if precalculating the modules to sync over is going to
127+
# meaningfully help performance
128+
129+
if self._sync_float8_amax_and_scale_history is None:
130+
if self.compile:
131+
self._sync_float8_amax_and_scale_history = torch.compile(
132+
sync_float8_amax_and_scale_history
133+
)
134+
else:
135+
self._sync_float8_amax_and_scale_history = (
136+
sync_float8_amax_and_scale_history
137+
)
137138

138-
sync_float8_amax_and_scale_history(model)
139+
self._sync_float8_amax_and_scale_history(model)

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def parallelize_llama(
541541
model,
542542
world_mesh["tp"],
543543
loss_parallel=parallel_dims.loss_parallel_enabled,
544-
enable_float8=job_config.training.enable_float8_linear,
544+
enable_float8=job_config.float8.enable_float8_linear,
545545
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
546546
)
547547

0 commit comments

Comments
 (0)