Skip to content

Commit 377eab2

Browse files
committed
Update on "add TensorBoard logging with loss and wps"
Each rank build its own TensorBoard writer. The global loss is communicated among all ranks before logging. To visualize using SSH tunneling: `ssh -L 6006:127.0.0.1:6006 your_user_namemy_server_ip` in torchtrain repo `tensorboard --logdir=./torchtrain/outputs/tb` then on web browser go to http://localhost:6006/ <img width="722" alt="Screenshot 2024-02-12 at 6 39 28 PM" src="https://github.com/pytorch-labs/torchtrain/assets/150487191/6304103c-fa89-4f1c-a8a2-57c887b07cd3"> [ghstack-poisoned]
2 parents e1abc87 + 794785a commit 377eab2

File tree

9 files changed

+344
-59
lines changed

9 files changed

+344
-59
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,4 @@ outputs
99
data
1010
out
1111
wandb
12-
*.model
1312
*.json

README.md

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,37 @@ torchtrain contains PyTorch native parallelisms, tools and utilities to train la
66

77
# Installation
88

9-
install PyTorch from source or install the latest pytorch nightly, then install requirements by
9+
Install PyTorch from source or install the latest pytorch nightly, then install requirements by
1010

1111
```python
1212
pip install -r requirements.txt
1313
```
1414

15-
download tokenizer from HF
16-
This part is needed first time if there's no tokenizer locally by run:
17-
15+
Install additional dev requirements if you want to contribute to the repo:
1816
```
19-
python torchtrain/datasets/download_tokenizer.py --hf_token your_token
17+
pip install -r dev-requirements.txt
2018
```
2119

2220
run the llama debug model locally to verify the setup is correct:
2321

2422
```
2523
./run_llama_train.sh
2624
```
25+
26+
# TensorBoard
27+
28+
To visualize training metrics on TensorBoard:
29+
30+
1. (by default) set `enable_tensorboard = true` in `torchtrain/train_configs/train_config.toml`
31+
32+
2. set up SSH tunneling
33+
```
34+
ssh -L 6006:127.0.0.1:6006 [username]@[hostname]
35+
```
36+
37+
3. then in the torchtrain repo
38+
```
39+
tensorboard --logdir=./torchtrain/outputs/tb
40+
```
41+
42+
4. go to the URL it provides OR to http://localhost:6006/

run_llama_train.sh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
88
# e.g.
99
# LOG_RANK=0,1 NGPU=4 SP=2 ./run_llama_train.sh
1010

11-
MODEL=${MODEL:-"debugmodel"}
11+
MODEL=${MODEL:-"llama"}
12+
MODEL_CONF=${MODEL_CONF:-"debugmodel"}
1213
NGPU=${NGPU:-"8"}
1314
PP=${PP:-"1"}
1415
SP=${SP:-"1"}
@@ -22,8 +23,10 @@ CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-""}
2223
# Please adjust this to a longer interval period. The unit of measurement is in steps.
2324
CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}
2425

25-
torchrun --nproc_per_node=${NGPU} \
26+
torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
2627
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
27-
train.py --steps 10 --compile \
28+
train.py --steps 10 \
29+
--model ${MODEL} --model_conf ${MODEL_CONF} \
2830
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
31+
--compile \
2932
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
488 KB
Binary file not shown.

torchtrain/metrics.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,200 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

4+
# Copyright (c) Meta Platforms, Inc. and affiliates.
5+
# All rights reserved
6+
47
import os
8+
from collections import namedtuple
59
from datetime import datetime
610
from typing import Any, Dict, Optional
711

812
import torch
13+
import torch.nn as nn
914
from torch.utils.tensorboard import SummaryWriter
1015

1116
from torchtrain.logging_utils import rank0_log
1217
from torchtrain.profiling import get_config_from_toml
1318

