Skip to content

add model num params display, gpu memory metrics #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 15, 2024
4 changes: 2 additions & 2 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-""}
# Please adjust this to a longer interval period. The unit of measurement is in steps.
CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}

torchrun --nproc_per_node=${NGPU} \
torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --steps 10 \
--model ${MODEL} --model_conf ${MODEL_CONF} \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
--compile
--compile \
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
189 changes: 189 additions & 0 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved

from collections import namedtuple

import torch
import torch.nn as nn

_gb_in_bytes = 1024 * 1024 * 1024
_mb_in_bytes = 1024 * 1024


def format_to_gb(item, precision=4):
"""quick function to format numbers to gigabyte and round to (default) 4 digit precision"""
metric_num = item / _gb_in_bytes
metric_num = round(metric_num, ndigits=precision)
return metric_num


def convert_to_gpu_pct(value, total_gpu_memory):
return round(100 * (value / total_gpu_memory), 2)


# named tuple for passing memory stats (as % of device capacity) for Tensorboard logging
GPUMemStats = namedtuple(
"GPUMemStats",
[
"allocated_curr",
"allocated_peak",
"reserved_curr",
"reserved_peak",
"active_curr",
"active_peak",
"num_retries",
],
)


class GPUMemoryMonitor:
"""
Class to monitor GPU memory usage
"""

def __init__(self, device: str = "cuda:0"):
self.device = torch.device(device) # device object
self.device_name = torch.cuda.get_device_name(self.device)
self.device_index = torch.cuda.current_device()
self.device_capacity = torch.cuda.get_device_properties(
self.device
).total_memory
self.device_capacity_gb = format_to_gb(self.device_capacity)
self.num_retries = 0
self.num_ooms = 0
self.peak_active_memory = 0
self.peak_allocated_memory = 0
self.peak_reserved_memory = 0
self.curr_reserved_memory = 0

self.device_reserved_memory_usage = 0
self.device_reserved_memory_gb = 0
self.device_reserved_memory_pct = 0

self.device_active_memory_usage = 0
self.device_active_memory_gb = 0
self.device_active_memory_pct = 0

# current stats
self.device_alloc_memory_usage = torch.cuda.memory_allocated(self.device)
self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage)
self.device_alloc_memory_pct = convert_to_gpu_pct(
self.device_alloc_memory_usage, self.device_capacity
)

# reset stats, clear cache
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()

def get_pct_memory(self, memory_num):
pct_memory = memory_num / self.device_capacity
pct_memory = round(100 * (pct_memory), 2)
return pct_memory

def get_gb_memory(self, memory_num):
gb_memory = memory_num / _gb_in_bytes
gb_memory = round(gb_memory, 2)
return gb_memory

def get_current_stats(self, return_data: bool = False):
"""
get the CudaCachingAllocator stats for the current device

return_data: bool, if True, return the data as a named tuple
"""
curr_mem = torch.cuda.memory_stats(self.device)

self.device_alloc_memory_usage = curr_mem["allocated_bytes.all.current"]
self.device_alloc_memory_gb = format_to_gb(self.device_alloc_memory_usage)
self.device_alloc_memory_pct = convert_to_gpu_pct(
self.device_alloc_memory_usage, self.device_capacity
)

self.device_reserved_memory_usage = curr_mem["reserved_bytes.all.current"]
self.device_reserved_memory_gb = format_to_gb(self.device_reserved_memory_usage)
self.device_reserved_memory_pct = convert_to_gpu_pct(
self.device_reserved_memory_usage, self.device_capacity
)

self.device_active_memory_usage = curr_mem["active_bytes.all.current"]
self.device_active_memory_gb = format_to_gb(self.device_active_memory_usage)
self.device_active_memory_pct = convert_to_gpu_pct(
self.device_active_memory_usage, self.device_capacity
)

display_str = ""
display_str += f"Current Memory: {self.device_name} ({self.device_index}): Reserved: {self.device_reserved_memory_pct}%,"
display_str += f"Alloc {self.device_alloc_memory_pct}%, Active: {self.device_active_memory_pct}%\n"

self.get_peak_stats(curr_mem)

peak_active_pct = self.get_pct_memory(self.peak_active_memory)
peak_allocated_pct = self.get_pct_memory(self.peak_allocated_memory)
peak_reserved_pct = self.get_pct_memory(self.peak_reserved_memory)
display_str += f"Peak Memory: Reserved {peak_reserved_pct}%, Alloc {peak_allocated_pct}%, Active: {peak_active_pct}%\n"

display_str += f"num retries: {self.num_retries}, num ooms: {self.num_ooms}"
if self.num_retries > 0:
display_str += f"\nWARNING: {self.num_retries} retries -- recommend lowering batch size for max performance\n"

if not return_data:
return display_str

# return named tuple
curr_mem_stats = GPUMemStats(
self.device_alloc_memory_pct,
peak_active_pct,
self.device_reserved_memory_pct,
peak_reserved_pct,
self.device_active_memory_pct,
peak_active_pct,
self.num_retries,
)
return curr_mem_stats

def start_monitoring(self):
"""reset all monitoring stats"""
self.reset_peak_stats()

def get_peak_stats(self, cuda_info=None):
"""capture current peak memory stats"""
if not cuda_info:
cuda_info = torch.cuda.memory_stats()

self.peak_active_memory = cuda_info.get("active_bytes.all.peak", 0)
self.peak_allocated_memory = cuda_info.get("allocated_bytes.all.peak", 0)
self.peak_reserved_memory = cuda_info.get("reserved_bytes.all.peak", 0)

self.num_retries = cuda_info.get("num_alloc_retries", 0)
self.num_ooms = cuda_info.get("num_ooms", 0)

def reset_peak_stats(self):
"""reset peak memory stats"""
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
self.num_retries = 0
self.num_ooms = 0
self.active_peak_memory_utilization_str = ""
self.peak_memory_utilization_str = ""
self.peak_reserved_memory_utilization_str = ""

def __str__(self):
_ = self.get_current_stats()
display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gb} GB capacity, "
display_str += f"{self.device_alloc_memory_gb} GB in-use, {self.device_alloc_memory_pct}% in-use"
return f"{display_str}"


def get_num_params(model: nn.Module, only_trainable: bool = False) -> int:
"""
Get the total model params
Args : only_trainable: whether to only count trainable params
"""
param_list = list(model.parameters())
if only_trainable:
param_list = [p for p in param_list if p.requires_grad]
unique_params = {p.data_ptr(): p for p in param_list}.values()
return sum(p.numel() for p in unique_params)
11 changes: 11 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchtrain.datasets import create_tokenizer, dataloader_fn
from torchtrain.logging_utils import init_logger, rank0_log
from torchtrain.lr_scheduling import get_lr_scheduler
from torchtrain.metrics import get_num_params, GPUMemoryMonitor

from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims
Expand Down Expand Up @@ -105,6 +106,14 @@ def main(args):

model = model_cls.from_model_args(model_config)

# log model size
model_param_count = get_num_params(model)
rank0_log(
f"Model {model_name} {args.model_conf} size: {model_param_count:,} total parameters"
)
gpu_metrics = GPUMemoryMonitor("cuda")
rank0_log(f"GPU memory usage: {gpu_metrics}")

# apply PTD parallelisms + AC
model = models_parallelize_fns[model_name](model, world_mesh, parallel_dims, args)

Expand Down Expand Up @@ -193,6 +202,8 @@ def main(args):

checkpoint.save(train_state.step, force=(train_state.step == args.steps))

rank0_log(f"{gpu_metrics.get_current_stats()}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
Expand Down