Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 73eadae

Browse files
committed
Thread through the scaling type argument to float8 constructors
ghstack-source-id: b09361e Pull Request resolved: #301
1 parent 36405a7 commit 73eadae

File tree

10 files changed

+275
-81
lines changed

10 files changed

+275
-81
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ repos:
1010
- id: trailing-whitespace
1111
- id: check-ast
1212
- id: check-merge-conflict
13-
- id: no-commit-to-branch
14-
args: ['--branch=main']
1513
- id: check-added-large-files
1614
args: ['--maxkb=500']
1715
- id: end-of-file-fixer

float8_experimental/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66
# Lets define a few top level things here
77
from float8_experimental.float8_linear import Float8Linear
8-
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
8+
from float8_experimental.float8_tensor import (
9+
Float8Tensor,
10+
ScaledMMConfig,
11+
ScalingStrategy,
12+
)
913

1014
# Needed to load Float8Tensor with weights_only = True
1115
from torch.serialization import add_safe_globals
1216

13-
add_safe_globals([Float8Tensor, ScaledMMConfig])
17+
add_safe_globals([Float8Tensor, ScaledMMConfig, ScalingStrategy])
1418

1519
__all__ = ["Float8Tensor", "Float8Linear"]

float8_experimental/float8_dynamic_linear.py

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Float8Tensor,
2020
merge_mm_configs,
2121
ScaledMMConfig,
22+
ScalingStrategy,
2223
tensor_already_casted_to_fp8,
2324
to_fp8_no_autograd,
2425
)
@@ -36,21 +37,27 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
3637
@staticmethod
3738
def forward(
3839
ctx,
39-
tensor,
40+
tensor: torch.Tensor,
4041
mm_config: ScaledMMConfig,
42+
scaling_strategy: ScalingStrategy,
4143
):
4244
ctx.mm_config = mm_config
45+
ctx.scaling_strategy = scaling_strategy
4346
return tensor
4447

4548
@staticmethod
46-
def backward(ctx, gradY):
49+
def backward(ctx, gradY: torch.Tensor):
4750
if tensor_already_casted_to_fp8(gradY):
48-
return gradY, None
51+
return gradY, None, None
4952
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
5053
fp8_tensor = to_fp8_no_autograd(
51-
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
54+
gradY,
55+
gradY_scale,
56+
e5m2_dtype,
57+
mm_config=ctx.mm_config,
58+
scaling_strategy=ctx.scaling_strategy,
5259
)
53-
return fp8_tensor, None
60+
return fp8_tensor, None, None
5461

5562

5663
class Float8DynamicLinear(torch.nn.Linear):
@@ -63,13 +70,15 @@ def __init__(self, **super_kwargs):
6370
super().__init__(**super_kwargs)
6471

6572
def forward(self, input: torch.Tensor) -> torch.Tensor:
66-
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config)
73+
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config, self.scaling_strategy)
6774
if isinstance(self.weight, Float8Tensor): # cast by FSDP
6875
w_fp8 = self.weight
6976
else:
70-
w_fp8 = cast_to_float8_e4m3fn(self.weight, self.forward_config)
77+
w_fp8 = cast_to_float8_e4m3fn(
78+
self.weight, self.forward_config, self.scaling_strategy
79+
)
7180
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
72-
y = cast_to_float8_e5m2_bw(y, self.backward_config)
81+
y = cast_to_float8_e5m2_bw(y, self.backward_config, self.scaling_strategy)
7382
return y
7483

7584
@classmethod
@@ -101,9 +110,14 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
101110
fp8_output=False,
102111
pad_inner_dim=config.pad_inner_dim,
103112
)
113+
# TODO: For now hardcode TensorWise scaling
114+
new_mod.scaling_strategy = ScalingStrategy.TensorWise
115+
104116
if config.enable_fsdp_fp8_all_gather:
105117
new_mod.weight = nn.Parameter(
106-
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
118+
WeightWithDynamicFloat8CastTensor(
119+
mod.weight, new_mod.forward_config, new_mod.scaling_strategy
120+
)
107121
)
108122
else:
109123
new_mod.weight = mod.weight
@@ -112,18 +126,27 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
112126

113127

114128
def cast_to_float8_e4m3fn(
115-
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
129+
inpt_tensor: torch.Tensor,
130+
mm_config: ScaledMMConfig,
131+
scaling_strategy: ScalingStrategy,
132+
reduce_amax: bool = False,
116133
) -> Float8Tensor:
117134
if tensor_already_casted_to_fp8(inpt_tensor):
118135
return inpt_tensor
119136
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
120-
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
137+
return Float8Tensor.to_float8(
138+
inpt_tensor,
139+
scale,
140+
e4m3_dtype,
141+
mm_config=mm_config,
142+
scaling_strategy=scaling_strategy,
143+
)
121144

122145

123146
def cast_to_float8_e5m2_bw(
124-
gradY: torch.Tensor, mm_config: ScaledMMConfig
147+
gradY: torch.Tensor, mm_config: ScaledMMConfig, scaling_strategy: ScalingStrategy
125148
) -> torch.Tensor:
126-
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
149+
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config, scaling_strategy)
127150

128151

