Skip to content

Commit fca144b

Browse files
bvandermoonCristian Garcia
and
Cristian Garcia
committed
Convert DenseGeneral to NNX
Co-authored-by: Branden Vandermoon <[email protected]> Co-authored-by: Cristian Garcia <[email protected]>
1 parent 5c7df89 commit fca144b

11 files changed

+278
-98
lines changed

MaxText/layers/attentions.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class AttentionType(enum.Enum):
6464
DType = common_types.DType
6565
Mesh = common_types.Mesh
6666
PRNGKey = common_types.PRNGKey
67-
DenseGeneral = linears.DenseGeneral
67+
dense_general = linears.dense_general
6868
RMSNorm = linears.RMSNorm
6969
RotaryEmbedding = embeddings.RotaryEmbedding
7070
YarnRotaryEmbedding = embeddings.YarnRotaryEmbedding
@@ -1330,7 +1330,8 @@ def query_init(*args):
13301330
kernel_axes = (
13311331
(None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("embed", "q_heads", "kv")
13321332
)
1333-
query_proj = DenseGeneral(
1333+
query_proj = dense_general(
1334+
inputs_shape=inputs_q.shape,
13341335
features=(self.num_query_heads, self.head_dim),
13351336
axis=-1,
13361337
kernel_init=query_init,
@@ -1366,7 +1367,8 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
13661367
else ("embed", "kv_heads", "kv_head_dim")
13671368
)
13681369

1369-
kv_proj = DenseGeneral(
1370+
kv_proj = dense_general(
1371+
inputs_shape=inputs_kv.shape,
13701372
features=(self.num_kv_heads, self.head_dim),
13711373
axis=-1,
13721374
kernel_init=self.kernel_init,
@@ -1382,7 +1384,8 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
13821384
def qkv_projection(self, inputs: Array, proj_name: str):
13831385
"""Fused QKV projection"""
13841386

1385-
qkv_proj = DenseGeneral(
1387+
qkv_proj = dense_general(
1388+
inputs_shape=inputs.shape,
13861389
features=(3, self.num_query_heads, self.head_dim),
13871390
axis=-1,
13881391
kernel_init=self.kernel_init,
@@ -1402,7 +1405,8 @@ def out_projection(self, output_dim: int, out: Array) -> Array:
14021405
out_kernel_axis = (
14031406
(None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("heads", "kv", "embed")
14041407
)
1405-
out_proj = DenseGeneral(
1408+
out_proj = dense_general(
1409+
inputs_shape=out.shape,
14061410
features=output_dim,
14071411
axis=(-2, -1),
14081412
kernel_init=self.kernel_init,
@@ -1660,7 +1664,8 @@ def setup(self):
16601664

16611665
if self.q_lora_rank == 0:
16621666
# Standard Q projection (without LoRA).
1663-
self.query_proj = DenseGeneral(
1667+
self.query_proj = dense_general(
1668+
in_features=self.config.emb_dim,
16641669
features=(self.num_query_heads, self.qk_head_dim),
16651670
axis=-1,
16661671
kernel_init=self.kernel_init,
@@ -1673,7 +1678,8 @@ def setup(self):
16731678
)
16741679
else:
16751680
# LoRA path for Q.
1676-
self.wq_a = DenseGeneral(
1681+
self.wq_a = dense_general(
1682+
in_features=self.config.emb_dim,
16771683
features=self.q_lora_rank,
16781684
axis=-1,
16791685
kernel_init=self.kernel_init,
@@ -1691,7 +1697,8 @@ def setup(self):
16911697
epsilon=self.config.normalization_layer_epsilon,
16921698
kernel_axes=("norm",),
16931699
)
1694-
self.wq_b = DenseGeneral(
1700+
self.wq_b = dense_general(
1701+
in_features=self.q_lora_rank,
16951702
features=(self.num_query_heads, self.qk_head_dim),
16961703
axis=-1,
16971704
kernel_init=self.kernel_init,
@@ -1704,7 +1711,8 @@ def setup(self):
17041711
)
17051712

17061713
# KV LoRA path.
1707-
self.wkv_a = DenseGeneral(
1714+
self.wkv_a = dense_general(
1715+
in_features=self.config.emb_dim,
17081716
features=self.kv_lora_rank + self.qk_rope_head_dim,
17091717
axis=-1,
17101718
kernel_init=self.kernel_init,
@@ -1722,8 +1730,12 @@ def setup(self):
17221730
epsilon=self.config.normalization_layer_epsilon,
17231731
kernel_axes=("norm",),
17241732
)
1725-
self.wkv_b = DenseGeneral(
1726-
features=(self.num_query_heads, (self.qk_nope_head_dim + self.v_head_dim)),
1733+
self.wkv_b = dense_general(
1734+
in_features=self.kv_lora_rank,
1735+
features=(
1736+
self.num_query_heads,
1737+
(self.qk_nope_head_dim + self.v_head_dim),
1738+
),
17271739
axis=-1,
17281740
kernel_init=self.kernel_init,
17291741
kernel_axes=("kv_lora", "kv_heads", "kv_head_dim"),
@@ -1933,3 +1945,4 @@ def __hash__(self):
19331945
self.q_sequence.tobytes() if self.q_sequence is not None else None,
19341946
)
19351947
)
1948+

MaxText/layers/gpt3.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
D_KV = common_types.D_KV
5050
EMBED = common_types.EMBED
5151

52-
DenseGeneral = linears.DenseGeneral
52+
dense_general = linears.dense_general
5353
NdInitializer = initializers.NdInitializer
5454
Initializer = initializers.Initializer
5555
nd_dense_init = initializers.nd_dense_init
@@ -163,7 +163,8 @@ class Gpt3MultiHeadAttention(nn.Module):
163163
def qkv_projection(self, inputs: Array, proj_name: str):
164164
"""Fused QKV projection"""
165165

166-
qkv_proj = DenseGeneral(
166+
qkv_proj = dense_general(
167+
inputs_shape=inputs.shape,
167168
features=(3, self.num_heads, self.head_dim),
168169
axis=-1,
169170
kernel_init=self.kernel_init,
@@ -181,7 +182,8 @@ def qkv_projection(self, inputs: Array, proj_name: str):
181182

182183
def projection(self, inputs: Array, proj_name: str) -> Array:
183184
"""individual projection for one of q, k and v."""
184-
proj = DenseGeneral(
185+
proj = dense_general(
186+
inputs_shape=inputs.shape,
185187
features=(self.num_heads, self.head_dim),
186188
axis=-1,
187189
kernel_init=self.kernel_init,
@@ -197,7 +199,8 @@ def projection(self, inputs: Array, proj_name: str) -> Array:
197199

198200
def out_projection(self, output_dim: int, out: Array) -> Array:
199201
"""output projection"""
200-
out_proj = DenseGeneral(
202+
out_proj = dense_general(
203+
inputs_shape=out.shape,
201204
features=output_dim,
202205
axis=(-2, -1),
203206
kernel_init=self.kernel_init,

0 commit comments

Comments
 (0)