-
Notifications
You must be signed in to change notification settings - Fork 482
Unified config manager for toml and command line #76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
from torchtrain.config_manager import JobConfig | ||
|
||
|
||
class TestJobConfig: | ||
def test_command_line_args(self): | ||
config = JobConfig() | ||
config.parse_args([]) | ||
assert config.model.name == "llama" | ||
|
||
def test_job_config_file(self): | ||
config = JobConfig() | ||
config.parse_args( | ||
["--job.config_file", "./torchtrain/train_configs/train_config.toml"] | ||
) | ||
assert config.model.name == "llama" | ||
|
||
def test_job_file_does_not_exist(self): | ||
with pytest.raises(FileNotFoundError): | ||
config = JobConfig() | ||
config.parse_args(["--job.config_file", "ohno.toml"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
import argparse | ||
import sys | ||
from collections import defaultdict | ||
from typing import Union | ||
|
||
try: | ||
import tomllib | ||
except ModuleNotFoundError: | ||
import tomli as tomllib | ||
|
||
|
||
class JobConfig: | ||
""" | ||
A helper class to manage the train configuration. | ||
Semantics: | ||
- Default config is loaded from a toml file. If no toml file is provided, | ||
then the default config is loaded from argparse defaults. | ||
""" | ||
|
||
def parse_args(self, args_list: list = sys.argv[1:]): | ||
args = JobConfig.init_args_from_command_line(args_list) | ||
config_file = getattr(args, "job.config_file", None) | ||
if config_file is None: | ||
args_dict = self._args_to_two_level_dict(args) | ||
else: | ||
with open(config_file, "rb") as f: | ||
args_dict = tomllib.load(f) | ||
for k, v in args_dict.items(): | ||
class_type = type(k.title(), (), v) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ohh got it |
||
setattr(self, k, class_type()) | ||
self._validate_config() | ||
|
||
def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict: | ||
args_dict = defaultdict(defaultdict) | ||
for k, v in vars(args).items(): | ||
first_level_key, second_level_key = k.split(".", 1) | ||
args_dict[first_level_key][second_level_key] = v | ||
return args_dict | ||
|
||
def _validate_config(self): | ||
# TODO: Add more mandatory validations | ||
assert self.model.name and self.model.flavor and self.model.tokenizer_path | ||
return True | ||
|
||
@staticmethod | ||
def init_args_from_command_line( | ||
args_list: list = sys.argv[1:], | ||
) -> argparse.Namespace: | ||
""" | ||
Each argument starts with <prefix>_ which is the section name in the toml file | ||
followed by name of the option in the toml file. For ex, | ||
model.name translates to: | ||
[model] | ||
name | ||
in the toml file | ||
""" | ||
parser = argparse.ArgumentParser(description="TorchTrain arg parser.") | ||
parser.add_argument( | ||
"--job.config_file", | ||
type=str, | ||
default=None, | ||
help="job config file", | ||
) | ||
|
||
# misc configs | ||
parser.add_argument( | ||
"--job.dump_folder", | ||
type=str, | ||
default="./torchtrain/outputs", | ||
help="folder to dump job outputs", | ||
) | ||
|
||
# profiling configs | ||
parser.add_argument( | ||
"--profiling.run_profiler", | ||
action="store_true", | ||
help="enable pytorch profiler", | ||
) | ||
parser.add_argument( | ||
"--profiling.save_traces_folder", | ||
type=str, | ||
default="profiling/traces", | ||
help="trace file location", | ||
) | ||
parser.add_argument( | ||
"--profiling.profile_every_x_iter", | ||
type=int, | ||
default=10, | ||
help="collect profiler traces every x iterations", | ||
) | ||
# metrics configs | ||
parser.add_argument( | ||
"--metrics.log_freq", | ||
type=int, | ||
default=10, | ||
help="how often to log metrics to TensorBoard", | ||
) | ||
parser.add_argument( | ||
"--metrics.enable_tensorboard", | ||
action="store_true", | ||
help="how often to log metrics to TensorBoard", | ||
) | ||
parser.add_argument( | ||
"--metrics.save_tb_folder", | ||
type=str, | ||
default="tb", | ||
help="folder to dump tensorboard state", | ||
) | ||
|
||
# model configs | ||
parser.add_argument( | ||
"--model.name", | ||
type=str, | ||
default="llama", | ||
help="which model to train", | ||
) | ||
parser.add_argument( | ||
"--model.flavor", | ||
type=str, | ||
default="debugmodel", | ||
help="which model config to train", | ||
) | ||
parser.add_argument( | ||
"--model.tokenizer_path", | ||
type=str, | ||
default="./torchtrain/datasets/tokenizer/tokenizer.model", | ||
help="tokenizer path", | ||
) | ||
|
||
# optimizer configs | ||
parser.add_argument( | ||
"--optimizer.name", type=str, default="AdamW", help="optimizer to use" | ||
) | ||
parser.add_argument( | ||
"--optimizer.lr", type=float, default=8e-4, help="learning rate to use" | ||
) | ||
|
||
# training configs | ||
parser.add_argument( | ||
"--training.dataset", type=str, default="alpaca", help="dataset to use" | ||
) | ||
parser.add_argument( | ||
"--training.batch_size", type=int, default=8, help="batch size" | ||
) | ||
parser.add_argument( | ||
"--training.seq_len", type=int, default=2048, help="sequence length" | ||
) | ||
parser.add_argument( | ||
"--training.warmup_pct", | ||
type=float, | ||
default=0.20, | ||
help="percentage of total training steps to use for warmup", | ||
) | ||
parser.add_argument( | ||
"--training.max_norm", | ||
type=Union[float, int], | ||
default=1.0, | ||
help="max norm for gradient clipping", | ||
) | ||
parser.add_argument( | ||
"--training.steps", type=int, default=-1, help="how many train steps to run" | ||
) | ||
parser.add_argument( | ||
"--training.data_parallel_degree", | ||
type=int, | ||
default=-1, | ||
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.", | ||
) | ||
parser.add_argument( | ||
"--training.sequence_parallel_degree", | ||
type=int, | ||
default=1, | ||
help="Sequence Parallelism degree. 1 means disabled.", | ||
) | ||
parser.add_argument( | ||
"--training.pipeline_parallel_degree", | ||
type=int, | ||
default=1, | ||
help="Pipeline Parallelism degree (default of 1 means disabled)", | ||
) | ||
parser.add_argument( | ||
"--training.compile", | ||
action="store_true", | ||
help="Whether to compile the model.", | ||
) | ||
parser.add_argument( | ||
"--training.checkpoint_interval", | ||
type=int, | ||
default=3600, | ||
help=( | ||
"Checkpointing interval. The unit of measurement is in seconds or " | ||
"steps depending on --training.checkpoint-internval-type." | ||
), | ||
) | ||
parser.add_argument( | ||
"--training.checkpoint_interval_type", | ||
type=str, | ||
default="steps", | ||
help=( | ||
"The checkpointing interval unit of measurement." | ||
"The default value is step." | ||
), | ||
) | ||
parser.add_argument( | ||
"--training.checkpoint_folder", | ||
type=str, | ||
default="", | ||
help=( | ||
"The folder to store the checkpoints. If this is not specified or " | ||
"is an empty string, checkpointing is disabled." | ||
), | ||
) | ||
return parser.parse_args(args_list) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import torch | ||
import logging | ||
|
||
import torch | ||
|
||
logger = logging.getLogger() | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
PrepareModuleInput, | ||
RowwiseParallel, | ||
) | ||
from torchtrain.config_manager import JobConfig | ||
|
||
from torchtrain.logging_utils import rank0_log | ||
|
||
|
@@ -67,13 +68,14 @@ def partition_fn(name, module, device_mesh): | |
|
||
|
||
# Uses PTD FSDP AC wrapper | ||
def checkpoint_wrapper(module, config): | ||
# TODO: why is config needed here? | ||
def checkpoint_wrapper(module, job_config: JobConfig): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is so that we could later add an option "selective_ac" and do either full AC or selective AC |
||
return ptd_checkpoint_wrapper( | ||
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False | ||
) | ||
|
||
|
||
def parallelize_llama(model, world_mesh, parallel_dims, args): | ||
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | ||
""" | ||
Apply parallelisms to the model, including PTD parallelisms, and AC. | ||
|
||
|
@@ -87,7 +89,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): | |
if parallel_dims.sp_enabled: | ||
# First we apply Sequence Parallelism if it's enabled | ||
tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh | ||
sp_degree = args.sp_degree | ||
sp_degree = job_config.training.sequence_parallelism_degree | ||
# First: | ||
# 1. parallelize the first embedding and the last linear proj layer | ||
# 2. shard the first layer of transformer block | ||
|
@@ -169,7 +171,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): | |
# apply AC to each layer | ||
# before wrapping with FSDP, we need to make sure the layer is on GPU | ||
transformer_block = transformer_block.cuda() | ||
transformer_block = checkpoint_wrapper(transformer_block, args) | ||
transformer_block = checkpoint_wrapper(transformer_block, job_config) | ||
|
||
# Wraps each layer with FSDP | ||
model.layers[layer_id] = wrap(transformer_block) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if for the case where the toml file does not have all the defaults, currently we would not populate the fields iiuc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this assumes toml file is complete and we dont want to mix defaults. If we want to implicitly pull missing defaults, we can do that but I was not sure if thats a good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can probably do a follow up later if we found that's useful :)