diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 978ad18..b5ae234 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -72,12 +72,6 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32): dist.broadcast(global_inp, src=0) return global_inp.view(self.world_size, -1)[self.rank].view(16, 16) - def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module: - kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC - kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC - kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC - return swap_linear_with_float8_linear(module, **kwargs) - class TestFloat8MultiProcess(FSDPTest, TestFloat8Common): @property @@ -106,11 +100,11 @@ def _test_transformer_parity_dynamic( # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to # fp8 for that tied weight, incorrectly using fp8 for the embedding. weight_tying = not enable_fsdp_fp8_all_gather - module = self.init_transformer(weight_tying=weight_tying) + module = self.init_transformer(weight_tying=weight_tying).cuda() ref_module = copy.deepcopy(module) - ref_module = self.swap_linear_with_dynamic(ref_module).cuda() + swap_linear_with_float8_linear(ref_module) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = self.swap_linear_with_dynamic(module) + swap_linear_with_float8_linear(module) for submodule in module.modules(): if isinstance(submodule, TransformerBlock): fully_shard(submodule) @@ -153,7 +147,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): # Emulate the fp8 matmul to bypass the scaled matmul op's divisibility # requirement to use a smaller activation size with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - model = self.swap_linear_with_dynamic(model, emulate=True) + swap_linear_with_float8_linear(model, emulate=True) model_unsharded_numel = sum(p.numel() for p in model.parameters()) model_sharded_numel = (model_unsharded_numel + 1) // 2 block_lin_weight_numel = 0 @@ -331,7 +325,8 @@ def get_expected_all_gather_size(module: nn.Module): module_fp32 = self.init_single_module() ref_module = copy.deepcopy(module_fp32) with set_enable_fsdp_fp8_all_gather(True): - module = self.swap_linear_with_dynamic(module_fp32) + module_fp32 = swap_linear_with_float8_linear(module_fp32) + module = module_fp32 fully_shard(module) local_inp = self.get_local_inp() expected_all_gather_size = get_expected_all_gather_size(ref_module) @@ -359,7 +354,7 @@ def get_expected_all_gather_size(module: nn.Module): module = self.init_multi_module() ref_module = copy.deepcopy(module) with set_enable_fsdp_fp8_all_gather(True): - module = self.swap_linear_with_dynamic(module) + module = swap_linear_with_float8_linear(module) for submodule in module: fully_shard(submodule) fully_shard(module) @@ -383,10 +378,11 @@ def test_fp32_fp8_single_module_parity(self): """ for enable_fsdp_fp8_all_gather in [False, True]: module_fp32 = self.init_single_module() - ref_module = self.swap_linear_with_dynamic(copy.deepcopy(module_fp32)) + ref_module = copy.deepcopy(module_fp32) + ref_module = swap_linear_with_float8_linear(ref_module) ref_module = ref_module.cuda() with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = self.swap_linear_with_dynamic(module_fp32) + module = swap_linear_with_float8_linear(module_fp32) fully_shard(module) ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) @@ -407,11 +403,11 @@ def test_fp32_fp8_multi_module_parity(self): multiple modules/FSDP communication groups. """ for enable_fsdp_fp8_all_gather in [False, True]: - module = self.init_multi_module() + module = self.init_multi_module().cuda() ref_module = copy.deepcopy(module) - ref_module = self.swap_linear_with_dynamic(ref_module).cuda() + ref_module = swap_linear_with_float8_linear(ref_module) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = self.swap_linear_with_dynamic(module) + module = swap_linear_with_float8_linear(module) for submodule in module: fully_shard(submodule) fully_shard(module)