From a4886787ddc7485565ddbbc29846233988e7c074 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 9 Jul 2024 13:36:26 -0700 Subject: [PATCH] move WeightWithDynamicFloat8CastTensor to fsdp_utils.py Summary: Refactor in preparation of adding delayed scaling support to float8 all-gather, no logic change. Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_dynamic_utils.py | 106 ------------------ float8_experimental/float8_linear.py | 3 +- float8_experimental/fsdp_utils.py | 112 ++++++++++++++++++++ test/test_fsdp2/test_fsdp2_eager.py | 2 +- 4 files changed, 115 insertions(+), 108 deletions(-) create mode 100644 float8_experimental/fsdp_utils.py diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index f48424c..215a394 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -3,27 +3,16 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -""" -A wrapper around a `torch.nn.Linear` module which does fp8 compute. -""" - -from typing import Any, Optional, Tuple - -import float8_experimental.config as config import torch -import torch.nn as nn -import torch.utils._pytree as pytree from float8_experimental.float8_tensor import ( Float8Tensor, - merge_mm_configs, ScaledMMConfig, tensor_already_casted_to_fp8, to_fp8_no_autograd, ) from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale -from torch._prims_common import suggest_memory_format @torch._dynamo.allow_in_graph @@ -66,98 +55,3 @@ def cast_to_float8_e5m2_dynamic_bw( gradY: torch.Tensor, mm_config: ScaledMMConfig ) -> torch.Tensor: return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config) - - -# FSDP pads its local tensor on dim-0. The subclass should be preserved such -# that the padded local tensor (and any transformations like copying to GPU) -# is of the subclass as well. -_ops_to_preserve_subclass = { - torch.ops.aten.empty_like.default, - torch.ops.aten.new_zeros.default, - torch.ops.aten.slice.Tensor, - torch.ops.aten.copy_.default, - torch.ops.aten.view.default, - torch.ops.aten.as_strided.default, - torch.ops.aten._to_copy.default, - torch.ops.aten._pin_memory.default, -} - - -class WeightWithDynamicFloat8CastTensor(torch.Tensor): - @staticmethod - def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): - return torch.Tensor._make_wrapper_subclass( - cls, - tensor.size(), - strides=tensor.stride(), - storage_offset=tensor.storage_offset(), - memory_format=suggest_memory_format(tensor), - dtype=tensor.dtype, - layout=tensor.layout, - device=tensor.device, - pin_memory=tensor.is_pinned(), - requires_grad=tensor.requires_grad, - ) - - def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): - self._tensor = tensor - self._mm_config = mm_config - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - if func == torch.ops.aten.detach.default: - return WeightWithDynamicFloat8CastTensor( - args[0]._tensor, args[0]._mm_config - ) - mm_config: Optional[ScaledMMConfig] = None - - def unwrap(t): - nonlocal mm_config - if mm_config is None: - mm_config = t._mm_config - else: - mm_config = merge_mm_configs(mm_config, t._mm_config) - return t._tensor - - args, kwargs = pytree.tree_map_only( - WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {}) - ) - out = func(*args, **kwargs) - if func not in _ops_to_preserve_subclass: - return out - return pytree.tree_map_only( - torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out - ) - - def __tensor_flatten__(self): - return ["_tensor"], self._mm_config - - @staticmethod - def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config = flatten_spec - return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config) - - def __repr__(self): - return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" - - def fsdp_pre_all_gather(self, mesh): - float8_tensor = cast_to_float8_e4m3_dynamic( - self._tensor, self._mm_config, reduce_amax=True - ) - return (float8_tensor._data,), (float8_tensor._scale,) - - def fsdp_post_all_gather( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, - *, - out: Optional[torch.Tensor] = None, - ): - (data,) = all_gather_outputs - (scale,) = metadata - if out is not None: - assert isinstance(out, Float8Tensor), f"{type(out)}" - out._scale = scale - return - return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index a7dd2d2..23d6f5c 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -19,7 +19,6 @@ from float8_experimental.float8_dynamic_utils import ( cast_to_float8_e4m3_dynamic, cast_to_float8_e5m2_dynamic_bw, - WeightWithDynamicFloat8CastTensor, ) from float8_experimental.float8_tensor import ( @@ -35,6 +34,8 @@ tensor_to_amax, ) +from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor + def _maybe_initialize_amaxes_scales_for_float8_cast( x, diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py new file mode 100644 index 0000000..41871d8 --- /dev/null +++ b/float8_experimental/fsdp_utils.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic + +from float8_experimental.float8_tensor import ( + Float8Tensor, + merge_mm_configs, + ScaledMMConfig, +) +from torch._prims_common import suggest_memory_format + +# FSDP pads its local tensor on dim-0. The subclass should be preserved such +# that the padded local tensor (and any transformations like copying to GPU) +# is of the subclass as well. +_ops_to_preserve_subclass = { + torch.ops.aten.empty_like.default, + torch.ops.aten.new_zeros.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.copy_.default, + torch.ops.aten.view.default, + torch.ops.aten.as_strided.default, + torch.ops.aten._to_copy.default, + torch.ops.aten._pin_memory.default, +} + + +class WeightWithDynamicFloat8CastTensor(torch.Tensor): + @staticmethod + def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): + return torch.Tensor._make_wrapper_subclass( + cls, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + memory_format=suggest_memory_format(tensor), + dtype=tensor.dtype, + layout=tensor.layout, + device=tensor.device, + pin_memory=tensor.is_pinned(), + requires_grad=tensor.requires_grad, + ) + + def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): + self._tensor = tensor + self._mm_config = mm_config + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func == torch.ops.aten.detach.default: + return WeightWithDynamicFloat8CastTensor( + args[0]._tensor, args[0]._mm_config + ) + mm_config: Optional[ScaledMMConfig] = None + + def unwrap(t): + nonlocal mm_config + if mm_config is None: + mm_config = t._mm_config + else: + mm_config = merge_mm_configs(mm_config, t._mm_config) + return t._tensor + + args, kwargs = pytree.tree_map_only( + WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {}) + ) + out = func(*args, **kwargs) + if func not in _ops_to_preserve_subclass: + return out + return pytree.tree_map_only( + torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out + ) + + def __tensor_flatten__(self): + return ["_tensor"], self._mm_config + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + mm_config = flatten_spec + return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config) + + def __repr__(self): + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" + + def fsdp_pre_all_gather(self, mesh): + float8_tensor = cast_to_float8_e4m3_dynamic( + self._tensor, self._mm_config, reduce_amax=True + ) + return (float8_tensor._data,), (float8_tensor._scale,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ): + (data,) = all_gather_outputs + (scale,) = metadata + if out is not None: + assert isinstance(out, Float8Tensor), f"{type(out)}" + out._scale = scale + return + return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 57123cd..de627a6 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -8,9 +8,9 @@ import torch._dynamo.testing import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear +from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor from test_fsdp2_common import ( check_parity_bf16_mp, check_parity_no_mp,