|
12 | 12 |
|
13 | 13 | # Note: Performance
|
14 | 14 | # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
|
15 |
| -import functools |
16 |
| -from typing import Optional |
17 | 15 |
|
18 | 16 | import torch
|
19 | 17 | import torch.nn as nn
|
20 |
| -from torch._logging import warning_once |
21 | 18 |
|
22 | 19 | from torchtitan.config_manager import JobConfig
|
23 | 20 | from torchtitan.logging import logger
|
| 21 | +from torchtitan.parallelisms import ParallelDims |
24 | 22 |
|
25 | 23 |
|
26 |
| -@functools.lru_cache(None) |
27 | 24 | def is_sm90_or_later():
|
28 | 25 | # Float8 is only supported on H100+ GPUs
|
29 | 26 | return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
|
30 | 27 |
|
31 | 28 |
|
32 |
| -def maybe_build_fp8_linear( |
33 |
| - model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False |
34 |
| -): |
35 |
| - """ |
36 |
| - This function converts the linear layers to `Float8Linear`. Note that today, |
37 |
| - only dynamic tensor scaling (the default) is supported. |
38 |
| -
|
39 |
| - This will mutate the model inplace. |
40 |
| - """ |
41 |
| - enable_float8_linear = job_config.training.enable_float8_linear |
42 |
| - if not enable_float8_linear: |
43 |
| - return |
44 |
| - if not is_sm90_or_later(): |
45 |
| - warning_once( |
46 |
| - logger, |
47 |
| - "Failed to swap to Float8Linear because SM90 or later is not available", |
48 |
| - ) |
49 |
| - return |
50 |
| - try: |
51 |
| - from torchao.float8 import ( |
52 |
| - CastConfig, |
53 |
| - convert_to_float8_training, |
54 |
| - Float8LinearConfig, |
55 |
| - ScalingType, |
56 |
| - ) |
| 29 | +class Float8Handler: |
| 30 | + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): |
| 31 | + self.enabled = False |
| 32 | + |
| 33 | + float8_config = job_config.float8 |
| 34 | + if not float8_config.enable_float8_linear: |
| 35 | + return |
| 36 | + if not is_sm90_or_later(): |
| 37 | + logger.warning( |
| 38 | + "Failed to swap to Float8Linear because SM90 or later is not available", |
| 39 | + ) |
| 40 | + return |
| 41 | + try: |
| 42 | + from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType |
| 43 | + except ImportError as e: |
| 44 | + raise ImportError( |
| 45 | + "torchao is not installed. Please install it to use fp8 linear layers." |
| 46 | + ) from e |
57 | 47 |
|
58 | 48 | # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
|
59 | 49 | enable_fsdp_float8_all_gather = (
|
60 |
| - job_config.training.enable_fsdp_float8_all_gather and dp_enabled |
61 |
| - ) |
62 |
| - scaling_type_input = ScalingType(job_config.training.float8_scaling_type_input) |
63 |
| - scaling_type_weight = ScalingType( |
64 |
| - job_config.training.float8_scaling_type_weight |
| 50 | + parallel_dims.dp_enabled |
| 51 | + and parallel_dims.dp_type == "fsdp" |
| 52 | + and float8_config.enable_fsdp_float8_all_gather |
65 | 53 | )
|
66 |
| - scaling_type_grad_output = ScalingType( |
67 |
| - job_config.training.float8_scaling_type_grad_output |
68 |
| - ) |
69 |
| - float8_config = Float8LinearConfig( |
| 54 | + scaling_type_input = ScalingType(float8_config.scaling_type_input) |
| 55 | + scaling_type_weight = ScalingType(float8_config.scaling_type_weight) |
| 56 | + scaling_type_grad_output = ScalingType(float8_config.scaling_type_grad_output) |
| 57 | + self.config = Float8LinearConfig( |
70 | 58 | enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
|
71 | 59 | cast_config_input=CastConfig(scaling_type=scaling_type_input),
|
72 | 60 | cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
|
73 | 61 | cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
|
74 | 62 | enable_pre_and_post_forward=False,
|
75 | 63 | )
|
| 64 | + |
| 65 | + self.enabled = True |
| 66 | + |
| 67 | + # for precompute_fp8_dynamic_scale_for_fsdp |
| 68 | + self.precompute_scale = ( |
| 69 | + enable_fsdp_float8_all_gather |
| 70 | + and float8_config.precompute_float8_dynamic_scale_for_fsdp |
| 71 | + ) |
| 72 | + |
| 73 | + # for sync_float8_amax_and_scale_history |
| 74 | + self.delayed_scaling = ( |
| 75 | + scaling_type_input == "delayed" |
| 76 | + or scaling_type_weight == "delayed" |
| 77 | + or scaling_type_grad_output == "delayed" |
| 78 | + ) |
| 79 | + self._sync_float8_amax_and_scale_history = None |
| 80 | + self.compile = job_config.training.compile |
| 81 | + |
| 82 | + logger.info("Float8 training active") |
| 83 | + |
| 84 | + def convert_to_float8_training(self, model: nn.Module): |
| 85 | + """ |
| 86 | + This function converts the linear layers of `model` to `Float8Linear`. |
| 87 | + Note that today, only dynamic tensor scaling (the default) is supported. |
| 88 | + This will mutate the model inplace. |
| 89 | + """ |
| 90 | + if not self.enabled: |
| 91 | + return |
| 92 | + |
| 93 | + from torchao.float8 import convert_to_float8_training |
| 94 | + |
| 95 | + # Mutates the model inplace replacing instances of nn.Linear with Float8Linear |
76 | 96 | convert_to_float8_training(
|
77 | 97 | model,
|
78 |
| - config=float8_config, |
| 98 | + config=self.config, |
79 | 99 | module_filter_fn=lambda mod, fqn: fqn != "output",
|
80 | 100 | )
|
81 | 101 | logger.info(
|
82 |
| - f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" |
| 102 | + "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" |
| 103 | + f"{self.config.enable_fsdp_float8_all_gather}" |
83 | 104 | )
|
84 |
| - except ImportError as exc: |
85 |
| - raise ImportError( |
86 |
| - "torchao is not installed. Please install it to use fp8 linear layers." |
87 |
| - ) from exc |
88 |
| - |
89 |
| - |
90 |
| -def maybe_precompute_fp8_dynamic_scale_for_fsdp( |
91 |
| - model: nn.Module, job_config: JobConfig |
92 |
| -): |
93 |
| - if not ( |
94 |
| - job_config.training.enable_float8_linear |
95 |
| - and job_config.training.enable_fsdp_float8_all_gather |
96 |
| - and job_config.training.precompute_float8_dynamic_scale_for_fsdp |
97 |
| - ): |
98 |
| - return |
99 |
| - if not is_sm90_or_later(): |
100 |
| - warning_once( |
101 |
| - logger, |
102 |
| - "Skipped precomputing fp8 scales because SM90 or later is not available", |
103 |
| - ) |
104 |
| - return |
105 |
| - from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp |
106 | 105 |
|
107 |
| - precompute_float8_dynamic_scale_for_fsdp(model) |
| 106 | + def precompute_fp8_dynamic_scale_for_fsdp(self, model: nn.Module): |
| 107 | + if not self.enabled: |
| 108 | + return |
108 | 109 |
|
| 110 | + if not self.precompute_scale: |
| 111 | + return |
109 | 112 |
|
110 |
| -_sync_float8_amax_and_scale_history = None |
| 113 | + from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp |
111 | 114 |
|
| 115 | + precompute_float8_dynamic_scale_for_fsdp(model) |
112 | 116 |
|
113 |
| -def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobConfig): |
114 |
| - if not ( |
115 |
| - job_config.training.enable_float8_linear |
116 |
| - and ( |
117 |
| - job_config.training.float8_scaling_type_input == "delayed" |
118 |
| - or job_config.training.float8_scaling_type_weight == "delayed" |
119 |
| - or job_config.training.float8_scaling_type_grad_output == "delayed" |
120 |
| - ) |
121 |
| - ): |
122 |
| - return |
| 117 | + def sync_float8_amax_and_scale_history(self, model: nn.Module): |
| 118 | + if not self.enabled: |
| 119 | + return |
123 | 120 |
|
124 |
| - from torchao.float8 import sync_float8_amax_and_scale_history |
| 121 | + if not self.delayed_scaling: |
| 122 | + return |
125 | 123 |
|
126 |
| - # TODO(future): see if precalculating the modules to sync over is going to |
127 |
| - # meaningfully help performance |
| 124 | + from torchao.float8 import sync_float8_amax_and_scale_history |
128 | 125 |
|
129 |
| - global _sync_float8_amax_and_scale_history |
130 |
| - if _sync_float8_amax_and_scale_history is None: |
131 |
| - if job_config.training.compile: |
132 |
| - _sync_float8_amax_and_scale_history = torch.compile( |
133 |
| - sync_float8_amax_and_scale_history |
134 |
| - ) |
135 |
| - else: |
136 |
| - _sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history |
| 126 | + # TODO(vkuzo): see if precalculating the modules to sync over is going to |
| 127 | + # meaningfully help performance |
| 128 | + |
| 129 | + if self._sync_float8_amax_and_scale_history is None: |
| 130 | + if self.compile: |
| 131 | + self._sync_float8_amax_and_scale_history = torch.compile( |
| 132 | + sync_float8_amax_and_scale_history |
| 133 | + ) |
| 134 | + else: |
| 135 | + self._sync_float8_amax_and_scale_history = ( |
| 136 | + sync_float8_amax_and_scale_history |
| 137 | + ) |
137 | 138 |
|
138 |
| - sync_float8_amax_and_scale_history(model) |
| 139 | + self._sync_float8_amax_and_scale_history(model) |
0 commit comments