-
Notifications
You must be signed in to change notification settings - Fork 227
1686 Logic matching refactor #1687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
63cb6e6
06a1e71
932daf5
ba045bc
d495843
6303210
3394c8c
f012ae6
a709fa6
94d2b94
e740de3
0a7a79f
f5740fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||||||||||||||
from loguru import logger | ||||||||||||||||
from pydantic import ConfigDict, PrivateAttr, model_validator | ||||||||||||||||
from torch.nn import Module | ||||||||||||||||
from operator import attrgetter | ||||||||||||||||
from tqdm import tqdm | ||||||||||||||||
|
||||||||||||||||
from llmcompressor.core import Event, EventType, State | ||||||||||||||||
|
@@ -29,7 +30,7 @@ | |||||||||||||||
from llmcompressor.pipelines.cache import IntermediatesCache | ||||||||||||||||
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent | ||||||||||||||||
from llmcompressor.utils.helpers import calibration_forward_context | ||||||||||||||||
from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers | ||||||||||||||||
from compressed_tensors import match_named_modules | ||||||||||||||||
|
||||||||||||||||
__all__ = ["AWQModifier"] | ||||||||||||||||
|
||||||||||||||||
|
@@ -304,8 +305,8 @@ def _set_resolved_mappings(self, model: Module) -> None: | |||||||||||||||
""" | ||||||||||||||||
resolved_mappings: list[ResolvedMapping] = [] | ||||||||||||||||
for mapping_idx, mapping in enumerate(self.mappings): | ||||||||||||||||
smooth_layers = get_layers( | ||||||||||||||||
mapping.smooth_layer, model, exclude_internal_modules=True | ||||||||||||||||
smooth_layers = match_named_modules( | ||||||||||||||||
model, [mapping.smooth_layer] | ||||||||||||||||
) | ||||||||||||||||
smooth_names = [ | ||||||||||||||||
smooth_name | ||||||||||||||||
|
@@ -323,12 +324,12 @@ def _set_resolved_mappings(self, model: Module) -> None: | |||||||||||||||
smooth_layer = smooth_layers[smooth_name] | ||||||||||||||||
|
||||||||||||||||
smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) | ||||||||||||||||
smooth_parent = get_layer_by_name(smooth_parent_name, model) | ||||||||||||||||
smooth_parent = attrgetter(smooth_parent_name)(model) if smooth_parent_name else model | ||||||||||||||||
|
||||||||||||||||
balance_layers, balance_names = [], [] | ||||||||||||||||
for balance_regex in mapping.balance_layers: | ||||||||||||||||
# find the submodules that match the activation layer | ||||||||||||||||
for balance_suffix, balance_layer in get_layers( | ||||||||||||||||
for balance_suffix, balance_layer in match_named_modules( | ||||||||||||||||
balance_regex, | ||||||||||||||||
smooth_parent, | ||||||||||||||||
exclude_internal_modules=True, | ||||||||||||||||
Comment on lines
+332
to
335
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
@@ -765,7 +766,7 @@ def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Mod | |||||||||||||||
while True: | ||||||||||||||||
if parent_name == "": | ||||||||||||||||
return "", module | ||||||||||||||||
parent = get_layer_by_name(parent_name, module) | ||||||||||||||||
parent = attrgetter(parent_name)(module) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
if not isinstance(parent, torch.nn.ModuleList): | ||||||||||||||||
return parent_name, parent | ||||||||||||||||
parent_name = ".".join(parent_name.split(".")[:-1]) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,7 +11,7 @@ | |||||
) | ||||||
from llmcompressor.utils.fsdp.context import summon_full_params_context | ||||||
from llmcompressor.utils.fsdp.helpers import maybe_get_wrapped, set_wrapped_model | ||||||
from llmcompressor.utils.pytorch.module import get_layers, set_layer | ||||||
from compressed_tensors import match_named_modules | ||||||
|
||||||
__all__ = ["OutputDistillationModifier"] | ||||||
|
||||||
|
@@ -61,8 +61,8 @@ def on_initialize(self, state: State, **kwargs) -> bool: | |||||
else: | ||||||
model_target, teacher_target = target, target | ||||||
|
||||||
model_layers = get_layers(model_target, state.model) | ||||||
teacher_layers = get_layers(teacher_target, state.teacher_model) | ||||||
model_layers = match_named_modules(model_target, state.model) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
teacher_layers = match_named_modules(teacher_target, state.teacher_model) | ||||||
|
||||||
if len(model_layers) < 1: | ||||||
raise ValueError(f"no model layers found for target {target}") | ||||||
|
@@ -85,8 +85,8 @@ def on_initialize(self, state: State, **kwargs) -> bool: | |||||
|
||||||
with summon_full_params_context(state.teacher_model, offload_to_cpu=True): | ||||||
for key, (student_wrapper, teacher_wrapper) in self.wrappers_.items(): | ||||||
set_layer(key, student_wrapper, state.model) | ||||||
set_layer(key, teacher_wrapper, state.teacher_model) | ||||||
Module.set_submodule(key, student_wrapper, state.model) | ||||||
Module.set_submodule(key, teacher_wrapper, state.teacher_model) | ||||||
|
||||||
self.wrapped_kd_model_ = self._create_model_wrapper( | ||||||
student_model=maybe_get_wrapped(state.model), | ||||||
|
@@ -109,8 +109,8 @@ def on_finalize(self, state: State, **kwargs) -> bool: | |||||
|
||||||
with summon_full_params_context(state.teacher_model, offload_to_cpu=True): | ||||||
for key, (student_wrapper, teacher_wrapper) in self.wrappers_.items(): | ||||||
set_layer(key, student_wrapper.layer, state.model) | ||||||
set_layer(key, teacher_wrapper.layer, state.teacher_model) | ||||||
Module.set_submodule(key, student_wrapper.layer, state.model) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we're sure we want to call a class method on |
||||||
Module.set_submodule(key, teacher_wrapper.layer, state.teacher_model) | ||||||
del student_wrapper | ||||||
del teacher_wrapper | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.