diff --git a/README.md b/README.md index b094c4f0..838aa8e5 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,28 @@ model.foo.bar.fc2.sequence_parallel = True # the rest of the flow is the same as the single GPU flow ``` +## weight caching (very experimental) + +```python +import float8_experimental.config as config + +m = Model(...) +# before converting to `Float8Linear`, turn on weight cache buffer allocation +config.allocate_float8_weight_cache_buffers = True + +# in the training loop, manually control the global weight caching setting +for idx in N_ITER: + ... + if idx % n_microbatch == 0: + # if we are in the first pass of a new microbatch, repopulate the cache + config.weight_cache_enabled = False + elif idx % n_microbatch == 1: + # if we are in the second pass of a new microbatch, use cached weight + # this persists until `idx % n_microbatch == 0` again + config.weight_cache_enabled = True + ... +``` + # high level technical design ## UX diff --git a/float8_experimental/config.py b/float8_experimental/config.py new file mode 100644 index 00000000..0f8b96be --- /dev/null +++ b/float8_experimental/config.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +# +# Weight caching. +# + +# If True, allocates buffers for float8 weight cache +allocate_float8_weight_cache_buffers = False + +# A global flag for controlling the weight cache, off by default. Intended +# usage is for users to modify this from their training loop directly +# according to their microbatching/pipeline parallel setup. +# Note: this is currently a global flag for simplicity and dynamo performance. +weight_cache_enabled = False diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index cae09649..82a8ccbf 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -16,9 +16,14 @@ from typing import Optional +import float8_experimental.config as config + import torch -from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_tensor import ( + calculate_amax_and_cast_to_float8, + Float8Tensor, +) from float8_experimental.float8_utils import ( amax_history_to_scale, @@ -172,6 +177,15 @@ def __init__(self, *args, **kwargs): # will access the scale when it has ensured that it is on GPU. self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs) + if config.allocate_float8_weight_cache_buffers: + # this is a buffer to get `to(dtype)` for free + # TODO(future): hide this from serialization + # TODO(future): force this to stay in float8_e4m3fn + self.register_buffer( + "cached_fp8_weight", + torch.empty(self.weight.shape, dtype=torch.float8_e4m3fn), + ) + def register_always_float32_buffer( self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True ) -> None: @@ -228,8 +242,33 @@ def cast_w_to_float8( torch.float8_e4m3fn, is_amax_initialized, ) + + if config.weight_cache_enabled: + assert config.allocate_float8_weight_cache_buffers, ( + "float8 weight cache buffer must be allocated using " + + "`allocate_float8_weight_cache_buffers` to use the weight cache" + ) + w_bits_fp8 = self.cached_fp8_weight + else: + # manual calculation of fp8 bits: + # 1. calculate the bits without Float8Tensor, without grad + # 2. store the bits here + # 3. create Float8Tensor from the bits calculated in 2 + # motivation: this will take care of saving the bits without + # interacting with tensor subclasses, as w_fp8._data is not + # currently traceable by dynamo + w_bits_fp8 = calculate_amax_and_cast_to_float8( + self.weight, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w + ) + if config.allocate_float8_weight_cache_buffers: + self.cached_fp8_weight.copy_(w_bits_fp8) w_fp8 = Float8Tensor.to_float8( - w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w, self.emulate + w, + self.fp8_scale_w, + torch.float8_e4m3fn, + self.fp8_amax_w, + self.emulate, + cached_casted_weight=w_bits_fp8, ) return w_fp8 diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 013f45b1..7514629f 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -156,6 +156,9 @@ def sync_float8_amax_and_scale_history( for idx in range(len(fp8_layers)): child = fp8_layers[idx] + # TODO(future): enable skipping weight related syncing if weight cache + # is on + # # 1. in distributed contexts, syncs amax values across workers # diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 66457836..00c5f6aa 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -12,6 +12,16 @@ aten = torch.ops.aten +@torch.no_grad() +def calculate_amax_and_cast_to_float8(tensor, scale, float8_dtype, amax_buffer): + if amax_buffer is not None: + amax_buffer.fill_(tensor_to_amax(tensor)) + + tensor_scaled = tensor * scale + bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) + return bits_fp8 + + class ToFloat8ConstrFunc(torch.autograd.Function): """ A differentiable conversion to fp8 @@ -25,24 +35,23 @@ def forward( float8_dtype=torch.float8_e4m3fn, amax_buffer=None, emulate: bool = False, + cached_casted_weight=None, ): - # In TransformerEngine, the casts to float8 are fused with calculating - # the new amax value. In this codebase, the eager mode code for those - # two things is colocated in this function. We expect PT2.0 to fuse it - # for us. - if amax_buffer is not None: - amax_buffer.fill_(tensor_to_amax(tensor)) - - tensor_scaled = tensor * scale - bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) + if cached_casted_weight is not None: + return Float8Tensor( + cached_casted_weight, scale, tensor.dtype, emulate=emulate + ) + bits_fp8 = calculate_amax_and_cast_to_float8( + tensor, scale, float8_dtype, amax_buffer + ) return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate) @staticmethod def backward(ctx, g): if isinstance(g, Float8Tensor): - return g.to_original_precision(), None, None, None, None + return g.to_original_precision(), None, None, None, None, None else: - return g, None, None, None, None + return g, None, None, None, None, None class FromFloat8ConstrFunc(torch.autograd.Function): @@ -122,7 +131,7 @@ def __tensor_flatten__(self): return ["_data", "_scale"], ctx @staticmethod - def __tensor_unflatten__(inner_tensors: Dict, metadata): + def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride): assert len(inner_tensors) == 2 return Float8Tensor( inner_tensors["_data"], @@ -136,7 +145,14 @@ def to_original_precision(self): @staticmethod @torch._dynamo.allow_in_graph - def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False): + def to_float8( + tensor, + scale, + float8_dtype, + amax_buffer=None, + emulate: bool = False, + cached_casted_weight=None, + ): """Converts a higher precision tensor to float8 in a differentiable way. Args: @@ -149,7 +165,12 @@ def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = Fal Float8Tensor: a float8 tensor """ return ToFloat8ConstrFunc.apply( - tensor, scale, float8_dtype, amax_buffer, emulate + tensor, + scale, + float8_dtype, + amax_buffer, + emulate, + cached_casted_weight, ) @classmethod diff --git a/test/test_base.py b/test/test_base.py index 79f6817b..97a085e9 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -10,6 +10,9 @@ import warnings from enum import Enum +import float8_experimental.config as config +import float8_experimental.float8_linear as float8_linear + import pytest import torch @@ -231,6 +234,36 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype): y.dtype == torch.bfloat16 ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" + @pytest.mark.parametrize("use_compile", [False, True]) + def test_weight_caching(self, use_compile): + M, K, N = 16, 32, 64 + dtype = torch.bfloat16 + config.allocate_float8_weight_cache_buffers = True + + x = torch.randn(M, K, device="cuda", dtype=dtype) + m_ref = nn.Linear(K, N, bias=True, device="cuda", dtype=dtype) + m = Float8Linear.from_float(copy.deepcopy(m_ref), emulate=False) + + if use_compile: + m = torch.compile(m) + + config.weight_cache_enabled = False + + y1 = m(x) + y1.sum().backward() + grad1 = m.weight.grad.clone().detach() + + config.weight_cache_enabled = True + sync_float8_amax_and_scale_history(m) + + y2 = m(x) + y2.sum().backward() + grad2 = m.weight.grad.clone().detach() + + torch.testing.assert_close(grad2, grad1 * 2) + + config.allocate_float8_weight_cache_buffers = False + class TestScaledMM: @unittest.skipIf(