-
Notifications
You must be signed in to change notification settings - Fork 105
Closed
Description
gather_backward
has an inplace op scatter_add_
. The joint graph as a result of tracing has this op, but the default partitioner gets rid of it, and leads to incorrect results. The graphs look like these. The repro is below
## JOINT GRAPH
def forward(self, primals, tangents):
primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
gather = torch.ops.aten.gather(primals_1, 3, primals_2); primals_1 = None
_tensor_constant0 = self._tensor_constant0
scatter_add_ = torch.ops.aten.scatter_add_(_tensor_constant0, 3, primals_2, tangents_1); _tensor_constant0 = primals_2 = tangents_1 = None
_tensor_constant0_1 = self._tensor_constant0
return pytree.tree_unflatten([gather, _tensor_constant0_1, None], self._out_spec)
# Partitioned fwd
def forward(self, primals_1, primals_2):
gather = torch.ops.aten.gather(primals_1, 3, primals_2); primals_1 = primals_2 = None
return [gather]
# Partitioned bwd (where is my scatter_add_ :O )
def forward(self, tangents_1):
_tensor_constant0_1 = self._tensor_constant0
return [_tensor_constant0_1, None]
import torch
from torch.nn import *
from functorch.compile import memory_efficient_fusion, print_compile, aot_module, decomposition_table
###################
### HELPER ########
###################
def clone_me(x):
if x is None:
return None
return x.detach().clone().requires_grad_(x.requires_grad)
def collect_results(model, prediction, loss, example_inputs):
results = []
results.append(prediction)
results.append(loss)
for param in model.parameters():
results.append(clone_me(param.grad))
for example in example_inputs:
if isinstance(example, list):
for inp in example:
results.append(clone_me(inp.grad))
else:
results.append(clone_me(example.grad))
return results
def same(a, b):
"""Check correctness to see if a and b match"""
if isinstance(a, (list, tuple, torch.nn.ParameterList)):
if not isinstance(b, (list, tuple)):
return False
return all(same(ai, bi) for ai, bi in zip(a, b))
elif isinstance(a, torch.Tensor):
assert isinstance(b, torch.Tensor)
if not torch.allclose(a, b, atol=1e-5, rtol=1e-5):
print(a.flatten()[1], b.flatten()[1])
print(a.size())
return torch.allclose(a, b, atol=1e-5, rtol=1e-5)
elif isinstance(a, (int, float, type(None), bool, torch.device)):
return a == b
else:
raise RuntimeError(f"unsupported type: {type(a).__name__}")
def clone_inputs(inputs):
clones = [clone_me(x) for x in inputs]
for c in clones:
c.grad = None
return clones
def get_results(mod, inputs):
cloned_inputs = clone_inputs(inputs)
mod.zero_grad(True)
ref = mod(*cloned_inputs)
l = ref.sum()
l.backward()
ref_results = collect_results(mod, ref, l, cloned_inputs)
return ref_results
##############################
#### ACTUAL TEST ############
############################
class FxModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, expand, unsqueeze_3):
gather = torch.gather(expand, 3, unsqueeze_3); expand = unsqueeze_3 = None
return gather
a = torch.randn(torch.Size([40, 29, 30, 4771]), requires_grad = True)
b = torch.ones(torch.Size([40, 29, 30, 1]), dtype=torch.int64)
inputs = [a, b]
mod = FxModule().to(device="cuda")
orig_mod_results = get_results(mod, inputs)
aot_mod = aot_module(mod, fw_compiler=print_compile)
aot_mod_results = get_results(aot_mod, inputs)
assert same(orig_mod_results, aot_mod_results)
Metadata
Metadata
Assignees
Labels
No labels