Skip to content

Convert DenseGeneral to NNX #1604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Convert DenseGeneral to NNX #1604

wants to merge 1 commit into from

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Apr 17, 2025

Convert DenseGeneral to NNX

Description

This commit converts DenseGeneral to NNX and creates a dense_general to interface with it through a Linen wrapper. dense_general contains all the same arguments as the Linen version but adds two additional ones:

  • input_shape: the expected shape of the input.
  • in_features: an int or tuple representing the input features.

Only one of them can be set at a time.

Tests

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@copybara-service copybara-service bot force-pushed the test_748311465 branch 3 times, most recently from a24f6e9 to 3fcbe6b Compare April 23, 2025 20:57
@copybara-service copybara-service bot changed the title convert DenseGeneral to NNX Convert DenseGeneral to NNX Apr 23, 2025
@copybara-service copybara-service bot force-pushed the test_748311465 branch 7 times, most recently from 97a92b7 to da43223 Compare April 25, 2025 16:33
@copybara-service copybara-service bot force-pushed the test_748311465 branch 3 times, most recently from 65a7f45 to f83299d Compare April 25, 2025 23:04
# Description
This commit converts `DenseGeneral` to NNX and creates a `dense_general` to interface with it through a Linen wrapper. `dense_general` contains all the same arguments as the Linen version but adds two additional ones:

* `input_shape`: the expected shape of the input.
* `in_features`: an int or tuple representing the input features.

Only one of them can be set at a time.

# Tests

# Checklist

Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [x] I have run end-to-end tests tests and provided workload links above if applicable.
- [x] I have made or will make corresponding changes to the doc if needed.

PiperOrigin-RevId: 748311465
Comment on lines +242 to +257
module = nnx.bridge.to_linen(
DenseGeneral,
in_features=in_features,
out_features=features,
axis=axis,
weight_dtype=weight_dtype,
dtype=dtype,
kernel_init=kernel_init,
kernel_axes=kernel_axes,
quant=quant,
use_bias=use_bias,
matmul_precision=matmul_precision,
name=name,
metadata_fn=variable_to_logically_partitioned,
)
return module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm my understanding - this will not always be a bridge to linen, right? This is just a starting point for the migration and we will eventually have everything just be nnx

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants