Skip to content

Add abstract base class for attention mechanisms with unified interface #8039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 1, 2025

Conversation

iseeyuan
Copy link
Contributor

@iseeyuan iseeyuan commented Jan 29, 2025

Summary

Add abstract base class for attention mechanisms with unified interface.
It creates the interface to multiple Attention definitions, like NPU friendly attentions, or other attention types like Multi-Head Latent Attention (MLA) used in Deepseek.
A simple registry is provided to easily add and register a new attention class.

Moved the current attention implementation to attention.py and rename it to AttentionMHA.

Test plan

CI

Copy link

pytorch-bot bot commented Jan 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8039

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 2 Unrelated Failures

As of commit 9ccf542 with merge base c0676fe (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 29, 2025
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
in_cache_state: Optional[Any] = None,
out_cache_state: Optional[Any] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace them with kwargs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directly using kwargs may break type-safe. Keep them as is and consider using TypedDict and Unpack for kwarg type checking later.

@CypherpunkSamurai
Copy link

what is the progress on this pr?

im currently trying to convert a distilled deekseek R1 to pte using example scripts

Let me know if I can help testing this out.

@iseeyuan
Copy link
Contributor Author

@CypherpunkSamurai I'm trying to complete the refactor and land it this week. Let me create an issue of adding the MLA to this interface.

@iseeyuan iseeyuan force-pushed the attention branch 3 times, most recently from ce1b50c to 00ec564 Compare January 30, 2025 16:16
@iseeyuan iseeyuan changed the title [WIP] Add abstract base class for attention mechanisms with unified interface Add abstract base class for attention mechanisms with unified interface Jan 30, 2025
@iseeyuan
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

Copy link
Contributor

@jackzhxng jackzhxng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ready for review? Also would probably tag with an actual release note tag since I assume this would be good to highlight in 0.6 release notes

from executorch.examples.models.llama.rope import Rope


class Attention(nn.Module, ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far a specialized implementation is only used during lowering and on device, and it needs to be able to accept checkpoint from whatever definition was used during training. What do see as the usage pattern going forward? Is the AttentionMHA below the standard definition that specialization of this class needs to support?

Copy link
Contributor Author

@iseeyuan iseeyuan Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sxu

Is the AttentionMHA below the standard definition that specialization of this class needs to support?

Not necessarily. The attention type is added into the model args. Usually the model arg and checkpoint will be saved in one place. We use model arg to build the model, and load the checkpoint as state_dict. If the checkpoint does not match the model architecture there will be error. We don't break the standard process of PyTorch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we don't usual expect the training to be done on a specialized NPU implementation, but I guess we can tweak the state dict loading on a case by case basis.

@iseeyuan
Copy link
Contributor Author

Is this ready for review? Also would probably tag with an actual release note tag since I assume this would be good to highlight in 0.6 release notes

@dvorjackz It's currently in the example model. I'm fine to promote this to extension/llm, rename it to llm_transformer, or simply transformer, and mark it user faced.

@iseeyuan iseeyuan force-pushed the attention branch 2 times, most recently from e29e337 to 433fabb Compare January 31, 2025 14:17
@facebook-github-bot
Copy link
Contributor

@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Jan 31, 2025
…ce (#8039)

Summary:
Add abstract base class for attention mechanisms with unified interface.
It creates the interface to multiple Attention definitions, like NPU friendly attentions, or other attention types like Multi-Head Latent Attention (MLA) used in Deepseek.
A simple registry is provided to easily add and register a new attention class.

Moved the current attention implementation to attention.py and rename it to AttentionMHA.


Test Plan: CI

Differential Revision: D68956201

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68956201

Copy link
Contributor

@sxu sxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me. The cache, mask, and rope frequencies still need to passed down from the top level transformer and updates from each layer returned need to be returned in a follow up PR.

from executorch.examples.models.llama.rope import Rope


class Attention(nn.Module, ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we don't usual expect the training to be done on a specialized NPU implementation, but I guess we can tweak the state dict loading on a case by case basis.

@iseeyuan
Copy link
Contributor Author

Looks reasonable to me. The cache, mask, and rope frequencies still need to passed down from the top level transformer and updates from each layer returned need to be returned in a follow up PR.

Thanks @sxu ! I'll land this one when all CI pass, and please feel free to add further PRs to improve it.

facebook-github-bot pushed a commit that referenced this pull request Jan 31, 2025
…ce (#8039)

Summary:
Add abstract base class for attention mechanisms with unified interface.
It creates the interface to multiple Attention definitions, like NPU friendly attentions, or other attention types like Multi-Head Latent Attention (MLA) used in Deepseek.
A simple registry is provided to easily add and register a new attention class.

Moved the current attention implementation to attention.py and rename it to AttentionMHA.


Test Plan: CI

Reviewed By: tarun292

Differential Revision: D68956201

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68956201

facebook-github-bot pushed a commit that referenced this pull request Jan 31, 2025
…ce (#8039)

Summary:
Add abstract base class for attention mechanisms with unified interface.
It creates the interface to multiple Attention definitions, like NPU friendly attentions, or other attention types like Multi-Head Latent Attention (MLA) used in Deepseek.
A simple registry is provided to easily add and register a new attention class.

Moved the current attention implementation to attention.py and rename it to AttentionMHA.


Test Plan: CI

Reviewed By: tarun292

Differential Revision: D68956201

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68956201

…ce (#8039)

Summary:
Add abstract base class for attention mechanisms with unified interface.
It creates the interface to multiple Attention definitions, like NPU friendly attentions, or other attention types like Multi-Head Latent Attention (MLA) used in Deepseek.
A simple registry is provided to easily add and register a new attention class.

Moved the current attention implementation to attention.py and rename it to AttentionMHA.


Test Plan: CI

Reviewed By: tarun292

Differential Revision: D68956201

Pulled By: iseeyuan
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68956201

@jackzhxng jackzhxng added release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava and removed topic: not user facing labels Feb 1, 2025
@facebook-github-bot facebook-github-bot merged commit a972e73 into main Feb 1, 2025
41 of 47 checks passed
@facebook-github-bot facebook-github-bot deleted the attention branch February 1, 2025 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants