-
Notifications
You must be signed in to change notification settings - Fork 105
Open
Description
While working on TorchDynamo + AOT integration, I came across the following bug
import torch
from torch.nn import *
from functorch.compile import print_compile, aot_module
import copy
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
# self.conv = Conv2d(3, 2, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
# self.instance_norm = InstanceNorm2d(2, affine=True, track_running_stats=True)
self.relu = ReLU(inplace=True)
def forward(self, x : torch.Tensor):
# self_main_0 = self.conv(x)
# self_main_1 = self.instance_norm(self_main_0)
self_main_0 = x * 2
self_main_1 = self_main_0.view([1, 3, 128, 128])
self_main_2 = self.relu(self_main_1)
return self_main_2
mod = Bar().to(device="cuda")
# Reduce randomness bits
mod.eval()
inp0 = torch.randn(1, 3, 128, 128, device='cuda', requires_grad=True)
inputs = (inp0, )
cloned_inp0 = inp0.clone().detach().requires_grad_(True)
cloned_inputs = (cloned_inp0, )
# Reference calculation
mod.zero_grad()
duplicated_mod = copy.deepcopy(mod)
ref = duplicated_mod(*inputs)
ref.sum().backward()
ref_grads = []
for param in duplicated_mod.parameters():
ref_grads.append(param.grad)
# AOT stuff
fx_mod = torch.fx.symbolic_trace(mod)
aot_mod = aot_module(fx_mod, print_compile)
aot_mod.zero_grad()
with torch.jit.fuser("fuser2"):
res = aot_mod(*cloned_inputs)
res.sum().backward()
res_grads = []
for param in aot_mod.parameters():
res_grads.append(param.grad)
assert torch.allclose(ref, res)
for (a, b) in zip(ref_grads, res_grads):
assert torch.allclose(a, b, atol=1e-4, rtol=1e-4), print(a, b)
for (a, b) in zip(inputs, cloned_inputs):
assert torch.allclose(a.grad, b.grad, atol=1e-4, rtol=1e-4), print(a.grad, b.grad)
view + inplace_Relu seems to give wrong backward trace.
Metadata
Metadata
Assignees
Labels
No labels