-
Notifications
You must be signed in to change notification settings - Fork 356
Convert DenseGeneral to NNX #1604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
a24f6e9
to
3fcbe6b
Compare
97a92b7
to
da43223
Compare
65a7f45
to
f83299d
Compare
# Description This commit converts `DenseGeneral` to NNX and creates a `dense_general` to interface with it through a Linen wrapper. `dense_general` contains all the same arguments as the Linen version but adds two additional ones: * `input_shape`: the expected shape of the input. * `in_features`: an int or tuple representing the input features. Only one of them can be set at a time. # Tests # Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a self-review of my code. - [x] I have necessary comments in my code, particularly in hard-to-understand areas. - [x] I have run end-to-end tests tests and provided workload links above if applicable. - [x] I have made or will make corresponding changes to the doc if needed. PiperOrigin-RevId: 748311465
f83299d
to
f9ca7b2
Compare
module = nnx.bridge.to_linen( | ||
DenseGeneral, | ||
in_features=in_features, | ||
out_features=features, | ||
axis=axis, | ||
weight_dtype=weight_dtype, | ||
dtype=dtype, | ||
kernel_init=kernel_init, | ||
kernel_axes=kernel_axes, | ||
quant=quant, | ||
use_bias=use_bias, | ||
matmul_precision=matmul_precision, | ||
name=name, | ||
metadata_fn=variable_to_logically_partitioned, | ||
) | ||
return module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm my understanding - this will not always be a bridge to linen, right? This is just a starting point for the migration and we will eventually have everything just be nnx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct!
Convert DenseGeneral to NNX
Description
This commit converts
DenseGeneral
to NNX and creates adense_general
to interface with it through a Linen wrapper.dense_general
contains all the same arguments as the Linen version but adds two additional ones:input_shape
: the expected shape of the input.in_features
: an int or tuple representing the input features.Only one of them can be set at a time.
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):