Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit e4b126a

Browse files
committed
[wip] support float8 weight caching for gradient accumulation/PP
Summary: In the cases where the optimizer update does not happen after every forward such as microbatching/PP, we can save the casted weight to trade some time for memory. For now I'm just testing out performance+accuracy. We can improve on the API in future PRs. In terms of accuracy this should be no change, I will validate this further if we want to land this. For performance, on @drisspg's LLaMa 7B pretrain script, with bsz==128 and micro_bsz == 1: 1. baseline bf16 + compile: 2.38 it/s 2. delayed scaling + compile: 2.80 it/s (1.18x over baseline) 3. delayed scaling + compile + this PR: 3.04 it/s (1.28x over baseline) Test Plan: ``` pytest test/test_base.py -s -k test_weight_caching ``` Reviewers: Subscribers: Tasks: Tags:
1 parent d0c6760 commit e4b126a

File tree

5 files changed

+107
-15
lines changed

5 files changed

+107
-15
lines changed

float8_experimental/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# If True, allocates buffers for float8 weight cache
8+
allocate_float8_weight_cache_buffers = False
9+
10+
# A global flag for controlling the weight cache, off by default. Intended
11+
# usage is for users to modify this from their training loop directly
12+
# according to their microbatching/pipeline parallel setup.
13+
# Note: this is currently a global flag for simplicity and dynamo performance.
14+
weight_cache_enabled = False

float8_experimental/float8_linear.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
_maybe_initialize_amaxes_scales_for_float8_cast,
2323
)
2424

25-
from float8_experimental.float8_tensor import Float8Tensor
25+
from float8_experimental.float8_tensor import (
26+
Float8Tensor,
27+
calculate_amax_and_cast_to_float8,
28+
)
2629

2730
from float8_experimental.float8_utils import (
2831
E4M3_MAX_POS,
@@ -31,6 +34,8 @@
3134
to_fp8_saturated,
3235
)
3336

37+
import float8_experimental.config as config
38+
3439

3540
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
3641
"""
@@ -148,6 +153,15 @@ def __init__(self, *args, **kwargs):
148153
# will access the scale when it has ensured that it is on GPU.
149154
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs)
150155

156+
if config.allocate_float8_weight_cache_buffers:
157+
# this is a buffer to get `to(dtype)` for free
158+
# TODO(future): hide this from serialization
159+
# TODO(future): force this to stay in float8_e4m3fn
160+
self.register_buffer(
161+
'cached_fp8_weight',
162+
torch.empty(self.weight.shape, dtype=torch.float8_e4m3fn),
163+
)
164+
151165
def register_always_float32_buffer(
152166
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
153167
) -> None:
@@ -204,8 +218,26 @@ def cast_w_to_float8(
204218
torch.float8_e4m3fn,
205219
is_amax_initialized,
206220
)
221+
222+
if config.weight_cache_enabled:
223+
assert config.allocate_float8_weight_cache_buffers
224+
w_bits_fp8 = self.cached_fp8_weight
225+
else:
226+
# manual calculation of fp8 bits:
227+
# 1. calculate the bits without Float8Tensor, without grad
228+
# 2. store the bits here
229+
# 3. create Float8Tensor from the bits calculated in 2
230+
# motivation: this will take care of saving the bits without
231+
# interacting with tensor subclasses, as w_fp8._data is not
232+
# currently traceable by dynamo
233+
w_bits_fp8 = calculate_amax_and_cast_to_float8(
234+
self.weight, self.fp8_scale_w, torch.float8_e4m3fn,
235+
self.fp8_amax_w)
236+
if config.allocate_float8_weight_cache_buffers:
237+
self.cached_fp8_weight.copy_(w_bits_fp8)
207238
w_fp8 = Float8Tensor.to_float8(
208-
w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w, self.emulate
239+
w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w,
240+
self.emulate, cached_casted_weight=w_bits_fp8,
209241
)
210242
return w_fp8
211243

float8_experimental/float8_linear_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def sync_float8_amax_and_scale_history(
149149
if not isinstance(child, fp8_classes):
150150
continue
151151

152+
# TODO(future): enable skipping weight related syncing if weight cache
153+
# is on
154+
152155
#
153156
# 1. in distributed contexts, syncs amax values across workers
154157
#

float8_experimental/float8_tensor.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212
aten = torch.ops.aten
1313

14+
@torch.no_grad()
15+
def calculate_amax_and_cast_to_float8(tensor, scale, float8_dtype, amax_buffer):
16+
if amax_buffer is not None:
17+
amax_buffer.fill_(tensor_to_amax(tensor))
18+
19+
tensor_scaled = tensor * scale
20+
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
21+
return bits_fp8
22+
1423

1524
class ToFloat8ConstrFunc(torch.autograd.Function):
1625
"""
@@ -25,24 +34,21 @@ def forward(
2534
float8_dtype=torch.float8_e4m3fn,
2635
amax_buffer=None,
2736
emulate: bool = False,
37+
cached_casted_weight = None,
2838
):
29-
# In TransformerEngine, the casts to float8 are fused with calculating
30-
# the new amax value. In this codebase, the eager mode code for those
31-
# two things is colocated in this function. We expect PT2.0 to fuse it
32-
# for us.
33-
if amax_buffer is not None:
34-
amax_buffer.fill_(tensor_to_amax(tensor))
35-
36-
tensor_scaled = tensor * scale
37-
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
39+
if cached_casted_weight is not None:
40+
return Float8Tensor(cached_casted_weight, scale, tensor.dtype,
41+
emulate=emulate)
42+
bits_fp8 = calculate_amax_and_cast_to_float8(
43+
tensor, scale, float8_dtype, amax_buffer)
3844
return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate)
3945

4046
@staticmethod
4147
def backward(ctx, g):
4248
if isinstance(g, Float8Tensor):
43-
return g.to_original_precision(), None, None, None, None
49+
return g.to_original_precision(), None, None, None, None, None
4450
else:
45-
return g, None, None, None, None
51+
return g, None, None, None, None, None
4652

4753

4854
class FromFloat8ConstrFunc(torch.autograd.Function):
@@ -123,6 +129,9 @@ def __tensor_flatten__(self):
123129

124130
@staticmethod
125131
def __tensor_unflatten__(inner_tensors: Dict, metadata):
132+
# TODO(TBD): this seems unused, and it's out of date after
133+
# the new args in https://github.com/pytorch/pytorch/pull/114311
134+
# we should just delete it
126135
assert len(inner_tensors) == 2
127136
return Float8Tensor(
128137
inner_tensors["_data"],
@@ -136,7 +145,7 @@ def to_original_precision(self):
136145

137146
@staticmethod
138147
@torch._dynamo.allow_in_graph
139-
def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False):
148+
def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False, cached_casted_weight = None):
140149
"""Converts a higher precision tensor to float8 in a differentiable way.
141150
142151
Args:
@@ -149,7 +158,7 @@ def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = Fal
149158
Float8Tensor: a float8 tensor
150159
"""
151160
return ToFloat8ConstrFunc.apply(
152-
tensor, scale, float8_dtype, amax_buffer, emulate
161+
tensor, scale, float8_dtype, amax_buffer, emulate, cached_casted_weight,
153162
)
154163

155164
@classmethod

test/test_base.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
import torch.nn as nn
17+
import float8_experimental.float8_linear as float8_linear
1718
from float8_experimental.float8_linear import Float8Linear
1819
from float8_experimental.float8_linear_utils import (
1920
get_float8_linear,
@@ -32,6 +33,8 @@
3233
tensor_to_scale,
3334
)
3435

36+
import float8_experimental.config as config
37+
3538
random.seed(0)
3639
torch.manual_seed(0)
3740

@@ -231,6 +234,37 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
231234
y.dtype == torch.bfloat16
232235
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
233236

237+
@pytest.mark.parametrize("use_compile", [False, True])
238+
def test_weight_caching(self, use_compile):
239+
M, K, N = 16, 32, 64
240+
dtype = torch.bfloat16
241+
config.allocate_float8_weight_cache_buffers = True
242+
243+
x = torch.randn(M, K, device="cuda", dtype=dtype)
244+
m_ref = nn.Linear(K, N, bias=True, device="cuda", dtype=dtype)
245+
m = Float8Linear.from_float(copy.deepcopy(m_ref), emulate = False)
246+
247+
if use_compile:
248+
m = torch.compile(m)
249+
250+
config.weight_cache_enabled = False
251+
252+
y1 = m(x)
253+
y1.sum().backward()
254+
grad1 = m.weight.grad.clone().detach()
255+
256+
config.weight_cache_enabled = True
257+
sync_float8_amax_and_scale_history(m)
258+
259+
y2 = m(x)
260+
y2.sum().backward()
261+
grad2 = m.weight.grad.clone().detach()
262+
263+
torch.testing.assert_close(grad2, grad1 * 2)
264+
265+
config.allocate_float8_weight_cache_buffers = False
266+
267+
234268

235269
class TestScaledMM:
236270
@unittest.skipIf(

0 commit comments

Comments
 (0)