Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,13 @@
# LICENSE file in the root directory of this source tree.

_supported_experiments = frozenset(
["flux", "llama4", "qwen3", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"]
[
"flux",
"llama4",
"qwen3",
"simple_fsdp.llama3",
"simple_fsdp.deepseek_v3",
"vlm",
"joint_graph_runner.llama3",
]
)
13 changes: 13 additions & 0 deletions torchtitan/experiments/joint_graph_runner/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## Joint Graph Runner

Exploring toolkit-style use of the compiler stack for authoring parallel models.

Joint Graph based Training Prototype:

Llama3
- User code: SimpleFSDP + TP
- Trace joint
- Apply passes to the joint
- Run using the Joint Graph Runner

Run with: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name joint_graph_runner.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4
35 changes: 35 additions & 0 deletions torchtitan/experiments/joint_graph_runner/llama3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.experiments.joint_graph_runner.llama3.parallelize import (
parallelize_llama,
)

from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer
from torchtitan.models.llama3 import llama3_configs, pipeline_llama
from torchtitan.protocols.train_spec import TrainSpec


def get_train_spec() -> TrainSpec:
return TrainSpec(
name="joint_graph_runner.llama3",
model_cls=SimpleFSDPTransformer,
model_args=llama3_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
Loading
Loading