Skip to content

Commit 76f35d4

Browse files
committed
add model num params display, gpu memory metrics (#56)
This PR is the start of adding perf related metrics. 1 - This PR adds function for logging the total num of unique model params, with option for only counting trainable params as well. (for future peft/qlora type work). 2 - logs it with comma formatted logging and model name ala: <img width="716" alt="Screenshot 2024-02-12 at 4 12 22 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/8eb48870-ab1e-4b70-9159-92864ff6c0e5"> this helps de-mistify for example the size of our debug model as well: <img width="716" alt="Screenshot 2024-02-12 at 4 10 17 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/77475306-54bc-48a6-bf28-9c9a542577fd"> **additional updates** - added in gpu mem tracking. We want to show the user peak memory stats, as well as monitor and alert for any cudacachealloc retries which are a perf hindrance. Thus, added class GPUMemoryMonitor: usage: 1 - instantiate <img width="1329" alt="Screenshot 2024-02-13 at 9 32 11 AM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/95610386-6fde-47bb-bbdc-bb7c399c5895"> 2 - start of training = start_monitoring() 3 - end of training = stop_monitoring() 4 - show results = get_peak_stats_str() and rank0_log it. <img width="1074" alt="Screenshot 2024-02-13 at 9 12 45 AM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/b6c7c854-7d83-436a-bea9-a67109422381"> [ghstack-poisoned]
1 parent 9e975f3 commit 76f35d4

File tree

3 files changed

+202
-2
lines changed

3 files changed

+202
-2
lines changed

run_llama_train.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-""}
2323
# Please adjust this to a longer interval period. The unit of measurement is in steps.
2424
CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}
2525

26-
torchrun --nproc_per_node=${NGPU} \
26+
torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
2727
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2828
train.py --steps 10 \
2929
--model ${MODEL} --model_conf ${MODEL_CONF} \
3030
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
31-
--compile
31+
--compile \
3232
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}

