Skip to content

Commit db3e3d3

Browse files
committed
Update
[ghstack-poisoned]
2 parents c038451 + 8c81863 commit db3e3d3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1928
-1683
lines changed

.github/workflows/torchao_experimental_test.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ jobs:
3333
- name: Install requirements
3434
run: |
3535
conda activate venv
36-
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu"
36+
# Install executorch first because it installs its own version
37+
# of torch and torchao, which we do not want to use
38+
pip install executorch
39+
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall
3740
pip install numpy
3841
pip install pytest
3942
pip install parameterized
@@ -57,6 +60,12 @@ jobs:
5760
sh build_and_run_tests.sh
5861
rm -rf /tmp/cmake-out
5962
popd
63+
- name: ET ops build
64+
run: |
65+
conda activate venv
66+
pushd torchao/experimental
67+
sh build_torchao_ops.sh executorch
68+
popd
6069
6170
test-mps-ops:
6271
strategy:

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})
115115
ADAM takes 2x as much memory as the model params so we can quantize the optimizer state to either 8 or 4 bit effectively reducing the optimizer VRAM requirements by 2x or 4x respectively over an fp16 baseline
116116

117117
```python
118-
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8
118+
from torchao.optim import AdamW8bit, AdamW4bit, AdamWFp8
119119
optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
120120
```
121121

122-
In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim)
122+
In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/optim)
123123

124-
We also have support for [single GPU CPU offloading](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload) where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can **reduce your VRAM requirements by 60%**
124+
We also have support for [single GPU CPU offloading](https://github.com/pytorch/ao/tree/main/torchao/optim#optimizer-cpu-offload) where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can **reduce your VRAM requirements by 60%**
125125

126126
```python
127127
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)

benchmarks/benchmark_low_bit_adam.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from torchvision.transforms import v2
3535
from tqdm import tqdm
3636

37-
from torchao.prototype import low_bit_optim
37+
from torchao import optim
3838
from torchao.utils import get_available_devices
3939

4040
_DEVICE = get_available_devices()[-1]
@@ -43,9 +43,9 @@
4343
OPTIM_MAP = dict(
4444
AdamW=partial(torch.optim.AdamW, fused=True),
4545
AdamW8bitBnb=bnb.optim.AdamW8bit,
46-
AdamW8bitAo=low_bit_optim.AdamW8bit,
47-
AdamWFp8Ao=low_bit_optim.AdamWFp8,
48-
AdamW4bitAo=low_bit_optim.AdamW4bit,
46+
AdamW8bitAo=optim.AdamW8bit,
47+
AdamWFp8Ao=optim.AdamWFp8,
48+
AdamW4bitAo=optim.AdamW4bit,
4949
)
5050

5151
try:
@@ -249,12 +249,10 @@ def evaluate_model(model, args):
249249
optim_cls = OPTIM_MAP[args.optim]
250250

251251
if args.optim_cpu_offload == "ao":
252-
optim_cls = partial(
253-
low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls
254-
)
252+
optim_cls = partial(optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
255253
elif args.optim_cpu_offload == "ao_offload_grads":
256254
optim_cls = partial(
257-
low_bit_optim.CPUOffloadOptimizer,
255+
optim.CPUOffloadOptimizer,
258256
optimizer_class=optim_cls,
259257
offload_gradients=True,
260258
)

benchmarks/float8/float8_roofline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,11 @@ def get_gemm_times(
184184
elif float8_recipe_name in ("rowwise", "rowwise_with_gw_hp"):
185185
scale_a = torch.ones(M, 1, device=device)
186186
scale_b = torch.ones(1, N, device=device)
187+
elif mx_recipe_name == "mxfp8_cublas":
188+
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
189+
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
187190
else:
188-
assert False, "TODO add mx gemm here"
191+
assert False, "TODO add cutlass mx gemm here"
189192

190193
def do_matmul(A, B):
191194
return torch._scaled_mm(

benchmarks/float8/training/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ The `float8_training_benchmark.sh` script in this directory can be used to launc
44

55
## Usage
66

7-
Example: `TORCHTITAN_ROOT=${HOME}/torchtitan FLOAT8_RECIPE=rowwise ./float8_training_benchmark.sh`
7+
Example: `TORCHTITAN_ROOT=${HOME}/torchtitan FLOAT8_RECIPE_WITH_BEST_SETTINGS=rowwise ./float8_training_benchmark.sh`
88

99
Training parameters can be configured via environment variables.
1010

1111
- Required:
12-
- `TORCHTITAN_ROOT`
12+
- `TORCHTITAN_ROOT`: Root directory of torchtitan in your local filesystem
1313
- Optional:
14-
- `RECIPE`: rowwise|tensorwise. defaults to tensorwise.
15-
- `BATCH_SIZE`: defaults to 1.
16-
- `STEPS`: defaults to 100.
14+
- `FLOAT8_RECIPE_WITH_BEST_SETTINGS`: "rowwise" or "tensorwise". Applies float8 training with the specified scaling recipe, as well as additional training configs which are optimal for that scaling recipe. See `float8_training_benchmark.sh` for more details.
15+
- `BATCH_SIZE`: Defaults to 1.
16+
- `STEPS`: Defaults to 100.
1717

1818
**NOTE**: `torch.compile` and FSDP2 are always used. Other forms of parallelism supported in torchtitan are not yet supported in this script.

benchmarks/quantized_training/pretrain_llama2.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@
2222
from torch.utils.checkpoint import checkpoint
2323
from tqdm import tqdm
2424

25-
from torchao import quantize_
25+
from torchao import optim, quantize_
2626
from torchao._models.llama.model import (
2727
ModelArgs,
2828
RMSNorm,
2929
Transformer,
3030
transformer_configs,
3131
)
32-
from torchao.prototype import low_bit_optim
3332
from torchao.prototype.quantized_training import (
3433
bitnet_training,
3534
int8_mixed_precision_training,
@@ -190,10 +189,10 @@ def insert_rmsnorm(module: torch.nn.Module):
190189
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}")
191190
torch.cuda.reset_peak_memory_stats() # don't count memory occupied by unquantized weights
192191

193-
# only use optimizers from torchao.prototype.low_bit_optim to support quantized training
192+
# only use optimizers from torchao.optim to support quantized training
194193
if args.optim == "AdamW":
195194
args.optim = "_AdamW"
196-
optim = getattr(low_bit_optim, args.optim)(
195+
optimizer = getattr(optim, args.optim)(
197196
model.parameters(),
198197
lr=args.lr,
199198
weight_decay=args.weight_decay,
@@ -228,15 +227,15 @@ def insert_rmsnorm(module: torch.nn.Module):
228227
if step % args.log_interval == 0:
229228
log_dict = dict(
230229
loss=loss.item(),
231-
lr=optim.param_groups[0]["lr"],
230+
lr=optimizer.param_groups[0]["lr"],
232231
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
233232
max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9,
234233
)
235234
run.log(log_dict, step=step)
236235
pbar.set_postfix(loss=log_dict["loss"])
237236

238-
optim.step()
239-
optim.zero_grad()
237+
optimizer.step()
238+
optimizer.zero_grad()
240239

241240
step += 1
242241
pbar.update()

test/dtypes/test_bitnet.py

Lines changed: 0 additions & 91 deletions
This file was deleted.

test/dtypes/test_uint2.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ def init_multi_module(self) -> nn.Module:
6767
return module
6868

6969
def init_transformer(
70-
self, weight_tying: bool, dtype: Optional[torch.dtype] = None
70+
self,
71+
weight_tying: bool,
72+
dtype: Optional[torch.dtype] = None,
73+
requires_grad: bool = True,
7174
) -> nn.Module:
7275
torch.manual_seed(42)
7376
args = ModelArgs(
@@ -81,6 +84,13 @@ def init_transformer(
8184
module = Transformer(args).cuda()
8285
if dtype is not None:
8386
module = module.to(dtype=dtype)
87+
88+
# if requires_grad=False, just set requires_grad to False
89+
# in the first layer to ensure we still train some params.
90+
if requires_grad is False:
91+
for param in module.layers[0].parameters():
92+
param.requires_grad = requires_grad
93+
8494
self.broadcast_module(module)
8595
return module
8696

@@ -107,6 +117,7 @@ def test_transformer_parity(self):
107117
],
108118
"compile_transformer_block": [False, True],
109119
"dtype": [torch.float32, torch.bfloat16],
120+
"requires_grad": [True, False],
110121
},
111122
self._test_transformer_parity,
112123
)
@@ -117,6 +128,7 @@ def _test_transformer_parity(
117128
precompute: bool,
118129
scaling_type_weight: ScalingType,
119130
compile_transformer_block: bool,
131+
requires_grad: bool,
120132
dtype: Optional[torch.dtype] = None,
121133
):
122134
if not enable_fsdp_float8_all_gather and precompute:
@@ -127,7 +139,10 @@ def _test_transformer_parity(
127139
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
128140
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
129141
weight_tying = not enable_fsdp_float8_all_gather
130-
module = self.init_transformer(weight_tying=weight_tying, dtype=dtype)
142+
module = self.init_transformer(
143+
weight_tying=weight_tying, dtype=dtype, requires_grad=requires_grad
144+
)
145+
131146
ref_module = copy.deepcopy(module)
132147
float8_linear_config1 = Float8LinearConfig(
133148
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),

0 commit comments

Comments
 (0)