@@ -96,16 +96,11 @@ def init_distributed(job_config):
96
96
os .environ ["TORCH_NCCL_AVOID_RECORD_STREAMS" ] = "1"
97
97
98
98
99
- def get_num_params (model : torch .nn .Module , only_trainable : bool = False ) -> int :
100
- """
101
- Get the total model params
102
- Args : only_trainable: whether to only count trainable params
103
- """
104
- param_list = list (model .parameters ())
105
- if only_trainable :
106
- param_list = [p for p in param_list if p .requires_grad ]
107
- # unique_params = {p.data_ptr(): p for p in param_list}.values()
108
- return sum (p .numel () for p in param_list )
99
+ def get_num_params (model : torch .nn .Module , exclude_embedding : bool = False ) -> int :
100
+ num_params = sum (p .numel () for p in model .parameters ())
101
+ if exclude_embedding :
102
+ num_params -= model .tok_embeddings .weight .numel ()
103
+ return num_params
109
104
110
105
111
106
def get_num_flop_per_token (num_params : int , model_config , seq_len ) -> int :
@@ -115,7 +110,14 @@ def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
115
110
model_config .dim // model_config .n_heads ,
116
111
seq_len ,
117
112
)
113
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
114
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
115
+ # 2. the flash attention does 1 more matmul recomputation in the backward
116
+ # but recomputation should not be counted in calculating MFU (+0)
117
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
118
+ # 4. we follow the convention and do not account for sparsity in causal attention
118
119
flop_per_token = 6 * num_params + 12 * l * h * q * t
120
+
119
121
return flop_per_token
120
122
121
123
0 commit comments