diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index d8447e3..c5d7c44 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.functional as F from float8_experimental.float8_dynamic_linear import Float8DynamicLinear -from float8_experimental.float8_linear import Float8Linear +from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, LinearType, @@ -207,6 +207,9 @@ def main( profile_path_prefix: Path, compile: bool = True, linear_type: str = "dynamic", + scaling_type_x: str = "delayed", + scaling_type_w: str = "delayed", + scaling_type_dL_dY: str = "delayed", model_type: str = "linear", dtype_filter: str = "both", ): @@ -250,9 +253,17 @@ def main( linear_cls = ( Float8Linear if linear_type is LinearType.DELAYED else Float8DynamicLinear ) + extra_kwargs = {} + scaling_type_x = TensorScalingType(scaling_type_x) + scaling_type_w = TensorScalingType(scaling_type_w) + scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY) + if linear_type is LinearType.DELAYED: + extra_kwargs["scaling_type_x"] = scaling_type_x + extra_kwargs["scaling_type_w"] = scaling_type_w + extra_kwargs["scaling_type_dL_dY"] = scaling_type_dL_dY m_float8 = copy.deepcopy(m_ref) - swap_linear_with_float8_linear(m_float8, linear_cls) + swap_linear_with_float8_linear(m_float8, linear_cls, **extra_kwargs) def ref_forw_backward(x): out = m_ref(x) @@ -270,7 +281,9 @@ def float8_forw_backward_wrapper(x): # inspection of the fw+bw torch.compile without the scale # syncing code # TODO(future): make this better - if linear_requires_sync(linear_type): + if linear_requires_sync( + linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY + ): with record_function("scale_amax_and_scales"): sync_amax_history(m_float8) out = float8_forw(x)