Skip to content

Add profiler #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ __pycache__
.DS_Store
*.egg-info
build
outputs

# data
data
out
wandb
*.model
*.json
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torch
sentencepiece
datasets
tomli >= 1.1.0 ; python_version < "3.11"
2 changes: 1 addition & 1 deletion scripts/train_llama_local.sh → run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ NGPU=8
MP=4

torchrun --nproc_per_node=${NGPU} \
${TRAINER_DIR}/train.py
train.py --steps 10
20 changes: 20 additions & 0 deletions torchtrain/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import logging

logger = logging.getLogger()


def rank0_log(msg):
if torch.distributed.get_rank() == 0:
logger.info(msg)


def init_logger():
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
logger.addHandler(ch)
79 changes: 79 additions & 0 deletions torchtrain/profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import contextlib
import os
import torch

try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib

from torchtrain.logging_utils import rank0_log

_config_file = "./torchtrain/train_config.toml"


def get_config_from_toml(config_path: str = _config_file) -> dict:
"""
Reads a config file in TOML format and returns a dictionary.
"""
with open(config_path, "rb") as f:
config = tomllib.load(f)
return config


@contextlib.contextmanager
def maybe_run_profiler(*pos_args, **kwargs):
config = get_config_from_toml()

# get user defined profiler settings
run_profiler = config["profiling"].get("run_profiler", False)

if run_profiler:
dump_dir = config["global"]["dump_folder"]
save_trace_dir = config["profiling"]["save_traces_folder"]
trace_dir = os.path.join(dump_dir, save_trace_dir)
iter_frequency = config["profiling"]["profile_every_x_iter"]

_global_iter_count = 0

rank = torch.distributed.get_rank()

def trace_handler(prof):
nonlocal _global_iter_count
_global_iter_count += iter_frequency
curr_trace_dir_name = "iteration_" + str(_global_iter_count)
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
if not os.path.exists(curr_trace_dir):
os.makedirs(curr_trace_dir)
rank0_log(f"exporting profile traces to {curr_trace_dir}")

prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")

rank0_log(f"Profiling active. Traces will be saved at {trace_dir}")

if not os.path.exists(trace_dir):
os.makedirs(trace_dir)

with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=iter_frequency - 2,
warmup=1,
active=1,
repeat=0,
),
on_trace_ready=trace_handler,
profile_memory=True,
with_stack=False,
record_shapes=True,
) as torch_profiler:
yield torch_profiler
else:
torch_profiler = contextlib.nullcontext()
yield None
9 changes: 9 additions & 0 deletions torchtrain/train_config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# TorchTrain Config.toml
[global]
dump_folder = "./torchtrain/outputs"

[profiling]
run_profiler = true
save_traces_folder = "profiling/traces"
# profiling frequency - example: 10 means every 10th iter will be profiled
profile_every_x_iter = 10
129 changes: 79 additions & 50 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,25 @@
from typing import List
import logging
from logging import getLogger
import sys # for logging

# torch imports
import torch
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.utils.data import DataLoader

# torchtrain related
from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer
from torchtrain.datasets import create_tokenizer, dataset_cls_map, pad_batch_to_longest_seq
from torchtrain.profiling import maybe_run_profiler
from torchtrain.logging_utils import init_logger, rank0_log


logger = getLogger()
# torchtrain related
from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer
from torchtrain.datasets import (
create_tokenizer,
dataset_cls_map,
pad_batch_to_longest_seq,
)


@dataclass
Expand All @@ -26,18 +32,6 @@ class TrainState:
losses: List[float] = field(default_factory=list)


def rank0_log(msg):
if torch.distributed.get_rank() == 0:
logger.info(msg)

def init_logger():
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
logger.addHandler(ch)

def build_optimizer(model, args):
# build optimizer
if args.optimizer == "Adam":
Expand All @@ -58,7 +52,9 @@ def main(args):
# distributed init
world_size = int(os.environ["WORLD_SIZE"])
dp_degree = world_size // args.tp_degree
world_mesh = init_device_mesh(device_type, (dp_degree, args.tp_degree), mesh_dim_names=("dp", "tp"))
world_mesh = init_device_mesh(
device_type, (dp_degree, args.tp_degree), mesh_dim_names=("dp", "tp")
)

