diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index b355098..9ad76f7 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -3,24 +3,18 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -""" -A wrapper around a `torch.nn.Linear` module which does fp8 compute. -""" from typing import Any, Optional, Tuple import torch -import torch.utils._pytree as pytree from float8_experimental.float8_tensor import ( Float8Tensor, - merge_mm_configs, ScaledMMConfig, tensor_already_casted_to_fp8, to_fp8_no_autograd, ) from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale -from torch._prims_common import suggest_memory_format @torch._dynamo.allow_in_graph @@ -63,127 +57,3 @@ def cast_to_float8_e5m2_dynamic_bw( gradY: torch.Tensor, mm_config: ScaledMMConfig ) -> torch.Tensor: return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config) - - -# FSDP pads its local tensor on dim-0. The subclass should be preserved such -# that the padded local tensor (and any transformations like copying to GPU) -# is of the subclass as well. -_ops_to_preserve_subclass = { - torch.ops.aten.empty_like.default, - torch.ops.aten.new_zeros.default, - torch.ops.aten.slice.Tensor, - torch.ops.aten.copy_.default, - torch.ops.aten.view.default, - torch.ops.aten.as_strided.default, - torch.ops.aten._to_copy.default, - torch.ops.aten._pin_memory.default, -} - - -class WeightWithDynamicFloat8CastTensor(torch.Tensor): - @staticmethod - def __new__( - cls, - tensor: torch.Tensor, - mm_config: ScaledMMConfig, - precomputed_scale: Optional[torch.Tensor] = None, - ): - return torch.Tensor._make_wrapper_subclass( - cls, - tensor.size(), - strides=tensor.stride(), - storage_offset=tensor.storage_offset(), - memory_format=suggest_memory_format(tensor), - dtype=tensor.dtype, - layout=tensor.layout, - device=tensor.device, - pin_memory=tensor.is_pinned(), - requires_grad=tensor.requires_grad, - ) - - def __init__( - self, - tensor: torch.Tensor, - mm_config: ScaledMMConfig, - precomputed_scale: Optional[torch.Tensor] = None, - ): - self._tensor = tensor - self._mm_config = mm_config - # for dynamic scaling - # `precompute_float8_dynamic_scale_for_fsdp` calculates scales - # for all float8 parameters after optimizer step - self._precomputed_scale = precomputed_scale - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - if func == torch.ops.aten.detach.default: - return WeightWithDynamicFloat8CastTensor( - args[0]._tensor, args[0]._mm_config - ) - mm_config: Optional[ScaledMMConfig] = None - - def unwrap(t): - nonlocal mm_config - if mm_config is None: - mm_config = t._mm_config - else: - mm_config = merge_mm_configs(mm_config, t._mm_config) - return t._tensor - - args, kwargs = pytree.tree_map_only( - WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {}) - ) - out = func(*args, **kwargs) - if func not in _ops_to_preserve_subclass: - return out - return pytree.tree_map_only( - torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out - ) - - def __tensor_flatten__(self): - if self._precomputed_scale: - return ["_tensor", "_precomputed_scale"], self._mm_config - else: - return ["_tensor"], self._mm_config - - @staticmethod - def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config = flatten_spec - return WeightWithDynamicFloat8CastTensor( - inner_tensors["_tensor"], - mm_config, - getattr(inner_tensors, "_precomputed_scale", None), - ) - - def __repr__(self): - return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" - - def fsdp_pre_all_gather(self, mesh): - if self._precomputed_scale is not None: - float8_tensor = Float8Tensor.to_float8( - self._tensor, - self._precomputed_scale, - torch.float8_e4m3fn, - mm_config=self._mm_config, - ) - else: - float8_tensor = cast_to_float8_e4m3_dynamic( - self._tensor, self._mm_config, reduce_amax=True - ) - return (float8_tensor._data,), (float8_tensor._scale,) - - def fsdp_post_all_gather( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, - *, - out: Optional[torch.Tensor] = None, - ): - (data,) = all_gather_outputs - (scale,) = metadata - if out is not None: - assert isinstance(out, Float8Tensor), f"{type(out)}" - out._scale = scale - return - return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index a7dd2d2..23d6f5c 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -19,7 +19,6 @@ from float8_experimental.float8_dynamic_utils import ( cast_to_float8_e4m3_dynamic, cast_to_float8_e5m2_dynamic_bw, - WeightWithDynamicFloat8CastTensor, ) from float8_experimental.float8_tensor import ( @@ -35,6 +34,8 @@ tensor_to_amax, ) +from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor + def _maybe_initialize_amaxes_scales_for_float8_cast( x, diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 0ade173..2f23a3b 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -1,11 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + import math -from typing import List +from typing import Any, List, Optional, Tuple import torch import torch.nn as nn -from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +import torch.utils._pytree as pytree +from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic + +from float8_experimental.float8_tensor import ( + Float8Tensor, + merge_mm_configs, + ScaledMMConfig, +) + from float8_experimental.float8_utils import EPS +from torch._prims_common import suggest_memory_format @torch.no_grad() @@ -19,6 +33,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: optim.step() precompute_float8_dynamic_scale_for_fsdp(model) """ + from float8_experimental.float8_linear import Float8Linear, TensorScalingType from torch.distributed._tensor import DTensor if any( @@ -50,3 +65,127 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: scales = torch.split(scale_tensor, 1) # Replicate for scale, float8_linear in zip(scales, float8_linears): float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor + + +# FSDP pads its local tensor on dim-0. The subclass should be preserved such +# that the padded local tensor (and any transformations like copying to GPU) +# is of the subclass as well. +_ops_to_preserve_subclass = { + torch.ops.aten.empty_like.default, + torch.ops.aten.new_zeros.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.copy_.default, + torch.ops.aten.view.default, + torch.ops.aten.as_strided.default, + torch.ops.aten._to_copy.default, + torch.ops.aten._pin_memory.default, +} + + +class WeightWithDynamicFloat8CastTensor(torch.Tensor): + @staticmethod + def __new__( + cls, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + precomputed_scale: Optional[torch.Tensor] = None, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + memory_format=suggest_memory_format(tensor), + dtype=tensor.dtype, + layout=tensor.layout, + device=tensor.device, + pin_memory=tensor.is_pinned(), + requires_grad=tensor.requires_grad, + ) + + def __init__( + self, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + precomputed_scale: Optional[torch.Tensor] = None, + ): + self._tensor = tensor + self._mm_config = mm_config + # for dynamic scaling + # `precompute_float8_dynamic_scale_for_fsdp` calculates scales + # for all float8 parameters after optimizer step + self._precomputed_scale = precomputed_scale + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func == torch.ops.aten.detach.default: + return WeightWithDynamicFloat8CastTensor( + args[0]._tensor, args[0]._mm_config + ) + mm_config: Optional[ScaledMMConfig] = None + + def unwrap(t): + nonlocal mm_config + if mm_config is None: + mm_config = t._mm_config + else: + mm_config = merge_mm_configs(mm_config, t._mm_config) + return t._tensor + + args, kwargs = pytree.tree_map_only( + WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {}) + ) + out = func(*args, **kwargs) + if func not in _ops_to_preserve_subclass: + return out + return pytree.tree_map_only( + torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out + ) + + def __tensor_flatten__(self): + if self._precomputed_scale: + return ["_tensor", "_precomputed_scale"], self._mm_config + else: + return ["_tensor"], self._mm_config + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + mm_config = flatten_spec + return WeightWithDynamicFloat8CastTensor( + inner_tensors["_tensor"], + mm_config, + getattr(inner_tensors, "_precomputed_scale", None), + ) + + def __repr__(self): + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" + + def fsdp_pre_all_gather(self, mesh): + if self._precomputed_scale is not None: + float8_tensor = Float8Tensor.to_float8( + self._tensor, + self._precomputed_scale, + torch.float8_e4m3fn, + mm_config=self._mm_config, + ) + else: + float8_tensor = cast_to_float8_e4m3_dynamic( + self._tensor, self._mm_config, reduce_amax=True + ) + return (float8_tensor._data,), (float8_tensor._scale,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ): + (data,) = all_gather_outputs + (scale,) = metadata + if out is not None: + assert isinstance(out, Float8Tensor), f"{type(out)}" + out._scale = scale + return + return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index bdbc878..978ad18 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -7,9 +7,9 @@ import torch._dynamo.testing import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear +from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor from test_fsdp2_common import ( check_parity_bf16_mp, check_parity_no_mp,