129152
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
@@ -143,7 +166,12 @@ def cast_to_float8_e5m2_bw(
143166

144167
class WeightWithDynamicFloat8CastTensor(torch.Tensor):
145168
@staticmethod
146-
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
169+
def __new__(
170+
cls,
171+
tensor: torch.Tensor,
172+
mm_config: ScaledMMConfig,
173+
scaling_strategy: ScalingStrategy,
174+
):
147175
return torch.Tensor._make_wrapper_subclass(
148176
cls,
149177
tensor.size(),
@@ -157,24 +185,38 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
157185
requires_grad=tensor.requires_grad,
158186
)
159187

160-
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
188+
def __init__(
189+
self,
190+
tensor: torch.Tensor,
191+
mm_config: ScaledMMConfig,
192+
scaling_strategy: ScalingStrategy,
193+
):
161194
self._tensor = tensor
162195
self._mm_config = mm_config
196+
self._scaling_strategy = scaling_strategy
163197

164198
@classmethod
165199
def __torch_dispatch__(cls, func, types, args, kwargs=None):
166200
if func == torch.ops.aten.detach.default:
167201
return WeightWithDynamicFloat8CastTensor(
168-
args[0]._tensor, args[0]._mm_config
202+
args[0]._tensor, args[0]._mm_config, args[0]._scaling_strategy
169203
)
170204
mm_config: Optional[ScaledMMConfig] = None
205+
scaling_strategy: Optional[ScalingStrategy] = None
171206

172207
def unwrap(t):
173208
nonlocal mm_config
209+
nonlocal scaling_strategy
174210
if mm_config is None:
175211
mm_config = t._mm_config
176212
else:
177213
mm_config = merge_mm_configs(mm_config, t._mm_config)
214+
215+
if scaling_strategy is None:
216+
scaling_strategy = t._scaling_strategy
217+
else:
218+
# TODO For now we assume that the scaling strategy is same across all tensors
219+
assert scaling_strategy == t._scaling_strategy
178220
return t._tensor
179221

180222
args, kwargs = pytree.tree_map_only(
@@ -184,23 +226,31 @@ def unwrap(t):
184226
if func not in _ops_to_preserve_subclass:
185227
return out
186228
return pytree.tree_map_only(
187-
torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out
229+
torch.Tensor,
230+
lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config, scaling_strategy),
231+
out,
188232
)
189233

190234
def __tensor_flatten__(self):
191-
return ["_tensor"], self._mm_config
235+
return ["_tensor"], {
236+
"_mm_config": self._mm_config,
237+
"_scaling_strategy": self._scaling_strategy,
238+
}
192239

193240
@staticmethod
194241
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
195-
mm_config = flatten_spec
196-
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)
242+
mm_config = flatten_spec["_mm_config"]
243+
scaling_strategy = flatten_spec["_scaling_strategy"]
244+
return WeightWithDynamicFloat8CastTensor(
245+
inner_tensors["_tensor"], mm_config, scaling_strategy
246+
)
197247

198248
def __repr__(self):
199-
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"
249+
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config}, scaling_strategy={self._scaling_strategy})"
200250

201251
def fsdp_pre_all_gather(self, mesh):
202252
float8_tensor = cast_to_float8_e4m3fn(
203-
self._tensor, self._mm_config, reduce_amax=True
253+
self._tensor, self._mm_config, self._scaling_strategy, reduce_amax=True
204254
)
205255
return (float8_tensor._data,), (float8_tensor._scale,)
206256

@@ -218,4 +268,6 @@ def fsdp_post_all_gather(
218268
assert isinstance(out, Float8Tensor), f"{type(out)}"
219269
out._scale = scale
220270
return
221-
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)
271+
return Float8Tensor(
272+
data, scale, param_dtype, self._mm_config, self._scaling_strategy
273+
), (data,)

float8_experimental/float8_linear.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from float8_experimental.float8_tensor import (
1919
Float8Tensor,
2020
ScaledMMConfig,
21+
ScalingStrategy,
2122
to_fp8_no_autograd,
2223
)
2324

@@ -75,11 +76,13 @@ def forward(
7576
scale_fn_name,
7677
is_amax_initialized,
7778
mm_config: ScaledMMConfig,
79+
scaling_strategy: ScalingStrategy,
7880
):
7981
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
8082
ctx.scale_fn_name = scale_fn_name
8183
ctx.is_amax_initialized = is_amax_initialized
8284
ctx.mm_config = mm_config
85+
ctx.scaling_strategy = scaling_strategy
8386
return tensor
8487

8588
@staticmethod
@@ -102,9 +105,13 @@ def backward(ctx, go):
102105
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
103106

104107
res = to_fp8_no_autograd(
105-
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
108+
go,
109+
fp8_scale_dL_dY,
110+
e5m2_dtype,
111+
mm_config=ctx.mm_config,
112+
scaling_strategy=ctx.scaling_strategy,
106113
)
107-
empty_grads = None, None, None, None, None, None
114+
empty_grads = None, None, None, None, None, None, None
108115
return res, *empty_grads
109116

110117

@@ -150,6 +157,9 @@ def __init__(self, *args, **kwargs):
150157
self.forward_config = ScaledMMConfig()
151158
self.backward_config = ScaledMMConfig()
152159

160+
# Defines the scaling strategy for the forward and backwards pass
161+
self.scaling_strategy = ScalingStrategy.TensorWise
162+
153163
# Note: is_amax_initialized is not a buffer to avoid data dependent
154164
# control flow visible to dynamo
155165
# TODO(future PR): add serialization for this flag
@@ -288,6 +298,7 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
288298
scale_fn_name,
289299
self.is_amax_initialized,
290300
self.backward_config,
301+
self.scaling_strategy,
291302
)
292303
return y
293304

@@ -353,4 +364,6 @@ def from_float(cls, mod, emulate: bool = False):
353364
new_mod.backward_config = ScaledMMConfig(
354365
emulate, False, False, config.pad_inner_dim
355366
)
367+
# TODO: For now hardcode TensorWise scaling
368+
new_mod.scaling_strategy = ScalingStrategy.TensorWise
356369
return new_mod

0 commit comments

Comments
 (0)