diff --git a/README.md b/README.md index 838aa8e5..b094c4f0 100644 --- a/README.md +++ b/README.md @@ -64,28 +64,6 @@ model.foo.bar.fc2.sequence_parallel = True # the rest of the flow is the same as the single GPU flow ``` -## weight caching (very experimental) - -```python -import float8_experimental.config as config - -m = Model(...) -# before converting to `Float8Linear`, turn on weight cache buffer allocation -config.allocate_float8_weight_cache_buffers = True - -# in the training loop, manually control the global weight caching setting -for idx in N_ITER: - ... - if idx % n_microbatch == 0: - # if we are in the first pass of a new microbatch, repopulate the cache - config.weight_cache_enabled = False - elif idx % n_microbatch == 1: - # if we are in the second pass of a new microbatch, use cached weight - # this persists until `idx % n_microbatch == 0` again - config.weight_cache_enabled = True - ... -``` - # high level technical design ## UX diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 4543de2c..f0ba914f 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -4,23 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -# -# Weight caching. -# - -# If True, allocates buffers for float8 weight cache -allocate_float8_weight_cache_buffers = False - -# A global flag for controlling the weight cache, off by default. Intended -# usage is for users to modify this from their training loop directly -# according to their microbatching/pipeline parallel setup. -# Note: this is currently a global flag for simplicity and dynamo performance. -weight_cache_enabled = False - -# -# Other -# - # If True, on the first iteration of Float8Linear the amaxes will be # initialized with the incoming data. As of 2023-12-30, this doesn't work # with autocast + torch.compile + FSDP. Enabling this option is nice for diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 642b7dde..a0822153 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -20,10 +20,7 @@ import torch -from float8_experimental.float8_tensor import ( - calculate_amax_and_cast_to_float8, - Float8Tensor, -) +from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import ( amax_history_to_scale, @@ -182,15 +179,6 @@ def __init__(self, *args, **kwargs): # and torch.compile, this option can disable them self.enable_pre_and_post_forward = config.enable_pre_and_post_forward - if config.allocate_float8_weight_cache_buffers: - # this is a buffer to get `to(dtype)` for free - # TODO(future): hide this from serialization - # TODO(future): force this to stay in float8_e4m3fn - self.register_buffer( - "cached_fp8_weight", - torch.empty(self.weight.shape, dtype=torch.float8_e4m3fn), - ) - def register_always_float32_buffer( self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True ) -> None: @@ -247,32 +235,12 @@ def cast_w_to_float8( is_amax_initialized, ) - if config.weight_cache_enabled: - assert config.allocate_float8_weight_cache_buffers, ( - "float8 weight cache buffer must be allocated using " - + "`allocate_float8_weight_cache_buffers` to use the weight cache" - ) - w_bits_fp8 = self.cached_fp8_weight - else: - # manual calculation of fp8 bits: - # 1. calculate the bits without Float8Tensor, without grad - # 2. store the bits here - # 3. create Float8Tensor from the bits calculated in 2 - # motivation: this will take care of saving the bits without - # interacting with tensor subclasses, as w_fp8._data is not - # currently traceable by dynamo - w_bits_fp8 = calculate_amax_and_cast_to_float8( - self.weight, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w - ) - if config.allocate_float8_weight_cache_buffers: - self.cached_fp8_weight.copy_(w_bits_fp8) w_fp8 = Float8Tensor.to_float8( w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w, self.emulate, - cached_casted_weight=w_bits_fp8, ) return w_fp8 diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index f969c98f..415c3bd0 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -156,8 +156,6 @@ def sync_float8_amax_and_scale_history( for idx in range(len(fp8_layers)): child = fp8_layers[idx] - # TODO(future): enable skipping weight related syncing if weight cache - # is on # # 1. in distributed contexts, syncs amax values across workers diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index e050d334..4450fce8 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -12,16 +12,6 @@ aten = torch.ops.aten -@torch.no_grad() -def calculate_amax_and_cast_to_float8(tensor, scale, float8_dtype, amax_buffer): - if amax_buffer is not None: - amax_buffer.fill_(tensor_to_amax(tensor)) - - tensor_scaled = tensor * scale - bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) - return bits_fp8 - - @torch._dynamo.allow_in_graph class ToFloat8ConstrFunc(torch.autograd.Function): """ @@ -36,23 +26,20 @@ def forward( float8_dtype=torch.float8_e4m3fn, amax_buffer=None, emulate: bool = False, - cached_casted_weight=None, ): - if cached_casted_weight is not None: - return Float8Tensor( - cached_casted_weight, scale, tensor.dtype, emulate=emulate - ) - bits_fp8 = calculate_amax_and_cast_to_float8( - tensor, scale, float8_dtype, amax_buffer - ) + if amax_buffer is not None: + amax_buffer.fill_(tensor_to_amax(tensor)) + + tensor_scaled = tensor * scale + bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate) @staticmethod def backward(ctx, g): if isinstance(g, Float8Tensor): - return g.to_original_precision(), None, None, None, None, None + return g.to_original_precision(), None, None, None, None else: - return g, None, None, None, None, None + return g, None, None, None, None @torch._dynamo.allow_in_graph @@ -147,14 +134,7 @@ def to_original_precision(self): @staticmethod @torch._dynamo.allow_in_graph - def to_float8( - tensor, - scale, - float8_dtype, - amax_buffer=None, - emulate: bool = False, - cached_casted_weight=None, - ): + def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False): """Converts a higher precision tensor to float8 in a differentiable way. Args: @@ -172,7 +152,6 @@ def to_float8( float8_dtype, amax_buffer, emulate, - cached_casted_weight, ) @classmethod diff --git a/test/test_base.py b/test/test_base.py index 724d1c2c..2a486942 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -10,9 +10,6 @@ import warnings from enum import Enum -import float8_experimental.config as config -import float8_experimental.float8_linear as float8_linear - import pytest import torch @@ -234,36 +231,6 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype): y.dtype == torch.bfloat16 ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" - @pytest.mark.parametrize("use_compile", [False, True]) - def test_weight_caching(self, use_compile): - M, K, N = 16, 32, 64 - dtype = torch.bfloat16 - config.allocate_float8_weight_cache_buffers = True - - x = torch.randn(M, K, device="cuda", dtype=dtype) - m_ref = nn.Linear(K, N, bias=True, device="cuda", dtype=dtype) - m = Float8Linear.from_float(copy.deepcopy(m_ref), emulate=False) - - if use_compile: - m = torch.compile(m) - - config.weight_cache_enabled = False - - y1 = m(x) - y1.sum().backward() - grad1 = m.weight.grad.clone().detach() - - config.weight_cache_enabled = True - sync_float8_amax_and_scale_history(m) - - y2 = m(x) - y2.sum().backward() - grad2 = m.weight.grad.clone().detach() - - torch.testing.assert_close(grad2, grad1 * 2) - - config.allocate_float8_weight_cache_buffers = False - class TestScaledMM: @unittest.skipIf(