-
Notifications
You must be signed in to change notification settings - Fork 607
Add abstract base class for attention mechanisms with unified interface #8039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8039
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 2 Unrelated FailuresAs of commit 9ccf542 with merge base c0676fe ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
mask: Optional[torch.Tensor] = None, | ||
input_pos: Optional[torch.Tensor] = None, | ||
in_cache_state: Optional[Any] = None, | ||
out_cache_state: Optional[Any] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace them with kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
directly using kwargs may break type-safe. Keep them as is and consider using TypedDict and Unpack for kwarg type checking later.
what is the progress on this pr? im currently trying to convert a distilled deekseek R1 to pte using example scripts Let me know if I can help testing this out. |
@CypherpunkSamurai I'm trying to complete the refactor and land it this week. Let me create an issue of adding the MLA to this interface. |
ce1b50c
to
00ec564
Compare
@pytorchbot label "topic: not user facing" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this ready for review? Also would probably tag with an actual release note tag since I assume this would be good to highlight in 0.6 release notes
from executorch.examples.models.llama.rope import Rope | ||
|
||
|
||
class Attention(nn.Module, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far a specialized implementation is only used during lowering and on device, and it needs to be able to accept checkpoint from whatever definition was used during training. What do see as the usage pattern going forward? Is the AttentionMHA
below the standard definition that specialization of this class needs to support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the AttentionMHA below the standard definition that specialization of this class needs to support?
Not necessarily. The attention type is added into the model args. Usually the model arg and checkpoint will be saved in one place. We use model arg to build the model, and load the checkpoint as state_dict. If the checkpoint does not match the model architecture there will be error. We don't break the standard process of PyTorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, we don't usual expect the training to be done on a specialized NPU implementation, but I guess we can tweak the state dict loading on a case by case basis.
@dvorjackz It's currently in the example model. I'm fine to promote this to extension/llm, rename it to llm_transformer, or simply transformer, and mark it user faced. |
e29e337
to
433fabb
Compare
@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…ce (#8039) Summary: Add abstract base class for attention mechanisms with unified interface. It creates the interface to multiple Attention definitions, like NPU friendly attentions, or other attention types like Multi-Head Latent Attention (MLA) used in Deepseek. A simple registry is provided to easily add and register a new attention class. Moved the current attention implementation to attention.py and rename it to AttentionMHA. Test Plan: CI Differential Revision: D68956201 Pulled By: iseeyuan
433fabb
to
9c19d3c
Compare
This pull request was exported from Phabricator. Differential Revision: D68956201 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable to me. The cache, mask, and rope frequencies still need to passed down from the top level transformer and updates from each layer returned need to be returned in a follow up PR.
from executorch.examples.models.llama.rope import Rope | ||
|
||
|
||
class Attention(nn.Module, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, we don't usual expect the training to be done on a specialized NPU implementation, but I guess we can tweak the state dict loading on a case by case basis.
Thanks @sxu ! I'll land this one when all CI pass, and please feel free to add further PRs to improve it. |
…ce (#8039) Summary: Add abstract base class for attention mechanisms with unified interface. It creates the interface to multiple Attention definitions, like NPU friendly attentions, or other attention types like Multi-Head Latent Attention (MLA) used in Deepseek. A simple registry is provided to easily add and register a new attention class. Moved the current attention implementation to attention.py and rename it to AttentionMHA. Test Plan: CI Reviewed By: tarun292 Differential Revision: D68956201 Pulled By: iseeyuan
9c19d3c
to
013ca61
Compare
This pull request was exported from Phabricator. Differential Revision: D68956201 |
…ce (#8039) Summary: Add abstract base class for attention mechanisms with unified interface. It creates the interface to multiple Attention definitions, like NPU friendly attentions, or other attention types like Multi-Head Latent Attention (MLA) used in Deepseek. A simple registry is provided to easily add and register a new attention class. Moved the current attention implementation to attention.py and rename it to AttentionMHA. Test Plan: CI Reviewed By: tarun292 Differential Revision: D68956201 Pulled By: iseeyuan
013ca61
to
c3167c3
Compare
This pull request was exported from Phabricator. Differential Revision: D68956201 |
…ce (#8039) Summary: Add abstract base class for attention mechanisms with unified interface. It creates the interface to multiple Attention definitions, like NPU friendly attentions, or other attention types like Multi-Head Latent Attention (MLA) used in Deepseek. A simple registry is provided to easily add and register a new attention class. Moved the current attention implementation to attention.py and rename it to AttentionMHA. Test Plan: CI Reviewed By: tarun292 Differential Revision: D68956201 Pulled By: iseeyuan
c3167c3
to
9ccf542
Compare
This pull request was exported from Phabricator. Differential Revision: D68956201 |
Summary
Add abstract base class for attention mechanisms with unified interface.
It creates the interface to multiple Attention definitions, like NPU friendly attentions, or other attention types like Multi-Head Latent Attention (MLA) used in Deepseek.
A simple registry is provided to easily add and register a new attention class.
Moved the current attention implementation to attention.py and rename it to AttentionMHA.
Test plan
CI