Skip to content

Commit d240be0

Browse files
authored
[Qwen3] Fix model arg computation for MoE model (#1704)
as titled
1 parent bd3850b commit d240be0

File tree

3 files changed

+46
-9
lines changed

3 files changed

+46
-9
lines changed

torchtitan/components/metrics.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,12 @@ def close(self) -> None:
163163
if self.wandb.run is not None:
164164
self.wandb.finish()
165165

166+
166167
class LoggerContainer(BaseLogger):
167168
"""Container to call all loggers enabled in the job config."""
169+
168170
def __init__(self) -> None:
169-
self._loggers : list[BaseLogger] = []
171+
self._loggers: list[BaseLogger] = []
170172

171173
def add_logger(self, logger_instance: BaseLogger) -> None:
172174
self._loggers.append(logger_instance)
@@ -183,6 +185,7 @@ def close(self) -> None:
183185
for logger_instance in self._loggers:
184186
logger_instance.close()
185187

188+
186189
def ensure_pp_loss_visible(
187190
parallel_dims: ParallelDims, job_config: JobConfig, color: Color
188191
) -> None:

torchtitan/experiments/qwen3/model/args.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,36 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
5555
self.max_seq_len = seq_len
5656

5757
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
58-
nparams = sum(p.numel() for p in model.parameters())
59-
nparams_embedding = sum(
60-
sum(p.numel() for p in m.parameters())
61-
for m in model.children()
62-
if isinstance(m, nn.Embedding)
58+
nparams_embedding = 0
59+
nparams_moe_router = 0
60+
nparams_shared_experts = 0
61+
nparams_experts = 0
62+
nparams_dense = 0
63+
64+
for name, p in model.named_parameters():
65+
if "embedding" in name:
66+
nparams_embedding += p.numel()
67+
nparams_dense += p.numel()
68+
elif "moe.shared_experts" in name:
69+
nparams_shared_experts += p.numel()
70+
elif "moe.router" in name:
71+
nparams_moe_router += p.numel()
72+
elif "moe.experts" in name:
73+
nparams_experts += p.numel()
74+
else:
75+
nparams_dense += p.numel()
76+
77+
nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts
78+
nparams = nparams_dense + nparams_sparse
79+
nparams_sparse_active = (
80+
nparams_moe_router
81+
+ nparams_shared_experts
82+
+ nparams_experts * self.moe_args.top_k // self.moe_args.num_experts
83+
)
84+
85+
logger.info(
86+
f"Total parameter count: dense {nparams_dense:,}, "
87+
f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
6388
)
6489

6590
l, h, q, t = (
@@ -68,10 +93,18 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in
6893
self.dim // self.n_heads,
6994
seq_len,
7095
)
71-
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
96+
# Reasoning behind the factor of 12 for the self-attention part of the formula:
97+
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
98+
# 2. the flash attention does 1 more matmul recomputation in the backward
99+
# but recomputation should not be counted in calculating MFU (+0)
100+
# 3. each matmul performs 1 multiplication and 1 addition (*2)
101+
# 4. we follow the convention and do not account for sparsity in causal attention
102+
num_flops_per_token = (
103+
6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
104+
+ 12 * l * h * q * t
105+
)
72106

73107
if self.enable_weight_tying:
74-
# exclude model.token_embedding parameters from nparams
75108
nparams = nparams - nparams_embedding
76109

77110
return nparams, num_flops_per_token

torchtitan/experiments/qwen3/model/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(self, model_args: Qwen3ModelArgs):
132132
)
133133
self.n_rep = self.n_heads // self.n_kv_heads
134134
self.head_dim = model_args.head_dim
135+
self.scaling = self.head_dim**-0.5
135136

136137
# RMSNorm added here to the here to include the q-k norm
137138
# This is one of the main differences between Llama3 and Qwen3
@@ -209,7 +210,7 @@ def forward(
209210
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
210211
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
211212

212-
output = self.sdpa(xq, xk, xv)
213+
output = self.sdpa(xq, xk, xv, scale=self.scaling)
213214

214215
output = output.transpose(
215216
1, 2

0 commit comments

Comments
 (0)