Skip to content

Commit 4d8c245

Browse files
committed
exclude embedding in MFU computation
ghstack-source-id: 9daa990 Pull Request resolved: #280
1 parent 58b1169 commit 4d8c245

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

torchtitan/utils.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,11 @@ def init_distributed(job_config):
9696
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
9797

9898

99-
def get_num_params(model: torch.nn.Module, only_trainable: bool = False) -> int:
100-
"""
101-
Get the total model params
102-
Args : only_trainable: whether to only count trainable params
103-
"""
104-
param_list = list(model.parameters())
105-
if only_trainable:
106-
param_list = [p for p in param_list if p.requires_grad]
107-
# unique_params = {p.data_ptr(): p for p in param_list}.values()
108-
return sum(p.numel() for p in param_list)
99+
def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
100+
num_params = sum(p.numel() for p in model.parameters())
101+
if exclude_embedding:
102+
num_params -= model.tok_embeddings.weight.numel()
103+
return num_params
109104

110105

111106
def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
@@ -115,7 +110,14 @@ def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
115110
model_config.dim // model_config.n_heads,
116111
seq_len,
117112
)
113+
# Reasoning behind the factor of 12 for the self-attention part of the formula:
114+
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
115+
# 2. the flash attention does 1 more matmul recomputation in the backward
116+
# but recomputation should not be counted in calculating MFU (+0)
117+
# 3. each matmul performs 1 multiplication and 1 addition (*2)
118+
# 4. we follow the convention and do not account for sparsity in causal attention
118119
flop_per_token = 6 * num_params + 12 * l * h * q * t
120+
119121
return flop_per_token
120122

121123

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,9 @@ def loss_fn(pred, labels):
187187
# log model size
188188
model_param_count = get_num_params(model)
189189
num_flop_per_token = get_num_flop_per_token(
190-
model_param_count, model_config, job_config.training.seq_len
190+
get_num_params(model, exclude_embedding=True),
191+
model_config,
192+
job_config.training.seq_len,
191193
)
192194
logger.info(
193195
f"{color.blue}Model {model_name} {job_config.model.flavor} "

0 commit comments

Comments
 (0)