Skip to content

AOT Autograd fails to get correct grads for view and Inplace Relu #514

@anijain2305

Description

@anijain2305

While working on TorchDynamo + AOT integration, I came across the following bug

import torch
from torch.nn import *
from functorch.compile import print_compile, aot_module
import copy

class Bar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # self.conv = Conv2d(3, 2, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        # self.instance_norm = InstanceNorm2d(2, affine=True, track_running_stats=True)
        self.relu = ReLU(inplace=True)

    def forward(self, x : torch.Tensor):
        # self_main_0 = self.conv(x)
        # self_main_1 = self.instance_norm(self_main_0)
        self_main_0 = x * 2
        self_main_1 = self_main_0.view([1, 3, 128, 128])
        self_main_2 = self.relu(self_main_1)
        return self_main_2



mod = Bar().to(device="cuda")
# Reduce randomness bits
mod.eval()

inp0 = torch.randn(1, 3, 128, 128, device='cuda', requires_grad=True)
inputs = (inp0, )

cloned_inp0 = inp0.clone().detach().requires_grad_(True)
cloned_inputs = (cloned_inp0, )

# Reference calculation
mod.zero_grad()
duplicated_mod = copy.deepcopy(mod)
ref = duplicated_mod(*inputs)
ref.sum().backward()
ref_grads = []
for param in duplicated_mod.parameters():
    ref_grads.append(param.grad)



# AOT stuff
fx_mod = torch.fx.symbolic_trace(mod)
aot_mod = aot_module(fx_mod, print_compile)
aot_mod.zero_grad()
with torch.jit.fuser("fuser2"):
    res = aot_mod(*cloned_inputs)
    res.sum().backward()

res_grads = []
for param in aot_mod.parameters():
    res_grads.append(param.grad)


assert torch.allclose(ref, res)

for (a, b) in zip(ref_grads, res_grads):
    assert torch.allclose(a, b, atol=1e-4, rtol=1e-4), print(a, b)

for (a, b) in zip(inputs, cloned_inputs):
    assert torch.allclose(a.grad, b.grad, atol=1e-4, rtol=1e-4), print(a.grad, b.grad)

view + inplace_Relu seems to give wrong backward trace.

@Chillee @jansel

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions