Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit c07f1d9

Browse files
committed
checkpiont to reduce memory usage, only do dynamic for now
1 parent 713d2db commit c07f1d9

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,27 @@ def backward(ctx, gradY):
4040
)
4141

4242

43+
def cast_weight_linear(
44+
x_fp8: Float8Tensor, weight: torch.Tensor, scale: torch.Tensor, bias, emulate: bool
45+
) -> torch.Tensor:
46+
"""Cast weight to fp8_e4m3fn and do linear
47+
Why a new function for something that can be inlined?
48+
Because we want to call torch utils checkpoint on this function.
49+
We always want to recompute the cast of the weight to fp8 since we can, trivially
50+
fuse this into the transpose/contiguous of the weight during the backwards.
51+
52+
Args:
53+
x_fp8 (Float8Tensor): input activation in fp8
54+
weight (torch.Tensor): weight tensor in higher precision
55+
scale (torch.Tensor): scale tensor for weight
56+
bias: bias tensor in higher precision
57+
emulate (bool): whether to emulate fp8 matmul logic in float32
58+
"""
59+
w_fp8 = Float8Tensor.to_float8(weight, scale, torch.float8_e4m3fn, emulate=emulate)
60+
y = torch.nn.functional.linear(x_fp8, w_fp8, bias)
61+
return y
62+
63+
4364
class Float8DynamicLinear(torch.nn.Linear):
4465
"""
4566
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
@@ -48,9 +69,16 @@ class Float8DynamicLinear(torch.nn.Linear):
4869

4970
def forward(self, x):
5071
x_fp8 = self.cast_to_float8(x)
51-
w_fp8 = self.cast_to_float8(self.weight)
52-
53-
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
72+
scale = tensor_to_scale(self.weight, torch.float8_e4m3fn)
73+
y = torch.utils.checkpoint.checkpoint(
74+
cast_weight_linear,
75+
x_fp8,
76+
self.weight,
77+
scale,
78+
self.bias,
79+
self.emulate,
80+
use_reentrant=False,
81+
)
5482

5583
# Cast gradY to float8_e5m2 during backward
5684
y = self.cast_to_float8e5m2_bw(y)

0 commit comments

Comments
 (0)