model_name = args.model
# build tokenizer
Expand All @@ -67,7 +63,11 @@ def main(args):

# build dataloader
dataset_cls = dataset_cls_map[args.dataset]
data_loader = DataLoader(dataset_cls(tokenizer), batch_size=args.batch_size, collate_fn=pad_batch_to_longest_seq)
data_loader = DataLoader(
dataset_cls(tokenizer),
batch_size=args.batch_size,
collate_fn=pad_batch_to_longest_seq,
)

# build model
# TODO: add meta initialization
Expand All @@ -84,54 +84,83 @@ def main(args):
optimizer = build_optimizer(model, args)

# TODO: apply parallelisms, e.g. fsdp/tp
# TODO: add profiler
# TODO: add metrics
train_state = TrainState()

# train loop
model.train()

while train_state.step < args.steps or args.steps == -1:
train_state.step += 1
# get batch
batch = next(iter(data_loader))
input_ids, labels = batch
input_ids = input_ids.to(device_type)
labels = labels.to(device_type)
with maybe_run_profiler() as torch_profiler:
while train_state.step < args.steps or args.steps == -1:
train_state.step += 1
# get batch
batch = next(iter(data_loader))
input_ids, labels = batch
input_ids = input_ids.to(device_type)
labels = labels.to(device_type)

# forward
pred = model(input_ids)
tok_loss = F.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1), reduction="none"
)
loss = tok_loss.mean()

# forward
pred = model(input_ids)
tok_loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1), reduction="none")
loss = tok_loss.mean()
# backward
loss.backward()
# TODO: add grad scaler

# backward
loss.backward()
# TODO: add grad scaler
# optimizer step
optimizer.step()
optimizer.zero_grad()

# optimizer step
optimizer.step()
optimizer.zero_grad()
# if profiler is active
if torch_profiler:
torch_profiler.step()

train_state.current_loss = loss.item()
train_state.losses.append(train_state.current_loss)
train_state.current_loss = loss.item()
train_state.losses.append(train_state.current_loss)

rank0_log(f"current loss: {train_state.current_loss}")
rank0_log(f"current loss: {train_state.current_loss}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='TorchTrain arg parser.')
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"])

parser.add_argument('--model', type=str, default="llama", help="which model to train")
parser.add_argument('--model_conf', type=str, default="debugmodel", help="which model config to train")
parser.add_argument('--dataset', type=str, default="alpaca", help="dataset to use")
parser.add_argument('--tokenizer_path', type=str, default="torchtrain/datasets/tokenizer.model", help="tokenizer path")
parser.add_argument('--batch_size', type=int, default=8, help="batch size")
parser.add_argument('--optimizer', type=str, default="AdamW", help="optimizer to use")
parser.add_argument('--lr', type=float, default=2e-5, help="optimizer to use")
parser.add_argument('--steps', type=int, default=-1, help="how many train steps to run")
parser.add_argument('--tp_degree', type=int, default=LOCAL_WORLD_SIZE, help="Tensor/Sequence Parallelism degree")
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument(
"--model", type=str, default="llama", help="which model to train"
)
parser.add_argument(
"--model_conf",
type=str,
default="debugmodel",
help="which model config to train",
)
parser.add_argument("--dataset", type=str, default="alpaca", help="dataset to use")
parser.add_argument(
"--tokenizer_path",
type=str,
default="./torchtrain/datasets/tokenizer/tokenizer.model",
help="tokenizer path",
)
parser.add_argument("--batch_size", type=int, default=8, help="batch size")
parser.add_argument(
"--optimizer", type=str, default="AdamW", help="optimizer to use"
)
parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use")
parser.add_argument(
"--steps", type=int, default=-1, help="how many train steps to run"
)
parser.add_argument(
"--tp_degree",
type=int,
default=LOCAL_WORLD_SIZE,
help="Tensor/Sequence Parallelism degree",
)
parser.add_argument(
"--compile", action="store_true", help="Whether to compile the model."
)

args = parser.parse_args()
main(args)