torchtrain/metrics.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3+
4+
# Copyright (c) Meta Platforms, Inc. and affiliates.
5+
# All rights reserved
6+
7+
from collections import namedtuple
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
_gb_in_bytes = 1024 * 1024 * 1024
13+
_mb_in_bytes = 1024 * 1024
14+
15+
16+
def format_to_gb(item, precision=4):
17+
"""quick function to format numbers to gigabyte and round to (default) 4 digit precision"""
18+
metric_num = item / _gb_in_bytes
19+
metric_num = round(metric_num, ndigits=precision)
20+
return metric_num
21+
22+
23+
def convert_to_gpu_pct(value, total_gpu_memory):
24+
return round(100 * (value / total_gpu_memory), 2)
25+
26+
27+
# named tuple for passing memory stats (as % of device capacity) for Tensorboard logging
28+
GPUMemStats = namedtuple(
29+
"GPUMemStats",
30+
[
31+
"allocated_curr",
32+
"allocated_peak",
33+
"reserved_curr",
34+
"reserved_peak",
35+
"active_curr",
36+
"active_peak",
37+
"num_retries",
38+
],
39+
)
40+
41+
42+
class GPUMemoryMonitor:
43+
"""
44+
Class to monitor GPU memory usage
45+
"""
46+
47+
def __init__(self, device: str = "cuda:0"):
48+
self.device = torch.device(device) # device object
49+
self.device_name = torch.cuda.get_device_name(self.device)
50+
self.device_index = torch.cuda.current_device()
51+
self.device_capacity = torch.cuda.get_device_properties(
52+
self.device
53+
).total_memory
54+
self.device_capacity_gb = format_to_gb(self.device_capacity)
55+
self.num_retries = 0
56+
self.num_ooms = 0
57+
self.peak_active_memory = 0
58+
self.peak_allocated_memory = 0
59+
self.peak_reserved_memory = 0
60+
self.curr_reserved_memory = 0
61+
62+
self.device_reserved_memory_usage = 0
63+
self.device_reserved_memory_gb = 0
64+
self.device_reserved_memory_pct = 0
65+
66+
self.device_active_memory_usage = 0
67+
self.device_active_memory_gb = 0
68+
self.device_active_memory_pct = 0
69+
70+
# current stats
71+
self.device_alloc_memory_usage = torch.cuda.memory_allocated(self.device)
72+
self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage)
73+
self.device_alloc_memory_pct = convert_to_gpu_pct(
74+
self.device_alloc_memory_usage, self.device_capacity
75+
)
76+
77+
# reset stats, clear cache
78+
torch.cuda.reset_peak_memory_stats()
79+
torch.cuda.empty_cache()
80+
81+
def get_pct_memory(self, memory_num):
82+
pct_memory = memory_num / self.device_capacity
83+
pct_memory = round(100 * (pct_memory), 2)
84+
return pct_memory
85+
86+
def get_gb_memory(self, memory_num):
87+
gb_memory = memory_num / _gb_in_bytes
88+
gb_memory = round(gb_memory, 2)
89+
return gb_memory
90+
91+
def get_current_stats(self, return_data: bool = False):
92+
"""
93+
get the CudaCachingAllocator stats for the current device
94+
95+
return_data: bool, if True, return the data as a named tuple
96+
"""
97+
curr_mem = torch.cuda.memory_stats(self.device)
98+
99+
self.device_alloc_memory_usage = curr_mem["allocated_bytes.all.current"]
100+
self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage)
101+
self.device_alloc_memory_pct = convert_to_gpu_pct(
102+
self.device_alloc_memory_usage, self.device_capacity
103+
)
104+
105+
self.device_reserved_memory_usage = curr_mem["reserved_bytes.all.current"]
106+
self.device_reserved_memory_gb = format_to_gb(self.device_reserved_memory_usage)
107+
self.device_reserved_memory_pct = convert_to_gpu_pct(
108+
self.device_reserved_memory_usage, self.device_capacity
109+
)
110+
111+
self.device_active_memory_usage = curr_mem["active_bytes.all.current"]
112+
self.device_active_memory_gb = format_to_gb(self.device_active_memory_usage)
113+
self.device_active_memory_pct = convert_to_gpu_pct(
114+
self.device_active_memory_usage, self.device_capacity
115+
)
116+
117+
display_str = ""
118+
display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%,"
119+
display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n"
120+
121+
self.get_peak_stats(curr_mem)
122+
123+
peak_active_pct = self.get_pct_memory(self.peak_active_memory)
124+
peak_allocated_pct = self.get_pct_memory(self.peak_allocated_memory)
125+
peak_reserved_pct = self.get_pct_memory(self.peak_reserved_memory)
126+
display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n"
127+
128+
display_str += f"num retries: {self.num_retries}, num ooms: {self.num_ooms}"
129+
if self.num_retries > 0:
130+
display_str += f"\nWARNING: {self.num_retries} retries -- recommend lowering batch size for max performance\n"
131+
132+
if not return_data:
133+
return display_str
134+
135+
# return named tuple
136+
curr_mem_stats = GPUMemStats(
137+
self.device_alloc_memory_pct,
138+
peak_active_pct,
139+
self.device_reserved_memory_pct,
140+
peak_reserved_pct,
141+
self.device_active_memory_pct,
142+
peak_active_pct,
143+
self.num_retries,
144+
)
145+
return curr_mem_stats
146+
147+
def start_monitoring(self):
148+
"""reset all monitoring stats"""
149+
self.reset_peak_stats()
150+
151+
def get_peak_stats(self, cuda_info=None):
152+
"""capture current peak memory stats"""
153+
if not cuda_info:
154+
cuda_info = torch.cuda.memory_stats()
155+
156+
self.peak_active_memory = cuda_info.get("active_bytes.all.peak", 0)
157+
self.peak_allocated_memory = cuda_info.get("allocated_bytes.all.peak", 0)
158+
self.peak_reserved_memory = cuda_info.get("reserved_bytes.all.peak", 0)
159+
160+
self.num_retries = cuda_info.get("num_alloc_retries", 0)
161+
self.num_ooms = cuda_info.get("num_ooms", 0)
162+
163+
def reset_peak_stats(self):
164+
"""reset peak memory stats"""
165+
torch.cuda.reset_peak_memory_stats()
166+
torch.cuda.empty_cache()
167+
self.num_retries = 0
168+
self.num_ooms = 0
169+
self.active_peak_memory_utilization_str = ""
170+
self.peak_memory_utilization_str = ""
171+
self.peak_reserved_memory_utilization_str = ""
172+
173+
def __str__(self):
174+
_ = self.get_current_stats()
175+
display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gb} GB capacity, "
176+
display_str += f"{self.device_alloc_memory_gb} GB in-use, {self.device_alloc_memory_pct}% in-use"
177+
return f"{display_str}"
178+
179+
180+
def get_num_params(model: nn.Module, only_trainable: bool = False) -> int:
181+
"""
182+
Get the total model params
183+
Args : only_trainable: whether to only count trainable params
184+
"""
185+
param_list = list(model.parameters())
186+
if only_trainable:
187+
param_list = [p for p in param_list if p.requires_grad]
188+
unique_params = {p.data_ptr(): p for p in param_list}.values()
189+
return sum(p.numel() for p in unique_params)

train.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchtrain.datasets import create_tokenizer, dataloader_fn
1919
from torchtrain.logging_utils import init_logger, rank0_log
2020
from torchtrain.lr_scheduling import get_lr_scheduler
21+
from torchtrain.metrics import get_num_params, GPUMemoryMonitor
2122

2223
from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
2324
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims
@@ -105,6 +106,14 @@ def main(args):
105106

106107
model = model_cls.from_model_args(model_config)
107108

109+
# log model size
110+
model_param_count = get_num_params(model)
111+
rank0_log(
112+
f"Model {model_name} {args.model_conf} size: {model_param_count:,} total parameters"
113+
)
114+
gpu_metrics = GPUMemoryMonitor("cuda")
115+
rank0_log(f"GPU memory usage: {gpu_metrics}")
116+
108117
# apply PTD parallelisms + AC
109118
model = models_parallelize_fns[model_name](model, world_mesh, parallel_dims, args)
110119

@@ -193,6 +202,8 @@ def main(args):
193202

194203
checkpoint.save(train_state.step, force=(train_state.step == args.steps))
195204

205+
rank0_log(f"{gpu_metrics.get_current_stats()}")
206+
196207

197208
if __name__ == "__main__":
198209
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")

0 commit comments

Comments
 (0)