Skip to content

Commit f9ca7b2

Browse files
Cristian Garciamaxtext authors
Cristian Garcia
authored and
maxtext authors
committed
Convert DenseGeneral to NNX
# Description This commit converts `DenseGeneral` to NNX and creates a `dense_general` to interface with it through a Linen wrapper. `dense_general` contains all the same arguments as the Linen version but adds two additional ones: * `input_shape`: the expected shape of the input. * `in_features`: an int or tuple representing the input features. Only one of them can be set at a time. # Tests # Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a self-review of my code. - [x] I have necessary comments in my code, particularly in hard-to-understand areas. - [x] I have run end-to-end tests tests and provided workload links above if applicable. - [x] I have made or will make corresponding changes to the doc if needed. PiperOrigin-RevId: 748311465
1 parent 2e4af93 commit f9ca7b2

File tree

6 files changed

+205
-82
lines changed

6 files changed

+205
-82
lines changed

MaxText/layers/attentions.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class AttentionType(enum.Enum):
6060
Mesh = common_types.Mesh
6161
PRNGKey = common_types.PRNGKey
6262
DenseGeneral = linears.DenseGeneral
63+
dense_general = linears.dense_general
6364
RMSNorm = linears.RMSNorm
6465
RotaryEmbedding = embeddings.RotaryEmbedding
6566
YarnRotaryEmbedding = embeddings.YarnRotaryEmbedding
@@ -1126,7 +1127,8 @@ def query_init(*args):
11261127
kernel_axes = (
11271128
(None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("embed", "q_heads", "kv")
11281129
)
1129-
query_proj = DenseGeneral(
1130+
query_proj = dense_general(
1131+
inputs_shape=inputs_q.shape,
11301132
features=(self.num_query_heads, self.head_dim),
11311133
axis=-1,
11321134
kernel_init=query_init,
@@ -1162,7 +1164,8 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
11621164
else ("embed", "kv_heads", "kv_head_dim")
11631165
)
11641166

1165-
kv_proj = DenseGeneral(
1167+
kv_proj = dense_general(
1168+
inputs_shape=inputs_kv.shape,
11661169
features=(self.num_kv_heads, self.head_dim),
11671170
axis=-1,
11681171
kernel_init=self.kernel_init,
@@ -1178,7 +1181,8 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
11781181
def qkv_projection(self, inputs: Array, proj_name: str):
11791182
"""Fused QKV projection"""
11801183

1181-
qkv_proj = DenseGeneral(
1184+
qkv_proj = dense_general(
1185+
inputs_shape=inputs.shape,
11821186
features=(3, self.num_query_heads, self.head_dim),
11831187
axis=-1,
11841188
kernel_init=self.kernel_init,
@@ -1197,7 +1201,8 @@ def out_projection(self, output_dim: int, out: Array) -> Array:
11971201
out_kernel_axis = (
11981202
(None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("heads", "kv", "embed")
11991203
)
1200-
out_proj = DenseGeneral(
1204+
out_proj = dense_general(
1205+
inputs_shape=out.shape,
12011206
features=output_dim,
12021207
axis=(-2, -1),
12031208
kernel_init=self.kernel_init,
@@ -1455,7 +1460,8 @@ def setup(self):
14551460

14561461
if self.q_lora_rank == 0:
14571462
# Standard Q projection (without LoRA).
1458-
self.query_proj = DenseGeneral(
1463+
self.query_proj = dense_general(
1464+
in_features=self.config.emb_dim,
14591465
features=(self.num_query_heads, self.qk_head_dim),
14601466
axis=-1,
14611467
kernel_init=self.kernel_init,
@@ -1468,7 +1474,8 @@ def setup(self):
14681474
)
14691475
else:
14701476
# LoRA path for Q.
1471-
self.wq_a = DenseGeneral(
1477+
self.wq_a = dense_general(
1478+
in_features=self.config.emb_dim,
14721479
features=self.q_lora_rank,
14731480
axis=-1,
14741481
kernel_init=self.kernel_init,
@@ -1486,7 +1493,8 @@ def setup(self):
14861493
epsilon=self.config.normalization_layer_epsilon,
14871494
kernel_axes=("norm",),
14881495
)
1489-
self.wq_b = DenseGeneral(
1496+
self.wq_b = dense_general(
1497+
in_features=self.q_lora_rank,
14901498
features=(self.num_query_heads, self.qk_head_dim),
14911499
axis=-1,
14921500
kernel_init=self.kernel_init,
@@ -1499,7 +1507,8 @@ def setup(self):
14991507
)
15001508

15011509
# KV LoRA path.
1502-
self.wkv_a = DenseGeneral(
1510+
self.wkv_a = dense_general(
1511+
in_features=self.config.emb_dim,
15031512
features=self.kv_lora_rank + self.qk_rope_head_dim,
15041513
axis=-1,
15051514
kernel_init=self.kernel_init,
@@ -1517,8 +1526,12 @@ def setup(self):
15171526
epsilon=self.config.normalization_layer_epsilon,
15181527
kernel_axes=("norm",),
15191528
)
1520-
self.wkv_b = DenseGeneral(
1521-
features=(self.num_query_heads, (self.qk_nope_head_dim + self.v_head_dim)),
1529+
self.wkv_b = dense_general(
1530+
in_features=self.kv_lora_rank,
1531+
features=(
1532+
self.num_query_heads,
1533+
(self.qk_nope_head_dim + self.v_head_dim),
1534+
),
15221535
axis=-1,
15231536
kernel_init=self.kernel_init,
15241537
kernel_axes=("kv_lora", "kv_heads", "kv_head_dim"),

MaxText/layers/gpt3.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
EMBED = common_types.EMBED
5151

5252
DenseGeneral = linears.DenseGeneral
53+
dense_general = linears.dense_general
5354
NdInitializer = initializers.NdInitializer
5455
Initializer = initializers.Initializer
5556
nd_dense_init = initializers.nd_dense_init
@@ -158,7 +159,8 @@ class Gpt3MultiHeadAttention(nn.Module):
158159
def qkv_projection(self, inputs: Array, proj_name: str):
159160
"""Fused QKV projection"""
160161

161-
qkv_proj = DenseGeneral(
162+
qkv_proj = dense_general(
163+
inputs_shape=inputs.shape,
162164
features=(3, self.num_heads, self.head_dim),
163165
axis=-1,
164166
kernel_init=self.kernel_init,
@@ -176,7 +178,8 @@ def qkv_projection(self, inputs: Array, proj_name: str):
176178

177179
def projection(self, inputs: Array, proj_name: str) -> Array:
178180
"""individual projection for one of q, k and v."""
179-
proj = DenseGeneral(
181+
proj = dense_general(
182+
inputs_shape=inputs.shape,
180183
features=(self.num_heads, self.head_dim),
181184
axis=-1,
182185
kernel_init=self.kernel_init,
@@ -192,7 +195,8 @@ def projection(self, inputs: Array, proj_name: str) -> Array:
192195

193196
def out_projection(self, output_dim: int, out: Array) -> Array:
194197
"""output projection"""
195-
out_proj = DenseGeneral(
198+
out_proj = dense_general(
199+
inputs_shape=out.shape,
196200
features=output_dim,
197201
axis=(-2, -1),
198202
kernel_init=self.kernel_init,

0 commit comments

Comments
 (0)