From b41cd9fbd845ebd0312d486406b02f962be96153 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 28 Dec 2023 14:11:39 -0800 Subject: [PATCH] [wip] make Float8Linear amax init more FSDP+compile friendly Summary: Need to use functional collectives to help torch.compile trace through distributed code (https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py) Numerics are off, debugging Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_utils.py | 8 +++++++- test/test_fsdp.py | 22 +++++++++++++++------- test/test_fsdp.sh | 11 ++++++----- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 4e65ef99..9717a430 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist +import torch.distributed._functional_collectives as _functional_collectives # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -60,7 +61,12 @@ def tensor_to_amax(x, distributed_reduction=False): # If the user did not ask for it, assume that it will # happen elsewhere. if distributed_reduction and dist.is_initialized(): - dist.all_reduce(amax, op=dist.ReduceOp.MAX) + # TODO(future): support process groups + ranks = list(range(dist.get_world_size())) + # print('ranks', ranks) + # print('old amax', amax) + amax = _functional_collectives.all_reduce(amax, "max", group=ranks) + # print('new amax', amax) return amax diff --git a/test/test_fsdp.py b/test/test_fsdp.py index 33e1653d..d4eb530e 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -104,10 +104,13 @@ def fsdp_main(rank, world_size, args): ref_input_local = ref_input_global[bsz_local_start:bsz_local_end].to(rank) sync_float8_func = sync_float8_amax_and_scale_history - if compile: + if compile and False: sync_float8_func = torch.compile( sync_float8_amax_and_scale_history, fullgraph=fullgraph ) + model = torch.compile(model) + + print(model) def forward_backward(model): optimizer.zero_grad() @@ -118,14 +121,19 @@ def forward_backward(model): return y_local for iter in range(N_ITER): - # We first run one iteration without compile, as a workaround to compile float8 layer. - # In the first iter, float8 layers go to the branches of "self.is_amax_initialized == False" - # After that, float8 layers go the the branches of "self.is_amax_initialized == True" - # TODO: Need to fix compile to run wihtout this workaround. - if iter == 1 and compile: - model = torch.compile(model, fullgraph=fullgraph) y_local = forward_backward(model) + if compile and False: + base = model._orig_mod._fsdp_wrapped_module + else: + base = model._fsdp_wrapped_module + print('0_x', base[0].fp8_amax_history_x) + print('2_x', base[2].fp8_amax_history_x) + print('0_w', base[0].fp8_amax_history_w) + print('2_w', base[2].fp8_amax_history_w) + print('0_g', base[0].fp8_amax_history_dL_dY) + print('2_g', base[2].fp8_amax_history_dL_dY) + # get global y y_global = [ torch.zeros(*y_local.shape, dtype=base_dtype).to(rank) diff --git a/test/test_fsdp.sh b/test/test_fsdp.sh index 624b3969..4e4cdeca 100755 --- a/test/test_fsdp.sh +++ b/test/test_fsdp.sh @@ -7,12 +7,12 @@ launch() { echo "launching IS_FP8 $IS_FP8, compile_fsdp $COMPILE, fullgraph $FULLGRAPH" # generate the test data - python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH - echo "Success: ✅" + # python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH + # echo "Success: ✅" # generate single GPU model output and updated state dict - python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH - echo "Success: ✅" + # python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH + # echo "Success: ✅" # generate FSDP model output and updated state dict # the NCCL_DEBUG setting is to avoid log spew @@ -30,7 +30,8 @@ launch() { } # IS_FP8, COMPILE, FULLGRAPH -for i in False,False,False True,False,False True,True,False +# for i in False,False,False True,False,False True,True,False +for i in True,True,False do IFS=","; set -- $i; IS_FP8=$1; COMPILE=$2; FULLGRAPH=$3