Skip to content

Unified config manager for toml and command line #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-""}
# Please adjust this to a longer interval period. The unit of measurement is in steps.
CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}

CONFIG_FILE=${CONFIG_FILE:-"./torchtrain/train_configs/train_config.toml"}

torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --steps 10 \
--model ${MODEL} --model_conf ${MODEL_CONF} \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
--compile \
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
train.py --job.config_file ${CONFIG_FILE}
Empty file added test/__init__.py
Empty file.
21 changes: 21 additions & 0 deletions test/test_job_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
from torchtrain.config_manager import JobConfig


class TestJobConfig:
def test_command_line_args(self):
config = JobConfig()
config.parse_args([])
assert config.model.name == "llama"

def test_job_config_file(self):
config = JobConfig()
config.parse_args(
["--job.config_file", "./torchtrain/train_configs/train_config.toml"]
)
assert config.model.name == "llama"

def test_job_file_does_not_exist(self):
with pytest.raises(FileNotFoundError):
config = JobConfig()
config.parse_args(["--job.config_file", "ohno.toml"])
1 change: 1 addition & 0 deletions test/test_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.


# delete me after adding real tests..
class Test:
def test_test(self):
Expand Down
215 changes: 215 additions & 0 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import argparse
import sys
from collections import defaultdict
from typing import Union

try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib


class JobConfig:
"""
A helper class to manage the train configuration.
Semantics:
- Default config is loaded from a toml file. If no toml file is provided,
then the default config is loaded from argparse defaults.
"""

def parse_args(self, args_list: list = sys.argv[1:]):
args = JobConfig.init_args_from_command_line(args_list)
config_file = getattr(args, "job.config_file", None)
if config_file is None:
args_dict = self._args_to_two_level_dict(args)
else:
with open(config_file, "rb") as f:
args_dict = tomllib.load(f)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if for the case where the toml file does not have all the defaults, currently we would not populate the fields iiuc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this assumes toml file is complete and we dont want to mix defaults. If we want to implicitly pull missing defaults, we can do that but I was not sure if thats a good idea.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably do a follow up later if we found that's useful :)

for k, v in args_dict.items():
class_type = type(k.title(), (), v)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the title() call here mean? iiuc k here is a dict and it does not have a title method? wondering if this is toml specific and if we don't use toml this would error out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

title() just makes sure that the type of the class is title case. For ex. training would be of type Training.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh got it

setattr(self, k, class_type())
self._validate_config()

def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
args_dict = defaultdict(defaultdict)
for k, v in vars(args).items():
first_level_key, second_level_key = k.split(".", 1)
args_dict[first_level_key][second_level_key] = v
return args_dict

def _validate_config(self):
# TODO: Add more mandatory validations
assert self.model.name and self.model.flavor and self.model.tokenizer_path
return True

@staticmethod
def init_args_from_command_line(
args_list: list = sys.argv[1:],
) -> argparse.Namespace:
"""
Each argument starts with <prefix>_ which is the section name in the toml file
followed by name of the option in the toml file. For ex,
model.name translates to:
[model]
name
in the toml file
"""
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
parser.add_argument(
"--job.config_file",
type=str,
default=None,
help="job config file",
)

# misc configs
parser.add_argument(
"--job.dump_folder",
type=str,
default="./torchtrain/outputs",
help="folder to dump job outputs",
)

# profiling configs
parser.add_argument(
"--profiling.run_profiler",
action="store_true",
help="enable pytorch profiler",
)
parser.add_argument(
"--profiling.save_traces_folder",
type=str,
default="profiling/traces",
help="trace file location",
)
parser.add_argument(
"--profiling.profile_every_x_iter",
type=int,
default=10,
help="collect profiler traces every x iterations",
)
# metrics configs
parser.add_argument(
"--metrics.log_freq",
type=int,
default=10,
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.enable_tensorboard",
action="store_true",
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.save_tb_folder",
type=str,
default="tb",
help="folder to dump tensorboard state",
)

# model configs
parser.add_argument(
"--model.name",
type=str,
default="llama",
help="which model to train",
)
parser.add_argument(
"--model.flavor",
type=str,
default="debugmodel",
help="which model config to train",
)
parser.add_argument(
"--model.tokenizer_path",
type=str,
default="./torchtrain/datasets/tokenizer/tokenizer.model",
help="tokenizer path",
)

# optimizer configs
parser.add_argument(
"--optimizer.name", type=str, default="AdamW", help="optimizer to use"
)
parser.add_argument(
"--optimizer.lr", type=float, default=8e-4, help="learning rate to use"
)

# training configs
parser.add_argument(
"--training.dataset", type=str, default="alpaca", help="dataset to use"
)
parser.add_argument(
"--training.batch_size", type=int, default=8, help="batch size"
)
parser.add_argument(
"--training.seq_len", type=int, default=2048, help="sequence length"
)
parser.add_argument(
"--training.warmup_pct",
type=float,
default=0.20,
help="percentage of total training steps to use for warmup",
)
parser.add_argument(
"--training.max_norm",
type=Union[float, int],
default=1.0,
help="max norm for gradient clipping",
)
parser.add_argument(
"--training.steps", type=int, default=-1, help="how many train steps to run"
)
parser.add_argument(
"--training.data_parallel_degree",
type=int,
default=-1,
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
)
parser.add_argument(
"--training.sequence_parallel_degree",
type=int,
default=1,
help="Sequence Parallelism degree. 1 means disabled.",
)
parser.add_argument(
"--training.pipeline_parallel_degree",
type=int,
default=1,
help="Pipeline Parallelism degree (default of 1 means disabled)",
)
parser.add_argument(
"--training.compile",
action="store_true",
help="Whether to compile the model.",
)
parser.add_argument(
"--training.checkpoint_interval",
type=int,
default=3600,
help=(
"Checkpointing interval. The unit of measurement is in seconds or "
"steps depending on --training.checkpoint-internval-type."
),
)
parser.add_argument(
"--training.checkpoint_interval_type",
type=str,
default="steps",
help=(
"The checkpointing interval unit of measurement."
"The default value is step."
),
)
parser.add_argument(
"--training.checkpoint_folder",
type=str,
default="",
help=(
"The folder to store the checkpoints. If this is not specified or "
"is an empty string, checkpointing is disabled."
),
)
return parser.parse_args(args_list)
3 changes: 2 additions & 1 deletion torchtrain/datasets/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import os
from abc import ABC, abstractmethod
from typing import List
from logging import getLogger
from typing import List

from sentencepiece import SentencePieceProcessor

Expand Down Expand Up @@ -48,6 +48,7 @@ def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> TokenizerIf:

class SentencePieceTokenizer(TokenizerIf):
"""tokenizing and encoding/decoding text using SentencePiece."""

def __init__(self, tokenizer_path: str):
"""
Initializes the Tokenizer with a SentencePiece model.
Expand Down
3 changes: 2 additions & 1 deletion torchtrain/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import logging

import torch

logger = logging.getLogger()


Expand Down
9 changes: 6 additions & 3 deletions torchtrain/lr_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# All rights reserved.

from torch.optim.lr_scheduler import LambdaLR
from torchtrain.config_manager import JobConfig

# global states for scheduling
# these are needed as LambdaLR does not support argument passing
Expand Down Expand Up @@ -29,11 +30,13 @@ def linear_warmup_linear_decay(current_step: int) -> float:
return curr_adjustment


def get_lr_scheduler(optimizer, args):
def get_lr_scheduler(optimizer, job_config: JobConfig):
"""Build a linear warmup and linear decay scheduler"""
global _warmup_steps, _decay_steps
_warmup_steps = max(int(args.steps * args.warmup_pct), 2)
_decay_steps = float(max(1, args.steps - _warmup_steps))
_warmup_steps = max(
int(job_config.training.steps * job_config.training.warmup_pct), 2
)
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))

warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
return warmup_scheduler
12 changes: 5 additions & 7 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torchtrain.config_manager import JobConfig

from torchtrain.logging_utils import rank0_log
from torchtrain.profiling import get_config_from_toml

_gb_in_bytes = 1024 * 1024 * 1024
_mb_in_bytes = 1024 * 1024
Expand Down Expand Up @@ -214,16 +214,14 @@ def close(self):
self.writer.close()


def build_metric_logger(tag: Optional[str] = None):
config = get_config_from_toml()

dump_dir = config["global"]["dump_folder"]
save_tb_folder = config["metrics"]["save_tb_folder"]
def build_metric_logger(config: JobConfig, tag: Optional[str] = None):
dump_dir = config.job.dump_folder
save_tb_folder = config.metrics.save_tb_folder
# since we don't have run id yet, use current minute as identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)

enable_tb = config["metrics"].get("enable_tensorboard", False)
enable_tb = config.metrics.enable_tensorboard
if enable_tb:
rank0_log(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}."
Expand Down
10 changes: 6 additions & 4 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
PrepareModuleInput,
RowwiseParallel,
)
from torchtrain.config_manager import JobConfig

from torchtrain.logging_utils import rank0_log

Expand Down Expand Up @@ -67,13 +68,14 @@ def partition_fn(name, module, device_mesh):


# Uses PTD FSDP AC wrapper
def checkpoint_wrapper(module, config):
# TODO: why is config needed here?
def checkpoint_wrapper(module, job_config: JobConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is so that we could later add an option "selective_ac" and do either full AC or selective AC

return ptd_checkpoint_wrapper(
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
)


def parallelize_llama(model, world_mesh, parallel_dims, args):
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply parallelisms to the model, including PTD parallelisms, and AC.

Expand All @@ -87,7 +89,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
if parallel_dims.sp_enabled:
# First we apply Sequence Parallelism if it's enabled
tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh
sp_degree = args.sp_degree
sp_degree = job_config.training.sequence_parallelism_degree
# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
Expand Down Expand Up @@ -169,7 +171,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
# apply AC to each layer
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()
transformer_block = checkpoint_wrapper(transformer_block, args)
transformer_block = checkpoint_wrapper(transformer_block, job_config)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)
Expand Down
Loading