@@ -55,11 +55,36 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
55
55
self .max_seq_len = seq_len
56
56
57
57
def get_nparams_and_flops (self , model : nn .Module , seq_len : int ) -> tuple [int , int ]:
58
- nparams = sum (p .numel () for p in model .parameters ())
59
- nparams_embedding = sum (
60
- sum (p .numel () for p in m .parameters ())
61
- for m in model .children ()
62
- if isinstance (m , nn .Embedding )
58
+ nparams_embedding = 0
59
+ nparams_moe_router = 0
60
+ nparams_shared_experts = 0
61
+ nparams_experts = 0
62
+ nparams_dense = 0
63
+
64
+ for name , p in model .named_parameters ():
65
+ if "embedding" in name :
66
+ nparams_embedding += p .numel ()
67
+ nparams_dense += p .numel ()
68
+ elif "moe.shared_experts" in name :
69
+ nparams_shared_experts += p .numel ()
70
+ elif "moe.router" in name :
71
+ nparams_moe_router += p .numel ()
72
+ elif "moe.experts" in name :
73
+ nparams_experts += p .numel ()
74
+ else :
75
+ nparams_dense += p .numel ()
76
+
77
+ nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts
78
+ nparams = nparams_dense + nparams_sparse
79
+ nparams_sparse_active = (
80
+ nparams_moe_router
81
+ + nparams_shared_experts
82
+ + nparams_experts * self .moe_args .top_k // self .moe_args .num_experts
83
+ )
84
+
85
+ logger .info (
86
+ f"Total parameter count: dense { nparams_dense :,} , "
87
+ f"sparse { nparams_sparse :,} , active { nparams_dense + nparams_sparse_active :,} "
63
88
)
64
89
65
90
l , h , q , t = (
@@ -68,10 +93,18 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in
68
93
self .dim // self .n_heads ,
69
94
seq_len ,
70
95
)
71
- num_flops_per_token = 6 * (nparams - nparams_embedding ) + 12 * l * h * q * t
96
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
97
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
98
+ # 2. the flash attention does 1 more matmul recomputation in the backward
99
+ # but recomputation should not be counted in calculating MFU (+0)
100
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
101
+ # 4. we follow the convention and do not account for sparsity in causal attention
102
+ num_flops_per_token = (
103
+ 6 * (nparams_dense - nparams_embedding + nparams_sparse_active )
104
+ + 12 * l * h * q * t
105
+ )
72
106
73
107
if self .enable_weight_tying :
74
- # exclude model.token_embedding parameters from nparams
75
108
nparams = nparams - nparams_embedding
76
109
77
110
return nparams , num_flops_per_token
0 commit comments