From d91d34f7ac80702e3460fa02cbcb9d826d480646 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 16 Jan 2024 18:28:34 -0800 Subject: [PATCH 1/4] add context based profiling, control via cmd line args --- .gitignore | 1 + ...train_llama_local.sh => run_llama_train.sh | 2 +- train.py | 97 ++++++++++++++----- 3 files changed, 75 insertions(+), 25 deletions(-) rename scripts/train_llama_local.sh => run_llama_train.sh (85%) diff --git a/.gitignore b/.gitignore index 185245ae7..24b08a290 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ data out wandb *.model +*.json diff --git a/scripts/train_llama_local.sh b/run_llama_train.sh similarity index 85% rename from scripts/train_llama_local.sh rename to run_llama_train.sh index 7a5b6e282..cf7bc4ffa 100755 --- a/scripts/train_llama_local.sh +++ b/run_llama_train.sh @@ -9,4 +9,4 @@ NGPU=8 MP=4 torchrun --nproc_per_node=${NGPU} \ -${TRAINER_DIR}/train.py +train.py --steps 10 diff --git a/train.py b/train.py index 3b931a7ec..dd8c29694 100644 --- a/train.py +++ b/train.py @@ -4,17 +4,23 @@ from typing import List import logging from logging import getLogger - +import sys # for logging # torch imports import torch import torch.nn.functional as F from torch.distributed.device_mesh import init_device_mesh from torch.utils.data import DataLoader +from contextlib import contextmanager +import contextlib + + # torchtrain related from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer from torchtrain.datasets import create_tokenizer, dataset_cls_map, pad_batch_to_longest_seq +def lprint(msg=""): + print(f"Debug ++> {sys._getframe().f_back.f_lineno}: {msg}") logger = getLogger() @@ -85,37 +91,79 @@ def main(args): # TODO: apply parallelisms, e.g. fsdp/tp # TODO: add profiler + @contextlib.contextmanager + def maybe_run_profiler(args, *pos_args, **kwargs): + use_profiler: bool = args.run_profiler + + trace_dir = args.profile_folder + rank = torch.distributed.get_rank() + + def trace_handler(prof): + rank0_log(f"exporting profile traces to {trace_dir}") + prof.export_chrome_trace( + f"{trace_dir}/rank{rank}_trace.json" + ) + + if use_profiler: + if not os.path.exists(trace_dir): + os.makedirs(trace_dir) + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1), + on_trace_ready=trace_handler, + profile_memory=True, + with_stack=False, + record_shapes=True, + ) as torch_profiler: + yield torch_profiler + else: + torch_profiler = contextlib.nullcontext() + yield None + + if args.run_profiler: + rank0_log(f"Profiling active. Traces will be saved at {args.profile_folder}") + + # TODO: add metrics train_state = TrainState() # train loop model.train() - while train_state.step < args.steps or args.steps == -1: - train_state.step += 1 - # get batch - batch = next(iter(data_loader)) - input_ids, labels = batch - input_ids = input_ids.to(device_type) - labels = labels.to(device_type) + with maybe_run_profiler(args) as torch_profiler: + while train_state.step < args.steps or args.steps == -1: + train_state.step += 1 + # get batch + batch = next(iter(data_loader)) + input_ids, labels = batch + input_ids = input_ids.to(device_type) + labels = labels.to(device_type) + + # forward + pred = model(input_ids) + tok_loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1), reduction="none") + loss = tok_loss.mean() - # forward - pred = model(input_ids) - tok_loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1), reduction="none") - loss = tok_loss.mean() + # backward + loss.backward() + # TODO: add grad scaler - # backward - loss.backward() - # TODO: add grad scaler + # optimizer step + optimizer.step() + optimizer.zero_grad() - # optimizer step - optimizer.step() - optimizer.zero_grad() + # if profiler is active + if torch_profiler: + torch_profiler.step() - train_state.current_loss = loss.item() - train_state.losses.append(train_state.current_loss) + train_state.current_loss = loss.item() + train_state.losses.append(train_state.current_loss) - rank0_log(f"current loss: {train_state.current_loss}") + rank0_log(f"current loss: {train_state.current_loss}") if __name__ == "__main__": @@ -125,13 +173,14 @@ def main(args): parser.add_argument('--model', type=str, default="llama", help="which model to train") parser.add_argument('--model_conf', type=str, default="debugmodel", help="which model config to train") parser.add_argument('--dataset', type=str, default="alpaca", help="dataset to use") - parser.add_argument('--tokenizer_path', type=str, default="torchtrain/datasets/tokenizer.model", help="tokenizer path") + parser.add_argument('--tokenizer_path', type=str, default="./torchtrain/datasets/tokenizer/tokenizer.model", help="tokenizer path") parser.add_argument('--batch_size', type=int, default=8, help="batch size") parser.add_argument('--optimizer', type=str, default="AdamW", help="optimizer to use") - parser.add_argument('--lr', type=float, default=2e-5, help="optimizer to use") + parser.add_argument('--lr', type=float, default=2e-5, help="learning rate to use") parser.add_argument('--steps', type=int, default=-1, help="how many train steps to run") parser.add_argument('--tp_degree', type=int, default=LOCAL_WORLD_SIZE, help="Tensor/Sequence Parallelism degree") parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - + parser.add_argument('--run_profiler', action='store_true', help='Whether to run the profiler.') + parser.add_argument('--profile_folder', type=str, default="./torchtrain/profiler", help='Folder to save profile traces to.') args = parser.parse_args() main(args) From 006e85444c381bedacc1bf910119730870d2ce30 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 16 Jan 2024 18:32:03 -0800 Subject: [PATCH 2/4] ruff formatting --- train.py | 93 +++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 69 insertions(+), 24 deletions(-) diff --git a/train.py b/train.py index dd8c29694..a761a47f9 100644 --- a/train.py +++ b/train.py @@ -5,22 +5,28 @@ import logging from logging import getLogger import sys # for logging + # torch imports import torch import torch.nn.functional as F from torch.distributed.device_mesh import init_device_mesh from torch.utils.data import DataLoader -from contextlib import contextmanager import contextlib # torchtrain related from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer -from torchtrain.datasets import create_tokenizer, dataset_cls_map, pad_batch_to_longest_seq +from torchtrain.datasets import ( + create_tokenizer, + dataset_cls_map, + pad_batch_to_longest_seq, +) + def lprint(msg=""): - print(f"Debug ++> {sys._getframe().f_back.f_lineno}: {msg}") + print(f"Debug ++> {sys._getframe().f_back.f_lineno}: {msg}") + logger = getLogger() @@ -36,14 +42,18 @@ def rank0_log(msg): if torch.distributed.get_rank() == 0: logger.info(msg) + def init_logger(): logger.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.INFO) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) ch.setFormatter(formatter) logger.addHandler(ch) + def build_optimizer(model, args): # build optimizer if args.optimizer == "Adam": @@ -64,7 +74,9 @@ def main(args): # distributed init world_size = int(os.environ["WORLD_SIZE"]) dp_degree = world_size // args.tp_degree - world_mesh = init_device_mesh(device_type, (dp_degree, args.tp_degree), mesh_dim_names=("dp", "tp")) + world_mesh = init_device_mesh( + device_type, (dp_degree, args.tp_degree), mesh_dim_names=("dp", "tp") + ) model_name = args.model # build tokenizer @@ -73,7 +85,11 @@ def main(args): # build dataloader dataset_cls = dataset_cls_map[args.dataset] - data_loader = DataLoader(dataset_cls(tokenizer), batch_size=args.batch_size, collate_fn=pad_batch_to_longest_seq) + data_loader = DataLoader( + dataset_cls(tokenizer), + batch_size=args.batch_size, + collate_fn=pad_batch_to_longest_seq, + ) # build model # TODO: add meta initialization @@ -100,9 +116,7 @@ def maybe_run_profiler(args, *pos_args, **kwargs): def trace_handler(prof): rank0_log(f"exporting profile traces to {trace_dir}") - prof.export_chrome_trace( - f"{trace_dir}/rank{rank}_trace.json" - ) + prof.export_chrome_trace(f"{trace_dir}/rank{rank}_trace.json") if use_profiler: if not os.path.exists(trace_dir): @@ -127,7 +141,6 @@ def trace_handler(prof): if args.run_profiler: rank0_log(f"Profiling active. Traces will be saved at {args.profile_folder}") - # TODO: add metrics train_state = TrainState() @@ -145,7 +158,9 @@ def trace_handler(prof): # forward pred = model(input_ids) - tok_loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1), reduction="none") + tok_loss = F.cross_entropy( + pred.flatten(0, 1), labels.flatten(0, 1), reduction="none" + ) loss = tok_loss.mean() # backward @@ -167,20 +182,50 @@ def trace_handler(prof): if __name__ == "__main__": - parser = argparse.ArgumentParser(description='TorchTrain arg parser.') + parser = argparse.ArgumentParser(description="TorchTrain arg parser.") LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"]) - parser.add_argument('--model', type=str, default="llama", help="which model to train") - parser.add_argument('--model_conf', type=str, default="debugmodel", help="which model config to train") - parser.add_argument('--dataset', type=str, default="alpaca", help="dataset to use") - parser.add_argument('--tokenizer_path', type=str, default="./torchtrain/datasets/tokenizer/tokenizer.model", help="tokenizer path") - parser.add_argument('--batch_size', type=int, default=8, help="batch size") - parser.add_argument('--optimizer', type=str, default="AdamW", help="optimizer to use") - parser.add_argument('--lr', type=float, default=2e-5, help="learning rate to use") - parser.add_argument('--steps', type=int, default=-1, help="how many train steps to run") - parser.add_argument('--tp_degree', type=int, default=LOCAL_WORLD_SIZE, help="Tensor/Sequence Parallelism degree") - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--run_profiler', action='store_true', help='Whether to run the profiler.') - parser.add_argument('--profile_folder', type=str, default="./torchtrain/profiler", help='Folder to save profile traces to.') + parser.add_argument( + "--model", type=str, default="llama", help="which model to train" + ) + parser.add_argument( + "--model_conf", + type=str, + default="debugmodel", + help="which model config to train", + ) + parser.add_argument("--dataset", type=str, default="alpaca", help="dataset to use") + parser.add_argument( + "--tokenizer_path", + type=str, + default="./torchtrain/datasets/tokenizer/tokenizer.model", + help="tokenizer path", + ) + parser.add_argument("--batch_size", type=int, default=8, help="batch size") + parser.add_argument( + "--optimizer", type=str, default="AdamW", help="optimizer to use" + ) + parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use") + parser.add_argument( + "--steps", type=int, default=-1, help="how many train steps to run" + ) + parser.add_argument( + "--tp_degree", + type=int, + default=LOCAL_WORLD_SIZE, + help="Tensor/Sequence Parallelism degree", + ) + parser.add_argument( + "--compile", action="store_true", help="Whether to compile the model." + ) + parser.add_argument( + "--run_profiler", action="store_true", help="Whether to run the profiler." + ) + parser.add_argument( + "--profile_folder", + type=str, + default="./torchtrain/profiler", + help="Folder to save profile traces to.", + ) args = parser.parse_args() main(args) From 0707d237fa7044b1fb39daa23c919364a4de5e1b Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 17 Jan 2024 20:06:38 -0800 Subject: [PATCH 3/4] address feedback - adds user config control for profiling, seperate profiling.py file, global dumps folder, logging_utils.py --- .gitignore | 1 + requirements.txt | 1 + torchtrain/logging_utils.py | 20 ++++++++++ torchtrain/profiling.py | 73 ++++++++++++++++++++++++++++++++++++ torchtrain/train_config.toml | 9 +++++ train.py | 73 ++---------------------------------- 6 files changed, 108 insertions(+), 69 deletions(-) create mode 100644 torchtrain/logging_utils.py create mode 100644 torchtrain/profiling.py create mode 100644 torchtrain/train_config.toml diff --git a/.gitignore b/.gitignore index 24b08a290..5e057e7ca 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__ .DS_Store *.egg-info build +outputs # data data diff --git a/requirements.txt b/requirements.txt index 2b0df57a9..9aa6e395f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch sentencepiece datasets +tomli >= 1.1.0 ; python_version < "3.11" diff --git a/torchtrain/logging_utils.py b/torchtrain/logging_utils.py new file mode 100644 index 000000000..34d5a1ff6 --- /dev/null +++ b/torchtrain/logging_utils.py @@ -0,0 +1,20 @@ +import torch +import logging + +logger = logging.getLogger() + + +def rank0_log(msg): + if torch.distributed.get_rank() == 0: + logger.info(msg) + + +def init_logger(): + logger.setLevel(logging.INFO) + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ch.setFormatter(formatter) + logger.addHandler(ch) diff --git a/torchtrain/profiling.py b/torchtrain/profiling.py new file mode 100644 index 000000000..3c447cd7f --- /dev/null +++ b/torchtrain/profiling.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +import contextlib +import os +import torch + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib + +from torchtrain.logging_utils import rank0_log + +_config_file = "./torchtrain/train_config.toml" + + +def get_config_from_toml(config_path: str = _config_file) -> dict: + """ + Reads a config file in TOML format and returns a dictionary. + """ + with open(config_path, "rb") as f: + config = tomllib.load(f) + return config + + +@contextlib.contextmanager +def maybe_run_profiler(*pos_args, **kwargs): + config = get_config_from_toml() + + # get user defined profiler settings + dump_dir = config["global"]["dump_folder"] + run_profiler = config["profiling"].get("run_profiler", False) + save_trace_dir = config["profiling"]["save_traces_folder"] + trace_dir = os.path.join(dump_dir, save_trace_dir) + + num_iters_to_profile = config["profiling"]["num_iters_to_profile"] + iter_to_start_profiling = config["profiling"]["iter_to_start_profiling"] + # profiler wants a warmup, so we reduce when to start by 1 + iter_to_start_profiling -= 1 + + rank = torch.distributed.get_rank() + + def trace_handler(prof): + rank0_log(f"exporting profile traces to {trace_dir}") + prof.export_chrome_trace(f"{trace_dir}/rank{rank}_trace.json") + + if run_profiler: + rank0_log(f"Profiling active. Traces will be saved at {trace_dir}") + + if not os.path.exists(trace_dir): + os.makedirs(trace_dir) + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=iter_to_start_profiling, + warmup=1, + active=num_iters_to_profile, + repeat=0, + ), + on_trace_ready=trace_handler, + profile_memory=True, + with_stack=False, + record_shapes=True, + ) as torch_profiler: + yield torch_profiler + else: + torch_profiler = contextlib.nullcontext() + yield None diff --git a/torchtrain/train_config.toml b/torchtrain/train_config.toml new file mode 100644 index 000000000..29c9fe085 --- /dev/null +++ b/torchtrain/train_config.toml @@ -0,0 +1,9 @@ +# TorchTrain Config.toml +[global] +dump_folder = "./torchtrain/outputs" + +[profiling] +run_profiler = true +save_traces_folder = "profiling/traces" +num_iters_to_profile = 1 +iter_to_start_profiling = 9 diff --git a/train.py b/train.py index a761a47f9..1877252b7 100644 --- a/train.py +++ b/train.py @@ -12,7 +12,8 @@ from torch.distributed.device_mesh import init_device_mesh from torch.utils.data import DataLoader -import contextlib +from torchtrain.profiling import maybe_run_profiler +from torchtrain.logging_utils import init_logger, rank0_log # torchtrain related @@ -24,13 +25,6 @@ ) -def lprint(msg=""): - print(f"Debug ++> {sys._getframe().f_back.f_lineno}: {msg}") - - -logger = getLogger() - - @dataclass class TrainState: step: int = 0 @@ -38,22 +32,6 @@ class TrainState: losses: List[float] = field(default_factory=list) -def rank0_log(msg): - if torch.distributed.get_rank() == 0: - logger.info(msg) - - -def init_logger(): - logger.setLevel(logging.INFO) - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - ch.setFormatter(formatter) - logger.addHandler(ch) - - def build_optimizer(model, args): # build optimizer if args.optimizer == "Adam": @@ -106,48 +84,13 @@ def main(args): optimizer = build_optimizer(model, args) # TODO: apply parallelisms, e.g. fsdp/tp - # TODO: add profiler - @contextlib.contextmanager - def maybe_run_profiler(args, *pos_args, **kwargs): - use_profiler: bool = args.run_profiler - - trace_dir = args.profile_folder - rank = torch.distributed.get_rank() - - def trace_handler(prof): - rank0_log(f"exporting profile traces to {trace_dir}") - prof.export_chrome_trace(f"{trace_dir}/rank{rank}_trace.json") - - if use_profiler: - if not os.path.exists(trace_dir): - os.makedirs(trace_dir) - - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1), - on_trace_ready=trace_handler, - profile_memory=True, - with_stack=False, - record_shapes=True, - ) as torch_profiler: - yield torch_profiler - else: - torch_profiler = contextlib.nullcontext() - yield None - - if args.run_profiler: - rank0_log(f"Profiling active. Traces will be saved at {args.profile_folder}") - # TODO: add metrics train_state = TrainState() # train loop model.train() - with maybe_run_profiler(args) as torch_profiler: + with maybe_run_profiler() as torch_profiler: while train_state.step < args.steps or args.steps == -1: train_state.step += 1 # get batch @@ -218,14 +161,6 @@ def trace_handler(prof): parser.add_argument( "--compile", action="store_true", help="Whether to compile the model." ) - parser.add_argument( - "--run_profiler", action="store_true", help="Whether to run the profiler." - ) - parser.add_argument( - "--profile_folder", - type=str, - default="./torchtrain/profiler", - help="Folder to save profile traces to.", - ) + args = parser.parse_args() main(args) From fa9306c97874bce7923850e948e95eee8b030e82 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 18 Jan 2024 11:01:02 -0800 Subject: [PATCH 4/4] profiling now uses profiling_frequency concept to enable repeated profiling, create custom named folders for traces --- torchtrain/profiling.py | 34 ++++++++++++++++++++-------------- torchtrain/train_config.toml | 4 ++-- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/torchtrain/profiling.py b/torchtrain/profiling.py index 3c447cd7f..c1044d458 100644 --- a/torchtrain/profiling.py +++ b/torchtrain/profiling.py @@ -29,23 +29,29 @@ def maybe_run_profiler(*pos_args, **kwargs): config = get_config_from_toml() # get user defined profiler settings - dump_dir = config["global"]["dump_folder"] run_profiler = config["profiling"].get("run_profiler", False) - save_trace_dir = config["profiling"]["save_traces_folder"] - trace_dir = os.path.join(dump_dir, save_trace_dir) - num_iters_to_profile = config["profiling"]["num_iters_to_profile"] - iter_to_start_profiling = config["profiling"]["iter_to_start_profiling"] - # profiler wants a warmup, so we reduce when to start by 1 - iter_to_start_profiling -= 1 + if run_profiler: + dump_dir = config["global"]["dump_folder"] + save_trace_dir = config["profiling"]["save_traces_folder"] + trace_dir = os.path.join(dump_dir, save_trace_dir) + iter_frequency = config["profiling"]["profile_every_x_iter"] - rank = torch.distributed.get_rank() + _global_iter_count = 0 - def trace_handler(prof): - rank0_log(f"exporting profile traces to {trace_dir}") - prof.export_chrome_trace(f"{trace_dir}/rank{rank}_trace.json") + rank = torch.distributed.get_rank() + + def trace_handler(prof): + nonlocal _global_iter_count + _global_iter_count += iter_frequency + curr_trace_dir_name = "iteration_" + str(_global_iter_count) + curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) + if not os.path.exists(curr_trace_dir): + os.makedirs(curr_trace_dir) + rank0_log(f"exporting profile traces to {curr_trace_dir}") + + prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") - if run_profiler: rank0_log(f"Profiling active. Traces will be saved at {trace_dir}") if not os.path.exists(trace_dir): @@ -57,9 +63,9 @@ def trace_handler(prof): torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule( - wait=iter_to_start_profiling, + wait=iter_frequency - 2, warmup=1, - active=num_iters_to_profile, + active=1, repeat=0, ), on_trace_ready=trace_handler, diff --git a/torchtrain/train_config.toml b/torchtrain/train_config.toml index 29c9fe085..a3b02917e 100644 --- a/torchtrain/train_config.toml +++ b/torchtrain/train_config.toml @@ -5,5 +5,5 @@ dump_folder = "./torchtrain/outputs" [profiling] run_profiler = true save_traces_folder = "profiling/traces" -num_iters_to_profile = 1 -iter_to_start_profiling = 9 +# profiling frequency - example: 10 means every 10th iter will be profiled +profile_every_x_iter = 10