19+
_gb_in_bytes = 1024 * 1024 * 1024
20+
_mb_in_bytes = 1024 * 1024
21+
22+
23+
def format_to_gb(item, precision=4):
24+
"""quick function to format numbers to gigabyte and round to (default) 4 digit precision"""
25+
metric_num = item / _gb_in_bytes
26+
metric_num = round(metric_num, ndigits=precision)
27+
return metric_num
28+
29+
30+
def convert_to_gpu_pct(value, total_gpu_memory):
31+
return round(100 * (value / total_gpu_memory), 2)
32+
33+
34+
# named tuple for passing memory stats (as % of device capacity) for Tensorboard logging
35+
GPUMemStats = namedtuple(
36+
"GPUMemStats",
37+
[
38+
"allocated_curr",
39+
"allocated_peak",
40+
"reserved_curr",
41+
"reserved_peak",
42+
"active_curr",
43+
"active_peak",
44+
"num_retries",
45+
],
46+
)
47+
48+
49+
class GPUMemoryMonitor:
50+
"""
51+
Class to monitor GPU memory usage
52+
"""
53+
54+
def __init__(self, device: str = "cuda:0"):
55+
self.device = torch.device(device) # device object
56+
self.device_name = torch.cuda.get_device_name(self.device)
57+
self.device_index = torch.cuda.current_device()
58+
self.device_capacity = torch.cuda.get_device_properties(
59+
self.device
60+
).total_memory
61+
self.device_capacity_gb = format_to_gb(self.device_capacity)
62+
self.num_retries = 0
63+
self.num_ooms = 0
64+
self.peak_active_memory = 0
65+
self.peak_allocated_memory = 0
66+
self.peak_reserved_memory = 0
67+
self.curr_reserved_memory = 0
68+
69+
self.device_reserved_memory_usage = 0
70+
self.device_reserved_memory_gb = 0
71+
self.device_reserved_memory_pct = 0
72+
73+
self.device_active_memory_usage = 0
74+
self.device_active_memory_gb = 0
75+
self.device_active_memory_pct = 0
76+
77+
# current stats
78+
self.device_alloc_memory_usage = torch.cuda.memory_allocated(self.device)
79+
self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage)
80+
self.device_alloc_memory_pct = convert_to_gpu_pct(
81+
self.device_alloc_memory_usage, self.device_capacity
82+
)
83+
84+
# reset stats, clear cache
85+
torch.cuda.reset_peak_memory_stats()
86+
torch.cuda.empty_cache()
87+
88+
def get_pct_memory(self, memory_num):
89+
pct_memory = memory_num / self.device_capacity
90+
pct_memory = round(100 * (pct_memory), 2)
91+
return pct_memory
92+
93+
def get_gb_memory(self, memory_num):
94+
gb_memory = memory_num / _gb_in_bytes
95+
gb_memory = round(gb_memory, 2)
96+
return gb_memory
97+
98+
def get_current_stats(self, return_data: bool = False):
99+
"""
100+
get the CudaCachingAllocator stats for the current device
101+
102+
return_data: bool, if True, return the data as a named tuple
103+
"""
104+
curr_mem = torch.cuda.memory_stats(self.device)
105+
106+
self.device_alloc_memory_usage = curr_mem["allocated_bytes.all.current"]
107+
self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage)
108+
self.device_alloc_memory_pct = convert_to_gpu_pct(
109+
self.device_alloc_memory_usage, self.device_capacity
110+
)
111+
112+
self.device_reserved_memory_usage = curr_mem["reserved_bytes.all.current"]
113+
self.device_reserved_memory_gb = format_to_gb(self.device_reserved_memory_usage)
114+
self.device_reserved_memory_pct = convert_to_gpu_pct(
115+
self.device_reserved_memory_usage, self.device_capacity
116+
)
117+
118+
self.device_active_memory_usage = curr_mem["active_bytes.all.current"]
119+
self.device_active_memory_gb = format_to_gb(self.device_active_memory_usage)
120+
self.device_active_memory_pct = convert_to_gpu_pct(
121+
self.device_active_memory_usage, self.device_capacity
122+
)
123+
124+
display_str = ""
125+
display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%,"
126+
display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n"
127+
128+
self.get_peak_stats(curr_mem)
129+
130+
peak_active_pct = self.get_pct_memory(self.peak_active_memory)
131+
peak_allocated_pct = self.get_pct_memory(self.peak_allocated_memory)
132+
peak_reserved_pct = self.get_pct_memory(self.peak_reserved_memory)
133+
display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n"
134+
135+
display_str += f"num retries: {self.num_retries}, num ooms: {self.num_ooms}"
136+
if self.num_retries > 0:
137+
display_str += f"\nWARNING: {self.num_retries} retries -- recommend lowering batch size for max performance\n"
138+
139+
if not return_data:
140+
return display_str
141+
142+
# return named tuple
143+
curr_mem_stats = GPUMemStats(
144+
self.device_alloc_memory_pct,
145+
peak_active_pct,
146+
self.device_reserved_memory_pct,
147+
peak_reserved_pct,
148+
self.device_active_memory_pct,
149+
peak_active_pct,
150+
self.num_retries,
151+
)
152+
return curr_mem_stats
153+
154+
def start_monitoring(self):
155+
"""reset all monitoring stats"""
156+
self.reset_peak_stats()
157+
158+
def get_peak_stats(self, cuda_info=None):
159+
"""capture current peak memory stats"""
160+
if not cuda_info:
161+
cuda_info = torch.cuda.memory_stats()
162+
163+
self.peak_active_memory = cuda_info.get("active_bytes.all.peak", 0)
164+
self.peak_allocated_memory = cuda_info.get("allocated_bytes.all.peak", 0)
165+
self.peak_reserved_memory = cuda_info.get("reserved_bytes.all.peak", 0)
166+
167+
self.num_retries = cuda_info.get("num_alloc_retries", 0)
168+
self.num_ooms = cuda_info.get("num_ooms", 0)
169+
170+
def reset_peak_stats(self):
171+
"""reset peak memory stats"""
172+
torch.cuda.reset_peak_memory_stats()
173+
torch.cuda.empty_cache()
174+
self.num_retries = 0
175+
self.num_ooms = 0
176+
self.active_peak_memory_utilization_str = ""
177+
self.peak_memory_utilization_str = ""
178+
self.peak_reserved_memory_utilization_str = ""
179+
180+
def __str__(self):
181+
_ = self.get_current_stats()
182+
display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gb} GB capacity, "
183+
display_str += f"{self.device_alloc_memory_gb} GB in-use, {self.device_alloc_memory_pct}% in-use"
184+
return f"{display_str}"
185+
186+
187+
def get_num_params(model: nn.Module, only_trainable: bool = False) -> int:
188+
"""
189+
Get the total model params
190+
Args : only_trainable: whether to only count trainable params
191+
"""
192+
param_list = list(model.parameters())
193+
if only_trainable:
194+
param_list = [p for p in param_list if p.requires_grad]
195+
unique_params = {p.data_ptr(): p for p in param_list}.values()
196+
return sum(p.numel() for p in unique_params)
197+
14198

15199
class MetricLogger:
16200
def __init__(self, log_dir, tag, enable_tb):

torchtrain/models/llama/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
__all__ = ["Transformer"]
77

88
llama_configs = {
9-
"debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16),
9+
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16),
10+
"1B": ModelArgs(dim=1024, n_layers=16, n_heads=8),
1011
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
1112
"13B": ModelArgs(dim=5120, n_layers=40, n_heads=40),
13+
"40B": ModelArgs(dim=5120, n_layers=80, n_heads=40),
1214
"70B": ModelArgs(
1315
dim=8192,
1416
n_layers=80,

0 commit comments

Comments
 (0)