Convert DenseGeneral to NNX #1743
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Finishing out PR 1604 for @cgarciae
Note: I updated the logits in
golden_data_grpo_default.jsonl
to the NNX values. The NNX logits don't match the Linen ones since the RNG keys are setup differently.Previous PR description:
This commit converts DenseGeneral to NNX and creates a dense_general to interface with it through a Linen wrapper. dense_general contains all the same arguments as the Linen version but adds two additional ones:
input_shape
: the expected shape of the input.in_features
: an int or tuple representing the input features.Only one of them can be set at a time.
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):