diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 013f45b1..e29de691 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -160,9 +160,9 @@ def sync_float8_amax_and_scale_history( # 1. in distributed contexts, syncs amax values across workers # if dist.is_initialized(): - child.fp8_amax_x = fp8_amax_x_tensor[idx] - child.fp8_amax_w = fp8_amax_w_tensor[idx] - child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx] + child.fp8_amax_x = fp8_amax_x_tensor[idx].clone() + child.fp8_amax_w = fp8_amax_w_tensor[idx].clone() + child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx].clone() # # 2. adds the `amax` values to history