Skip to content

Commit 23e2221

Browse files
authored
add bunch of cleanups and design principle section (#71)
1 parent c396c1f commit 23e2221

File tree

4 files changed

+8
-84
lines changed

4 files changed

+8
-84
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@ Note: This repository is currently under heavy development.
44

55
torchtrain contains PyTorch native parallelisms, tools and utilities to train large models.
66

7+
## Design Principles
8+
9+
TorchTrain is a native PyTorch library with various training techniques. While it utilizes the PyTorch ecosystem for things like data loading (i.e. HuggingFace datasets), the core functionality is written in PyTorch.
10+
11+
* Designed to be easy to understand, use and extend for different training purposes.
12+
* Minimal changes to the model code, when applying 1D/2D or 3D Parallelisms.
13+
* Modular components instead of monolithic codebase
14+
715
# Installation
816

917
Install PyTorch from source or install the latest pytorch nightly, then install requirements by

torchtrain/datasets/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

44
from torchtrain.datasets.alpaca import build_alpaca_data_loader
5-
from torchtrain.datasets.pad_batch_sequence import pad_batch_to_longest_seq
65
from torchtrain.datasets.tokenizer import create_tokenizer
76

87
__all__ = ["build_alpaca_data_loader", "create_tokenizer", "pad_batch_to_longest_seq"]

torchtrain/datasets/pad_batch_sequence.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

torchtrain/parallelisms/parallelize_llama.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
130130
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)),
131131
"feed_forward.w3": ColwiseParallel(),
132132
}
133-
# if layer_id == 0:
134-
# # in first transformer block we need to shard the input
135-
# layer_plan[""] = PrepareModuleInput(
136-
# input_layouts=(Replicate(), None),
137-
# desired_input_layouts=(Shard(0), None),
138-
# )
139133

140134
# adjust num_heads in attention layer to local heads
141135
attn_layer = transformer_block.attention

0 commit comments

Comments
 (0)