Skip to content

Update INT8 mixed-precision training test to be less flaky #950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 19 additions & 32 deletions test/prototype/test_quantized_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,44 +161,31 @@ def test_int8_weight_only_training(self, compile, device):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_int8_mixed_precision_training(self, compile, config):
_reset()
bsize = 4
embed_dim = 32
bsize = 64
embed_dim = 64
device = "cuda"

# only use 1 matmul shape to reduce triton autotune time
model_ref = nn.Sequential(
nn.Linear(embed_dim, embed_dim, bias=False),
nn.GELU(),
nn.Linear(embed_dim, embed_dim),
).to(device)
model_int8mp = copy.deepcopy(model_ref)
quantize_(model_int8mp, int8_mixed_precision_training(config), set_inductor_config=False)
linear = nn.Linear(embed_dim, embed_dim).cuda()
linear_int8mp = copy.deepcopy(linear)
quantize_(linear_int8mp, int8_mixed_precision_training(config), set_inductor_config=False)

if compile:
model_ref.compile()
model_int8mp.compile()
linear.compile()
linear_int8mp.compile()

optim_ref = torch.optim.AdamW(model_ref.parameters())
optim_int8mp = torch.optim.AdamW(model_int8mp.parameters())
inputs = torch.randn(bsize, embed_dim, device=device)
grad_outputs = torch.randn(bsize, embed_dim, device=device)

for i in range(5):
inputs = torch.randn(bsize, embed_dim, device=device)
labels = torch.randint(embed_dim, size=(bsize,), device=device)
loss_ref = F.cross_entropy(model_ref(inputs), labels)
loss_int8mp = F.cross_entropy(model_int8mp(inputs), labels)

rel_error = abs(loss_int8mp.item() - loss_ref.item()) / abs(loss_ref.item())
assert rel_error < 3e-3, (i, rel_error)

loss_ref.backward()
optim_ref.step()
optim_ref.zero_grad()

loss_int8mp.backward()
for p in model_int8mp.parameters():
assert p.grad is not None
optim_int8mp.step()
optim_int8mp.zero_grad()
inputs_ref, outputs_ref = self._forward_and_backward(linear, inputs, grad_outputs)
inputs_int8mp, outputs_int8mp = self._forward_and_backward(linear_int8mp, inputs, grad_outputs)

def snr(ref, actual):
error = actual - ref
return 20 * torch.log10(ref.norm() / error.norm())

assert snr(outputs_ref, outputs_int8mp) > 20
assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20
assert snr(linear.weight.grad, linear_int8mp.weight.grad) > 20


_FSDP_WORLD_SIZE = 2
Expand Down
Loading