diff --git a/README.md b/README.md index 1db1009..fa093c3 100644 --- a/README.md +++ b/README.md @@ -27,21 +27,23 @@ pip install -e ".[dev]" # User API -We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. +We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`). -## float8 linear with dynamic scaling +## float8 linear with dynamic scaling for `x`, `w` and `dL_dY` + +This is the most accurate recipe as every tensor is scaled dynamically. ```python from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) -from float8_experimental.float8_dynamic_linear import Float8DynamicLinear +from float8_experimental.float8_linear import Float8Linear # create model m = Model(...) -# convert all `torch.nn.Linear` modules to `Float8DynamicLinear` -swap_linear_with_float8_linear(m, Float8DynamicLinear) +# convert all `torch.nn.Linear` modules to `Float8Linear` +swap_linear_with_float8_linear(m, Float8Linear) # optional: use FSDP model = FSDP(model, use_orig_params=True) @@ -54,18 +56,27 @@ m = torch.compile(m) ## float8 linear with delayed scaling +This is theoretically the most performant recipe as it minimizes memory reads. + ```python from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) -from float8_experimental.float8_linear import Float8Linear +from float8_experimental.float8_linear import Float8Linear, TensorScalingType # create model m = Model(...) -# convert all `torch.nn.Linear` modules to `Float8Linear` -swap_linear_with_float8_linear(m, Float8Linear) +# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling +# type +swap_linear_with_float8_linear( + m, + Float8Linear, + scaling_type_x=TensorScalingType.DELAYED, + scaling_type_w=TensorScalingType.DELAYED, + scaling_type_dL_dY=TensorScalingType.DELAYED, +) # optional: use FSDP. Note that workarounds gated with config.enable_amax_init and # config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work @@ -93,9 +104,7 @@ for _ in range(N_ITER): # 🧭 Code Organization * `float8_experimental/float8_linear.py` - - `Float8Linear` (main user facing entry point for delayed scaling) -* `float8_experimental/float8_dynamic_linear.py` - - `Float8DynamicLinear` (main user facing entry point for dynamic scaling) + - `Float8Linear` (main user facing entry point for Float8Linear) * `float8_experimental/float8_tensor.py` - `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction - `ScaledMMConfig` defines the semantics for matmul in the forward and backwards pass diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index cbf992e..b1a17e4 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -191,9 +191,9 @@ def swap_linear_with_float8_linear( skip_fqn_list: Optional[List[str]] = None, emulate: bool = False, linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, - scaling_type_x: TensorScalingType = TensorScalingType.DELAYED, - scaling_type_w: TensorScalingType = TensorScalingType.DELAYED, - scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED, + scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC, + scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC, + scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC, ) -> Optional[nn.Module]: """ Swaps `torch.nn.Linear` in `module` with `Float8Linear` or `Float8DynamicLinear`. diff --git a/test/test_compile.py b/test/test_compile.py index 5d90487..834d126 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -299,7 +299,13 @@ def test_sync_amax_func(): module = torch.nn.Sequential( nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) ) - float8_mod = swap_linear_with_float8_linear(module, Float8Linear) + float8_mod = swap_linear_with_float8_linear( + module, + Float8Linear, + scaling_type_x=TensorScalingType.DELAYED, + scaling_type_w=TensorScalingType.DELAYED, + scaling_type_dL_dY=TensorScalingType.DELAYED, + ) compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts) compiled_swap_func(float8_mod) assert cnts.frame_count == 1, "Compiled graph should have 1 frame!" @@ -329,7 +335,13 @@ def test_sync_amax_func_cuda_graph_success(): my_module = nn.Sequential( nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) ).to("cuda") - swap_linear_with_float8_linear(my_module, Float8Linear) + swap_linear_with_float8_linear( + my_module, + Float8Linear, + scaling_type_x=TensorScalingType.DELAYED, + scaling_type_w=TensorScalingType.DELAYED, + scaling_type_dL_dY=TensorScalingType.DELAYED, + ) inpt = torch.randn( 16, 16, device="cuda", dtype=torch.float32, requires_grad=True ) diff --git a/test/test_fsdp.py b/test/test_fsdp.py index ff31ca3..031b40d 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -23,6 +23,8 @@ import torch.nn as nn from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( + linear_requires_sync, + LinearType, swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) @@ -130,7 +132,12 @@ def forward_backward(model, optim, is_fp8, i): optim.zero_grad() y_local = model(ref_input_local[i]) y_local.backward(ref_grad_local[i]) - if is_fp8: + if is_fp8 and linear_requires_sync( + LinearType.DELAYED, + TensorScalingType.DYNAMIC, + scaling_type_w, + TensorScalingType.DYNAMIC, + ): sync_float8_func(model) optim.step() return y_local diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index 389b756..cc44934 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -18,7 +18,7 @@ import torch.multiprocessing as mp import torch.nn as nn from float8_experimental import config -from float8_experimental.float8_linear import Float8Linear +from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, @@ -49,7 +49,14 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): nn.Linear(K, N, dtype=base_dtype), nn.ReLU(), ) - swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate) + swap_linear_with_float8_linear( + m, + Float8Linear, + emulate=emulate, + scaling_type_x=TensorScalingType.DELAYED, + scaling_type_w=TensorScalingType.DELAYED, + scaling_type_dL_dY=TensorScalingType.DELAYED, + ) return m