diff --git a/.gitignore b/.gitignore index 185245ae7..5e057e7ca 100644 --- a/.gitignore +++ b/.gitignore @@ -3,9 +3,11 @@ __pycache__ .DS_Store *.egg-info build +outputs # data data out wandb *.model +*.json diff --git a/requirements.txt b/requirements.txt index 2b0df57a9..9aa6e395f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch sentencepiece datasets +tomli >= 1.1.0 ; python_version < "3.11" diff --git a/scripts/train_llama_local.sh b/run_llama_train.sh similarity index 85% rename from scripts/train_llama_local.sh rename to run_llama_train.sh index 7a5b6e282..cf7bc4ffa 100755 --- a/scripts/train_llama_local.sh +++ b/run_llama_train.sh @@ -9,4 +9,4 @@ NGPU=8 MP=4 torchrun --nproc_per_node=${NGPU} \ -${TRAINER_DIR}/train.py +train.py --steps 10 diff --git a/torchtrain/logging_utils.py b/torchtrain/logging_utils.py new file mode 100644 index 000000000..34d5a1ff6 --- /dev/null +++ b/torchtrain/logging_utils.py @@ -0,0 +1,20 @@ +import torch +import logging + +logger = logging.getLogger() + + +def rank0_log(msg): + if torch.distributed.get_rank() == 0: + logger.info(msg) + + +def init_logger(): + logger.setLevel(logging.INFO) + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ch.setFormatter(formatter) + logger.addHandler(ch) diff --git a/torchtrain/profiling.py b/torchtrain/profiling.py new file mode 100644 index 000000000..c1044d458 --- /dev/null +++ b/torchtrain/profiling.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +import contextlib +import os +import torch + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib + +from torchtrain.logging_utils import rank0_log + +_config_file = "./torchtrain/train_config.toml" + + +def get_config_from_toml(config_path: str = _config_file) -> dict: + """ + Reads a config file in TOML format and returns a dictionary. + """ + with open(config_path, "rb") as f: + config = tomllib.load(f) + return config + + +@contextlib.contextmanager +def maybe_run_profiler(*pos_args, **kwargs): + config = get_config_from_toml() + + # get user defined profiler settings + run_profiler = config["profiling"].get("run_profiler", False) + + if run_profiler: + dump_dir = config["global"]["dump_folder"] + save_trace_dir = config["profiling"]["save_traces_folder"] + trace_dir = os.path.join(dump_dir, save_trace_dir) + iter_frequency = config["profiling"]["profile_every_x_iter"] + + _global_iter_count = 0 + + rank = torch.distributed.get_rank() + + def trace_handler(prof): + nonlocal _global_iter_count + _global_iter_count += iter_frequency + curr_trace_dir_name = "iteration_" + str(_global_iter_count) + curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) + if not os.path.exists(curr_trace_dir): + os.makedirs(curr_trace_dir) + rank0_log(f"exporting profile traces to {curr_trace_dir}") + + prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") + + rank0_log(f"Profiling active. Traces will be saved at {trace_dir}") + + if not os.path.exists(trace_dir): + os.makedirs(trace_dir) + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=iter_frequency - 2, + warmup=1, + active=1, + repeat=0, + ), + on_trace_ready=trace_handler, + profile_memory=True, + with_stack=False, + record_shapes=True, + ) as torch_profiler: + yield torch_profiler + else: + torch_profiler = contextlib.nullcontext() + yield None diff --git a/torchtrain/train_config.toml b/torchtrain/train_config.toml new file mode 100644 index 000000000..a3b02917e --- /dev/null +++ b/torchtrain/train_config.toml @@ -0,0 +1,9 @@ +# TorchTrain Config.toml +[global] +dump_folder = "./torchtrain/outputs" + +[profiling] +run_profiler = true +save_traces_folder = "profiling/traces" +# profiling frequency - example: 10 means every 10th iter will be profiled +profile_every_x_iter = 10 diff --git a/train.py b/train.py index 3b931a7ec..1877252b7 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,7 @@ from typing import List import logging from logging import getLogger +import sys # for logging # torch imports import torch @@ -11,12 +12,17 @@ from torch.distributed.device_mesh import init_device_mesh from torch.utils.data import DataLoader -# torchtrain related -from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer -from torchtrain.datasets import create_tokenizer, dataset_cls_map, pad_batch_to_longest_seq +from torchtrain.profiling import maybe_run_profiler +from torchtrain.logging_utils import init_logger, rank0_log -logger = getLogger() +# torchtrain related +from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer +from torchtrain.datasets import ( + create_tokenizer, + dataset_cls_map, + pad_batch_to_longest_seq, +) @dataclass @@ -26,18 +32,6 @@ class TrainState: losses: List[float] = field(default_factory=list) -def rank0_log(msg): - if torch.distributed.get_rank() == 0: - logger.info(msg) - -def init_logger(): - logger.setLevel(logging.INFO) - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - ch.setFormatter(formatter) - logger.addHandler(ch) - def build_optimizer(model, args): # build optimizer if args.optimizer == "Adam": @@ -58,7 +52,9 @@ def main(args): # distributed init world_size = int(os.environ["WORLD_SIZE"]) dp_degree = world_size // args.tp_degree - world_mesh = init_device_mesh(device_type, (dp_degree, args.tp_degree), mesh_dim_names=("dp", "tp")) + world_mesh = init_device_mesh( + device_type, (dp_degree, args.tp_degree), mesh_dim_names=("dp", "tp") + ) model_name = args.model # build tokenizer @@ -67,7 +63,11 @@ def main(args): # build dataloader dataset_cls = dataset_cls_map[args.dataset] - data_loader = DataLoader(dataset_cls(tokenizer), batch_size=args.batch_size, collate_fn=pad_batch_to_longest_seq) + data_loader = DataLoader( + dataset_cls(tokenizer), + batch_size=args.batch_size, + collate_fn=pad_batch_to_longest_seq, + ) # build model # TODO: add meta initialization @@ -84,54 +84,83 @@ def main(args): optimizer = build_optimizer(model, args) # TODO: apply parallelisms, e.g. fsdp/tp - # TODO: add profiler # TODO: add metrics train_state = TrainState() # train loop model.train() - while train_state.step < args.steps or args.steps == -1: - train_state.step += 1 - # get batch - batch = next(iter(data_loader)) - input_ids, labels = batch - input_ids = input_ids.to(device_type) - labels = labels.to(device_type) + with maybe_run_profiler() as torch_profiler: + while train_state.step < args.steps or args.steps == -1: + train_state.step += 1 + # get batch + batch = next(iter(data_loader)) + input_ids, labels = batch + input_ids = input_ids.to(device_type) + labels = labels.to(device_type) + + # forward + pred = model(input_ids) + tok_loss = F.cross_entropy( + pred.flatten(0, 1), labels.flatten(0, 1), reduction="none" + ) + loss = tok_loss.mean() - # forward - pred = model(input_ids) - tok_loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1), reduction="none") - loss = tok_loss.mean() + # backward + loss.backward() + # TODO: add grad scaler - # backward - loss.backward() - # TODO: add grad scaler + # optimizer step + optimizer.step() + optimizer.zero_grad() - # optimizer step - optimizer.step() - optimizer.zero_grad() + # if profiler is active + if torch_profiler: + torch_profiler.step() - train_state.current_loss = loss.item() - train_state.losses.append(train_state.current_loss) + train_state.current_loss = loss.item() + train_state.losses.append(train_state.current_loss) - rank0_log(f"current loss: {train_state.current_loss}") + rank0_log(f"current loss: {train_state.current_loss}") if __name__ == "__main__": - parser = argparse.ArgumentParser(description='TorchTrain arg parser.') + parser = argparse.ArgumentParser(description="TorchTrain arg parser.") LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"]) - parser.add_argument('--model', type=str, default="llama", help="which model to train") - parser.add_argument('--model_conf', type=str, default="debugmodel", help="which model config to train") - parser.add_argument('--dataset', type=str, default="alpaca", help="dataset to use") - parser.add_argument('--tokenizer_path', type=str, default="torchtrain/datasets/tokenizer.model", help="tokenizer path") - parser.add_argument('--batch_size', type=int, default=8, help="batch size") - parser.add_argument('--optimizer', type=str, default="AdamW", help="optimizer to use") - parser.add_argument('--lr', type=float, default=2e-5, help="optimizer to use") - parser.add_argument('--steps', type=int, default=-1, help="how many train steps to run") - parser.add_argument('--tp_degree', type=int, default=LOCAL_WORLD_SIZE, help="Tensor/Sequence Parallelism degree") - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument( + "--model", type=str, default="llama", help="which model to train" + ) + parser.add_argument( + "--model_conf", + type=str, + default="debugmodel", + help="which model config to train", + ) + parser.add_argument("--dataset", type=str, default="alpaca", help="dataset to use") + parser.add_argument( + "--tokenizer_path", + type=str, + default="./torchtrain/datasets/tokenizer/tokenizer.model", + help="tokenizer path", + ) + parser.add_argument("--batch_size", type=int, default=8, help="batch size") + parser.add_argument( + "--optimizer", type=str, default="AdamW", help="optimizer to use" + ) + parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use") + parser.add_argument( + "--steps", type=int, default=-1, help="how many train steps to run" + ) + parser.add_argument( + "--tp_degree", + type=int, + default=LOCAL_WORLD_SIZE, + help="Tensor/Sequence Parallelism degree", + ) + parser.add_argument( + "--compile", action="store_true", help="Whether to compile the model." + ) args = parser.parse_args() main(args)