From a1f44e01db6b4e6449aec3503734cf2fd9314c9b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 1 Jul 2024 13:57:56 -0700 Subject: [PATCH 1/3] [4/x] add tests for DTensor TP/SP + Float8Linear Summary: Makes the DTensor TP/SP tests also test `Float8Linear` with all scaling types configured to be dynamic. We can add support for delayed scaling with float8 all-gather for `x` and `dL_dY` in a future PR, as needed. Test Plan: ``` ./test/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_tensor_parallel.py | 35 +++++++++--- test/test_dtensor.py | 53 ++++++++++++++++--- 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index b84c2e9..f5cc61d 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -4,6 +4,7 @@ cast_to_float8_e4m3_dynamic, cast_to_float8_e5m2_dynamic_bw, ) +from float8_experimental.float8_linear import TensorScalingType from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( @@ -22,6 +23,15 @@ # NOTE: This only works and tested with the DynamicLinear +def _float8_linear_supports_float8_allgather(m): + # TODO(future PR): add support for delayed scaling for activations + # and gradients + return ( + m.scaling_type_x == TensorScalingType.DYNAMIC + and m.scaling_type_dL_dY == TensorScalingType.DYNAMIC + ) + + class Float8ColwiseParallel(ColwiseParallel): @staticmethod def _prepare_input_fn( @@ -61,11 +71,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_dynamic_linear import Float8DynamicLinear + from float8_experimental.float8_linear import Float8Linear - if not isinstance(module, Float8DynamicLinear): + if not isinstance(module, (Float8DynamicLinear, Float8Linear)): raise ValueError( - f"Expecting module to be Float8DynamicLinear but found {type(module)}" + f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}" ) + elif isinstance( + module, Float8Linear + ) and not _float8_linear_supports_float8_allgather(module): + raise AssertionError("unsupported") return super()._apply(module, device_mesh) @@ -107,11 +122,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_dynamic_linear import Float8DynamicLinear + from float8_experimental.float8_linear import Float8Linear - if not isinstance(module, Float8DynamicLinear): + if not isinstance(module, (Float8DynamicLinear, Float8Linear)): raise ValueError( - f"Expecting module to be Float8DynamicLinear but found {type(module)}" + f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}" ) + elif isinstance( + module, Float8Linear + ) and not _float8_linear_supports_float8_allgather(module): + raise AssertionError("unsupported") return super()._apply(module, device_mesh) @@ -184,22 +204,23 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_dynamic_linear import Float8DynamicLinear + from float8_experimental.float8_linear import Float8Linear fwd_linear_config = None if self.fwd_config_submodule_fqn is not None: fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) - assert isinstance(fwd_linear, Float8DynamicLinear) + assert isinstance(fwd_linear, (Float8DynamicLinear, Float8Linear)) fwd_linear_config = fwd_linear.forward_config else: # search for ScaledMM configs for all the submodules and make sure they are the same for mod in module.modules(): - if isinstance(mod, Float8DynamicLinear): + if isinstance(mod, (Float8DynamicLinear, Float8Linear)): if fwd_linear_config is None: fwd_linear_config = mod.forward_config else: assert ( fwd_linear_config == mod.forward_config - ), "All the Float8DynamicLinear modules should have same forward config!" + ), "All the Float8DynamicLinear and Float8Linear modules should have same forward config!" self.fwd_linear_config = fwd_linear_config super()._apply(module, device_mesh) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 354f831..24a5e58 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -18,6 +18,7 @@ Float8DynamicLinear, NoopFwToFloat8E5M2Bw, ) +from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig from float8_experimental.float8_tensor_parallel import ( @@ -169,23 +170,37 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): loss.backward() -def test_fp8_mlp_tensor_parallelism_base( - mesh: DeviceMesh, size=16, compile: bool = False +def _test_fp8_mlp_tensor_parallelism_base( + mesh: DeviceMesh, size=16, compile: bool = False, use_float8_linear: bool = False ): device = mesh.device_type + # TODO(future): delete Float8DynamicLinear from this test once all the + # code is unified + float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear + extra_kwargs = {} + if use_float8_linear: + # For now, just use Float8Linear with dynamic scaling, which is the + # same behavior as Float8Linear. + # TODO(future): add support for float8 all-gather with delayed scaling + # for activations and gradients. + extra_kwargs = { + "scaling_type_x": TensorScalingType.DYNAMIC, + "scaling_type_w": TensorScalingType.DYNAMIC, + "scaling_type_dL_dY": TensorScalingType.DYNAMIC, + } toy_model = ToyModel().to(device) toy_model_fp8 = swap_linear_with_float8_linear( - toy_model, Float8DynamicLinear, emulate=True + toy_model, float8_cls, emulate=True, **extra_kwargs ) tp_model = copy.deepcopy(toy_model) tp_model = swap_linear_with_float8_linear( - tp_model, Float8DynamicLinear, emulate=True + tp_model, float8_cls, emulate=True, **extra_kwargs ) sp_model = copy.deepcopy(toy_model) sp_model = swap_linear_with_float8_linear( - sp_model, Float8DynamicLinear, emulate=True + sp_model, float8_cls, emulate=True, **extra_kwargs ) # vanilla TP @@ -218,7 +233,7 @@ def test_fp8_mlp_tensor_parallelism_base( # PrepareFloat8ModuleInput with specific submodule fqn sp_model2 = copy.deepcopy(toy_model) sp_model2 = swap_linear_with_float8_linear( - sp_model2, Float8DynamicLinear, emulate=True + sp_model2, Float8DynamicLinear, emulate=True, **extra_kwargs ) sp_model2 = parallelize_module( @@ -271,8 +286,28 @@ def test_fp8_mlp_tensor_parallelism_base( ) +def test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base( + mesh, size, compile=False, use_float8_linear=False + ) + + +def test_fp8_mlp_tensor_parallelism_eager_float8_linear(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base( + mesh, size, compile=False, use_float8_linear=True + ) + + def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): - test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) + _test_fp8_mlp_tensor_parallelism_base( + mesh, size, compile=True, use_float8_linear=False + ) + + +def test_fp8_mlp_tensor_parallelism_compile_float8_linear(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base( + mesh, size, compile=True, use_float8_linear=True + ) if __name__ == "__main__": @@ -285,8 +320,10 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): test_fp8_redistribute, test_dtensor_cast_to_fp8, test_dtensor_fp8_autograd, - test_fp8_mlp_tensor_parallelism_base, + test_fp8_mlp_tensor_parallelism_eager, + test_fp8_mlp_tensor_parallelism_eager_float8_linear, test_fp8_mlp_tensor_parallelism_compile, + test_fp8_mlp_tensor_parallelism_compile_float8_linear, ] for test in tqdm(tests, desc="Running tests"): From a447f7c45de7c0d4450c80ff29f03f63de80f11d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 1 Jul 2024 14:36:28 -0700 Subject: [PATCH 2/3] Update on "[4/x] add tests for DTensor TP/SP + Float8Linear" Summary: Makes the DTensor TP/SP tests also test `Float8Linear` with all scaling types configured to be dynamic. We can add support for delayed scaling with float8 all-gather for `x` and `dL_dY` in a future PR, as needed. Test Plan: ``` ./test/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_tensor_parallel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index f5cc61d..b778c53 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -20,11 +20,12 @@ # here is that in input/output handling we do casting after # creating the DTensor. -# NOTE: This only works and tested with the DynamicLinear +# NOTE: This only works and tested with the dynamic scaling +# (Float8DynamicLinear and Float8Linear with dynamic scaling for all tensors) def _float8_linear_supports_float8_allgather(m): - # TODO(future PR): add support for delayed scaling for activations + # TODO(future): add support for delayed scaling for activations # and gradients return ( m.scaling_type_x == TensorScalingType.DYNAMIC From 024225f1206caa294d9ce2cba9364f3ffcdd45d3 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 2 Jul 2024 13:32:30 -0700 Subject: [PATCH 3/3] Update on "[4/x] add tests for DTensor TP/SP + Float8Linear" Summary: Makes the DTensor TP/SP tests also test `Float8Linear` with all scaling types configured to be dynamic. We can add support for delayed scaling with float8 all-gather for `x` and `dL_dY` in a future PR, as needed. Test Plan: ``` ./test/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_tensor_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index b778c53..fac0201 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -20,7 +20,7 @@ # here is that in input/output handling we do casting after # creating the DTensor. -# NOTE: This only works and tested with the dynamic scaling +# NOTE: This only works and tested with the dynamic scaling # (Float8DynamicLinear and Float8Linear with dynamic scaling for all tensors)