Skip to content

AOT Autograd - Default partitioner fails for gather op (model pytorch_struct) #591

@anijain2305

Description

@anijain2305

gather_backward has an inplace op scatter_add_. The joint graph as a result of tracing has this op, but the default partitioner gets rid of it, and leads to incorrect results. The graphs look like these. The repro is below


## JOINT GRAPH
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    gather = torch.ops.aten.gather(primals_1, 3, primals_2);  primals_1 = None
    _tensor_constant0 = self._tensor_constant0
    scatter_add_ = torch.ops.aten.scatter_add_(_tensor_constant0, 3, primals_2, tangents_1);  _tensor_constant0 = primals_2 = tangents_1 = None
    _tensor_constant0_1 = self._tensor_constant0
    return pytree.tree_unflatten([gather, _tensor_constant0_1, None], self._out_spec)

# Partitioned fwd
def forward(self, primals_1, primals_2):
    gather = torch.ops.aten.gather(primals_1, 3, primals_2);  primals_1 = primals_2 = None
    return [gather]

# Partitioned bwd (where is my scatter_add_ :O )
def forward(self, tangents_1):
    _tensor_constant0_1 = self._tensor_constant0
    return [_tensor_constant0_1, None]
import torch
from torch.nn import *
from functorch.compile import memory_efficient_fusion, print_compile, aot_module, decomposition_table

###################
### HELPER ########
###################
def clone_me(x):
    if x is None:
        return None
    return x.detach().clone().requires_grad_(x.requires_grad)

def collect_results(model, prediction, loss, example_inputs):
    results = []
    results.append(prediction)
    results.append(loss)
    for param in model.parameters():
        results.append(clone_me(param.grad))
    for example in example_inputs:
        if isinstance(example, list):
            for inp in example:
                results.append(clone_me(inp.grad))
        else:
            results.append(clone_me(example.grad))
    return results

def same(a, b):
    """Check correctness to see if a and b match"""
    if isinstance(a, (list, tuple, torch.nn.ParameterList)):
        if not isinstance(b, (list, tuple)):
            return False
        return all(same(ai, bi) for ai, bi in zip(a, b))
    elif isinstance(a, torch.Tensor):
        assert isinstance(b, torch.Tensor)
        if not  torch.allclose(a, b, atol=1e-5, rtol=1e-5):
            print(a.flatten()[1], b.flatten()[1])
            print(a.size())
        return torch.allclose(a, b, atol=1e-5, rtol=1e-5)
    elif isinstance(a, (int, float, type(None), bool, torch.device)):
        return a == b
    else:
        raise RuntimeError(f"unsupported type: {type(a).__name__}")


def clone_inputs(inputs):
    clones = [clone_me(x) for x in inputs]
    for c in clones:
        c.grad = None
    return clones


def get_results(mod, inputs):
    cloned_inputs = clone_inputs(inputs)
    mod.zero_grad(True)
    ref = mod(*cloned_inputs)
    l = ref.sum()
    l.backward()
    ref_results = collect_results(mod, ref, l, cloned_inputs)
    return ref_results

##############################
#### ACTUAL TEST ############
############################

class FxModule(torch.nn.Module):
    def __init__(self):
        super().__init__()


    def forward(self, expand, unsqueeze_3):
        gather = torch.gather(expand, 3, unsqueeze_3);  expand = unsqueeze_3 = None
        return gather



a = torch.randn(torch.Size([40, 29, 30, 4771]), requires_grad = True)
b = torch.ones(torch.Size([40, 29, 30, 1]), dtype=torch.int64)
inputs = [a, b]

mod = FxModule().to(device="cuda")
orig_mod_results = get_results(mod, inputs)

aot_mod = aot_module(mod, fw_compiler=print_compile)
aot_mod_results = get_results(aot_mod, inputs)

assert same(orig_mod_results, aot_mod_results)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions