-
Notifications
You must be signed in to change notification settings - Fork 460
Bring LLaMa 3.1 405B to TorchTitan family #481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
58cd625
94b609a
6327d35
1cc3ac9
a6a6507
57ba91b
4bea482
97b2de7
6a862ee
441e22c
659a854
4b1693d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# torchtitan Config.toml | ||
# NOTE: this toml config is a preset for 128 H100 GPUs. | ||
|
||
[job] | ||
dump_folder = "./outputs" | ||
description = "Llama 3 405B training" | ||
|
||
[profiling] | ||
enable_profiling = true | ||
save_traces_folder = "profile_trace" | ||
profile_freq = 100 | ||
|
||
[metrics] | ||
log_freq = 10 | ||
enable_tensorboard = true | ||
save_tb_folder = "tb" | ||
|
||
[model] | ||
name = "llama3" | ||
flavor = "405B" | ||
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm | ||
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" | ||
|
||
[optimizer] | ||
name = "AdamW" | ||
lr = 0.8e-4 | ||
|
||
[training] | ||
batch_size = 2 | ||
seq_len = 8192 | ||
warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps | ||
max_norm = 1.0 # grad norm clipping | ||
steps = 3000 | ||
data_parallel_degree = -1 | ||
tensor_parallel_degree = 8 # 8-way TP | ||
enable_float8_linear = false | ||
compile = false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wondering if there were any blockers on enabling compile? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know. Any reason why we didn't enable it for 70B? @tianyu-l For 405b, the ideal toml will have 3D for sure. Whether to have compile enabled is something we can discuss. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously there were issues with compile. Now there's no blocker, maybe just extra compile time? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me try with compile if it works, I will update the toml file in a different PR. |
||
dataset = "c4" | ||
|
||
[experimental] | ||
pipeline_parallel_degree = 1 | ||
|
||
[checkpoint] | ||
enable_checkpoint = false | ||
folder = "checkpoint" | ||
interval_type = "steps" | ||
interval = 500 | ||
model_weights_only = false | ||
export_dtype = "float32" | ||
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] | ||
|
||
[activation_checkpoint] | ||
mode = 'full' # ['none', 'selective', 'full'] |
Uh oh!
There was an error while loading. Please reload this page.