Skip to content

Conversation

noobsiecoder
Copy link
Contributor

Description

Fixes #2585
Added PyTorch nn.RMSNorm to CoreML conversion by decomposing into primitive MIL operations.

Implementation

Decompose RMSNorm formula into mb composable functions:

  1. square
  2. reduce_mean
  3. add
  4. sqrt
  5. real_div
  6. mul

References

Checklist for Reviewer(s)

@junpeiz
Copy link
Collaborator

junpeiz commented Sep 9, 2025

Great! Will merge after getting a green CI: https://gitlab.com/coremltools1/coremltools/-/pipelines/2030304960

@junpeiz junpeiz self-requested a review September 10, 2025 02:28
@noobsiecoder
Copy link
Contributor Author

Hey @junpeiz, can I merge my changes to the main branch?

@junpeiz junpeiz merged commit 02450cf into apple:main Sep 10, 2025
@junpeiz
Copy link
Collaborator

junpeiz commented Sep 10, 2025

Hey @junpeiz, can I merge my changes to the main branch?

No worries! I already merged it. Thank you for your contributions!

noobsiecoder added a commit to noobsiecoder/coremltools that referenced this pull request Sep 10, 2025
…2585) (apple#2592)

* Add RMSNorm operator support for PyTorch to CoreML conversion (apple#2585)

* formatted code

* handles FP16 overflow for RMSNorm operation
@FL33TW00D
Copy link
Contributor

@smpanaro @0seba

Relevant to your discussion about OpenELM: huggingface/swift-transformers#95

@noobsiecoder noobsiecoder deleted the add-torch-nn-rmsnorm-fused-opr branch September 11, 2025 12:57
junpeiz pushed a commit that referenced this pull request Sep 11, 2025
* Fix Issue #2583: Dynamic padding in torch.nn.functional.pad

Modified _array_construct to handle dynamic padding values: Creates proper Var objects using mb.concat instead of Python lists + Fixes AttributeError when converting models with x.size(-1) padding

* limit torch to older than 2.8 for now (#2591)

Co-authored-by: yifan_shen3 <[email protected]>

* Add RMSNorm operator support for PyTorch to CoreML conversion (#2585) (#2592)

* Add RMSNorm operator support for PyTorch to CoreML conversion (#2585)

* formatted code

* handles FP16 overflow for RMSNorm operation

* handle dynamic padding w/o breaking legacy code

---------

Co-authored-by: Yifan Shen <[email protected]>
Co-authored-by: yifan_shen3 <[email protected]>
@0seba
Copy link

0seba commented Sep 18, 2025

thanks for the tag @FL33TW00D. i'd switch to using anemll rms norm implementation, it gives me the impression that it relies on ane functionalities not publicly accesible.
implementation in mil ops
https://github.com/0seba/coreml-models/blob/7660302e2ea5f6c493ded09b3730d06710a3922d/src/layers/normalization.py#L10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

add torch.nn.RMSNorm fused op
4 participants