Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[bc breaking] unify filtering functions #322

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,19 @@ from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_f
# create model
m = Model(...)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(fqn: str, mod: torch.nn.Module):
# don't convert the output module
if fqn == "output":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True

# convert all `torch.nn.Linear` modules to `Float8Linear`
swap_linear_with_float8_linear(m)
swap_linear_with_float8_linear(m, module_filter_fn=module_filter_fn)

# optional: use FSDP
model = FSDP(model, use_orig_params=True)
Expand Down
74 changes: 28 additions & 46 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,11 @@ def _update_history_stack(
amax_history_stack.copy_(new_amax_history_stack)


def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear], bool]:
"""
Returns a callable that filters out small (dimensions less than the given `size_limit`)
and unaligned (dimenstions not divisible by 16) layers.
It can be passed as the `linear_layer_filter` argument to `swap_linear_with_float8_linear`.
"""
return (
lambda linear_layer: linear_layer.in_features >= size_limit
and linear_layer.out_features >= size_limit
and linear_layer.in_features % 16 == 0
and linear_layer.out_features % 16 == 0
)


def swap_linear_layers(
module: nn.Module,
from_float_func: Callable[[nn.Linear], nn.Linear],
*,
skip_fqn_list: Optional[List[str]] = None,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
) -> Optional[nn.Module]:
"""
Generic function to swap linear layers in a module with a new type of linear layer.
Expand All @@ -90,18 +75,15 @@ def swap_linear_layers(
Args:
module: Module to modify.
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
skip_fqn_list: If specified, a list of module FQNs to skip.
linear_layer_filter: If specified, only the linear layers
that pass the filter function will be swapped.
from_float_kwargs: Additional keyword arguments for from_float_func.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.

Returns:
nn.Module: The modified module with swapped linear layers.
"""
module_names_to_skip = set(skip_fqn_list or [])

if isinstance(module, nn.Linear) and (
linear_layer_filter is None or linear_layer_filter(module)
module_filter_fn is None or module_filter_fn("", module)
):
if len(list(module.children())) > 0:
raise AssertionError(
Expand All @@ -112,43 +94,44 @@ def swap_linear_layers(
)

root_module = module
visited_modules = {root_module}

for module_name, module in root_module.named_modules():
if module_name in module_names_to_skip:
visited_modules.add(module)

def post_order_traversal(
module: nn.Module, module_name: str, parent_module: Optional[nn.Module]
module: nn.Module,
cur_fqn: Optional[str] = None,
parent_module: Optional[nn.Module] = None,
):
nonlocal visited_modules
if cur_fqn is None:
cur_fqn = ""

for child_module_name, child_module in module.named_children():
if child_module not in visited_modules:
visited_modules.add(child_module)
post_order_traversal(child_module, child_module_name, module)
if cur_fqn == "":
new_fqn = child_module_name
else:
new_fqn = f"{cur_fqn}.{child_module_name}"

post_order_traversal(child_module, new_fqn, module)

if isinstance(module, nn.Linear) and (
linear_layer_filter is None or linear_layer_filter(module)
# linear_layer_filter is None or linear_layer_filter(module)
module_filter_fn is None
or module_filter_fn(cur_fqn, module)
):
assert (
parent_module is not None
), f"Linear root module should return early: {module}"
new_linear_module = from_float_func(module)
setattr(parent_module, module_name, new_linear_module)
cur_module_name = cur_fqn.split(".")[-1]
setattr(parent_module, cur_module_name, new_linear_module)

post_order_traversal(root_module, "", None)
# Without this explicit `del`, this set only gets deleted upon an explicit
# garbage collection (not from when its refcount hits zero)
del visited_modules
post_order_traversal(root_module)
return root_module


def swap_linear_with_float8_linear(
module: nn.Module,
*,
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC,
Expand All @@ -158,10 +141,10 @@ def swap_linear_with_float8_linear(

Args:
module: Module to modify.
skip_fqn_list: If specified, a list of module FQNs to skip.
emulate: If True, emulation is used instead of hardware accelerated gemm
linear_layer_filter: If specified, only the linear layers
that pass the filter function will be swapped.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
scaling_type_x (TensorScalingType): scaling type for `x`
scaling_type_w (TensorScalingType): scaling type for `w`
scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY`
Expand All @@ -179,8 +162,7 @@ def swap_linear_with_float8_linear(
return swap_linear_layers(
module,
from_float,
skip_fqn_list=skip_fqn_list,
linear_layer_filter=linear_layer_filter,
module_filter_fn=module_filter_fn,
)


Expand Down
10 changes: 6 additions & 4 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import dataclass

from enum import auto, Enum
from typing import List, Optional
from typing import Callable, List, Optional

import float8_experimental.config as config

Expand Down Expand Up @@ -209,7 +209,7 @@ def quantize_to_float8(
module: nn.Module,
quant_config: QuantConfig,
*,
skip_fqn_list: Optional[List[str]] = None,
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
use_fast_accum: bool = True,
) -> Optional[nn.Module]:
"""
Expand All @@ -222,7 +222,9 @@ def quantize_to_float8(
Args:
module (nn.Module): The module to modify.
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.

Returns:
Expand All @@ -234,5 +236,5 @@ def quantize_to_float8(
return swap_linear_layers(
module,
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
skip_fqn_list=skip_fqn_list,
module_filter_fn=module_filter_fn,
)
74 changes: 42 additions & 32 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
filter_out_small_unaligned_layers,
linear_requires_sync,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
Expand Down Expand Up @@ -631,24 +630,34 @@ def __init__(self, dim: int):
self.lin1 = nn.Linear(dim, 4 * dim)
self.lin2 = nn.Linear(4 * dim, 4 * dim)

for emulate in [True, False]:
model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40))
# filter out the linear layers whose shape is smaller than 32 or non-divisible by 16.
model = swap_linear_with_float8_linear(
model,
emulate=emulate,
linear_layer_filter=filter_out_small_unaligned_layers(32),
model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40))
# filter out the linear layers whose shape is smaller than 32 or non-divisible by 16.

size_limit = 32

def module_filter_fn(fqn, mod):
return (
mod.in_features >= size_limit
and mod.out_features >= size_limit
and mod.in_features % 16 == 0
and mod.out_features % 16 == 0
)
# in_features=8, out_features=32, 8 is less than 32.
self.assertNotIsInstance(model[0].lin1, Float8Linear)
# in_features=32, out_features=32,
self.assertIsInstance(model[0].lin2, Float8Linear)
# in_features=32, out_features=32,
self.assertIsInstance(model[1], Float8Linear)
# in_features=40, out_features=160, 40 is not divisible by 16.
self.assertNotIsInstance(model[2].lin1, Float8Linear)
# in_features=160, out_features=160,
self.assertIsInstance(model[2].lin2, Float8Linear)

model = swap_linear_with_float8_linear(
model,
emulate=True,
module_filter_fn=module_filter_fn,
)
# in_features=8, out_features=32, 8 is less than 32.
self.assertNotIsInstance(model[0].lin1, Float8Linear)
# in_features=32, out_features=32,
self.assertIsInstance(model[0].lin2, Float8Linear)
# in_features=32, out_features=32,
self.assertIsInstance(model[1], Float8Linear)
# in_features=40, out_features=160, 40 is not divisible by 16.
self.assertNotIsInstance(model[2].lin1, Float8Linear)
# in_features=160, out_features=160,
self.assertIsInstance(model[2].lin2, Float8Linear)

def test_swap_submodule_linears_with_skip(self):
class MLP(nn.Module):
Expand All @@ -657,20 +666,21 @@ def __init__(self, dim: int):
self.lin1 = nn.Linear(dim, 4 * dim)
self.lin2 = nn.Linear(4 * dim, dim)

for emulate in [True, False]:
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
skip_fqn_list = ["2", "0.lin2"]
model = swap_linear_with_float8_linear(
model, emulate=emulate, skip_fqn_list=skip_fqn_list
)
self.assertIsInstance(model[0].lin1, Float8Linear)
self.assertNotIsInstance(model[0].lin2, Float8Linear)
self.assertIsInstance(model[0].lin2, nn.Linear)
self.assertIsInstance(model[1], Float8Linear)
self.assertNotIsInstance(model[2].lin2, Float8Linear)
self.assertNotIsInstance(model[2].lin2, Float8Linear)
self.assertIsInstance(model[2].lin1, nn.Linear)
self.assertIsInstance(model[2].lin2, nn.Linear)
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
module_filter_fn = lambda fqn, mod: fqn not in [
"0.lin2",
"2.lin1",
]
model = swap_linear_with_float8_linear(
model,
emulate=True,
module_filter_fn=module_filter_fn,
)
self.assertTrue(type(model[0].lin1) is Float8Linear)
self.assertTrue(type(model[0].lin2) is nn.Linear)
self.assertTrue(type(model[1]) is Float8Linear)
self.assertTrue(type(model[2].lin1) is nn.Linear)
self.assertTrue(type(model[2].lin2) is Float8Linear)

def test_fp8_tensor_statistics(self):
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
Expand Down
Loading