@@ -64,7 +64,7 @@ class AttentionType(enum.Enum):
64
64
DType = common_types .DType
65
65
Mesh = common_types .Mesh
66
66
PRNGKey = common_types .PRNGKey
67
- DenseGeneral = linears .DenseGeneral
67
+ dense_general = linears .dense_general
68
68
RMSNorm = linears .RMSNorm
69
69
RotaryEmbedding = embeddings .RotaryEmbedding
70
70
YarnRotaryEmbedding = embeddings .YarnRotaryEmbedding
@@ -1330,7 +1330,8 @@ def query_init(*args):
1330
1330
kernel_axes = (
1331
1331
(None , None , None ) if self .config .ici_context_autoregressive_parallelism > 1 else ("embed" , "q_heads" , "kv" )
1332
1332
)
1333
- query_proj = DenseGeneral (
1333
+ query_proj = dense_general (
1334
+ inputs_shape = inputs_q .shape ,
1334
1335
features = (self .num_query_heads , self .head_dim ),
1335
1336
axis = - 1 ,
1336
1337
kernel_init = query_init ,
@@ -1366,7 +1367,8 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
1366
1367
else ("embed" , "kv_heads" , "kv_head_dim" )
1367
1368
)
1368
1369
1369
- kv_proj = DenseGeneral (
1370
+ kv_proj = dense_general (
1371
+ inputs_shape = inputs_kv .shape ,
1370
1372
features = (self .num_kv_heads , self .head_dim ),
1371
1373
axis = - 1 ,
1372
1374
kernel_init = self .kernel_init ,
@@ -1382,7 +1384,8 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
1382
1384
def qkv_projection (self , inputs : Array , proj_name : str ):
1383
1385
"""Fused QKV projection"""
1384
1386
1385
- qkv_proj = DenseGeneral (
1387
+ qkv_proj = dense_general (
1388
+ inputs_shape = inputs .shape ,
1386
1389
features = (3 , self .num_query_heads , self .head_dim ),
1387
1390
axis = - 1 ,
1388
1391
kernel_init = self .kernel_init ,
@@ -1402,7 +1405,8 @@ def out_projection(self, output_dim: int, out: Array) -> Array:
1402
1405
out_kernel_axis = (
1403
1406
(None , None , None ) if self .config .ici_context_autoregressive_parallelism > 1 else ("heads" , "kv" , "embed" )
1404
1407
)
1405
- out_proj = DenseGeneral (
1408
+ out_proj = dense_general (
1409
+ inputs_shape = out .shape ,
1406
1410
features = output_dim ,
1407
1411
axis = (- 2 , - 1 ),
1408
1412
kernel_init = self .kernel_init ,
@@ -1660,7 +1664,8 @@ def setup(self):
1660
1664
1661
1665
if self .q_lora_rank == 0 :
1662
1666
# Standard Q projection (without LoRA).
1663
- self .query_proj = DenseGeneral (
1667
+ self .query_proj = dense_general (
1668
+ in_features = self .config .emb_dim ,
1664
1669
features = (self .num_query_heads , self .qk_head_dim ),
1665
1670
axis = - 1 ,
1666
1671
kernel_init = self .kernel_init ,
@@ -1673,7 +1678,8 @@ def setup(self):
1673
1678
)
1674
1679
else :
1675
1680
# LoRA path for Q.
1676
- self .wq_a = DenseGeneral (
1681
+ self .wq_a = dense_general (
1682
+ in_features = self .config .emb_dim ,
1677
1683
features = self .q_lora_rank ,
1678
1684
axis = - 1 ,
1679
1685
kernel_init = self .kernel_init ,
@@ -1691,7 +1697,8 @@ def setup(self):
1691
1697
epsilon = self .config .normalization_layer_epsilon ,
1692
1698
kernel_axes = ("norm" ,),
1693
1699
)
1694
- self .wq_b = DenseGeneral (
1700
+ self .wq_b = dense_general (
1701
+ in_features = self .q_lora_rank ,
1695
1702
features = (self .num_query_heads , self .qk_head_dim ),
1696
1703
axis = - 1 ,
1697
1704
kernel_init = self .kernel_init ,
@@ -1704,7 +1711,8 @@ def setup(self):
1704
1711
)
1705
1712
1706
1713
# KV LoRA path.
1707
- self .wkv_a = DenseGeneral (
1714
+ self .wkv_a = dense_general (
1715
+ in_features = self .config .emb_dim ,
1708
1716
features = self .kv_lora_rank + self .qk_rope_head_dim ,
1709
1717
axis = - 1 ,
1710
1718
kernel_init = self .kernel_init ,
@@ -1722,8 +1730,12 @@ def setup(self):
1722
1730
epsilon = self .config .normalization_layer_epsilon ,
1723
1731
kernel_axes = ("norm" ,),
1724
1732
)
1725
- self .wkv_b = DenseGeneral (
1726
- features = (self .num_query_heads , (self .qk_nope_head_dim + self .v_head_dim )),
1733
+ self .wkv_b = dense_general (
1734
+ in_features = self .kv_lora_rank ,
1735
+ features = (
1736
+ self .num_query_heads ,
1737
+ (self .qk_nope_head_dim + self .v_head_dim ),
1738
+ ),
1727
1739
axis = - 1 ,
1728
1740
kernel_init = self .kernel_init ,
1729
1741
kernel_axes = ("kv_lora" , "kv_heads" , "kv_head_dim" ),
@@ -1933,3 +1945,4 @@ def __hash__(self):
1933
1945
self .q_sequence .tobytes () if self .q_sequence is not None else None ,
1934
1946
)
1935
1947
)
1948
+
0 commit comments