Skip to content

AOT Autograd - LSTM - grads not generated (model tts_angular) #586

@anijain2305

Description

@anijain2305

This is a subgraph from tts_angular model

The generated backward pass has many None outputs, suggesting that that requires_grad is somehow not passed correctly when LSTM cell is used.

import functorch
import torch
from torch.nn import *

from functorch.compile import memory_efficient_fusion, print_compile, aot_module, decomposition_table
import importlib
import torchdynamo
import copy
import itertools
from torchdynamo.optimizations import backends

class Bar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_lstm = LSTM(40, 768, batch_first=True)
        self.weight = Parameter(torch.randn(torch.Size([256, 768], requires_grad=True)))



    def forward(self, x):
        self_lstm = self.self_lstm(x);  x = None
        getitem = self_lstm[0];  self_lstm = None
        linear = torch.nn.functional.linear(getitem, self.weight, bias = None);  getitem = self_linear_weight = None
        return (linear,)

def reduce_out(out):
    if isinstance(out, torch.Tensor):
        return torch.sigmoid(out).sum()
    elif isinstance(out, (tuple, list)):
        return sum([reduce_out(x) for x in out])
    raise NotImplementedError("Don't know how to reduce", type(out))


def checkpoint_params(gm):
    rng_state = torch.clone(torch.random.get_rng_state())
    saved_state = []
    for param in itertools.chain(gm.parameters(), gm.buffers()):
        saved_state.append((param, param._version, torch.clone(param)))

    def restore():
        with torch.no_grad():
            torch.random.set_rng_state(rng_state)
            for param, version, original_value in saved_state:
                if param._version != version:
                    param.copy_(original_value)

    return restore


def clone_me(x):
    if x is None:
        return None
    return x.detach().clone().requires_grad_(x.requires_grad)

def collect_results(model, prediction, loss, example_inputs):
    results = []
    results.append(prediction)
    results.append(loss)
    for param in model.parameters():
        results.append(clone_me(param.grad))
    for example in example_inputs:
        if isinstance(example, list):
            for inp in example:
                results.append(clone_me(inp.grad))
        else:
            results.append(clone_me(example.grad))
    return results

def same(a, b):
    """Check correctness to see if a and b match"""
    if isinstance(a, (list, tuple, torch.nn.ParameterList)):
        if not isinstance(b, (list, tuple)):
            return False
        return all(same(ai, bi) for ai, bi in zip(a, b))
    elif isinstance(a, torch.Tensor):
        assert isinstance(b, torch.Tensor)
        if not  torch.allclose(a, b, atol=1e-5, rtol=1e-5):
            print(a.flatten()[1], b.flatten()[1])
            print(a.size())
        return torch.allclose(a, b, atol=1e-5, rtol=1e-5)
    elif isinstance(a, (int, float, type(None), bool, torch.device)):
        return a == b
    else:
        raise RuntimeError(f"unsupported type: {type(a).__name__}")


def clone_inputs(inputs):
    clones = [clone_me(x) for x in inputs]
    for c in clones:
        c.grad = None
    return clones


def get_results(mod, inputs):
    cloned_inputs = clone_inputs(inputs)
    mod.zero_grad(True)
    ref = mod(*cloned_inputs)
    l = reduce_out(ref)
    l.backward()
    ref_results = collect_results(mod, ref, l, cloned_inputs)
    return ref_results

def test_module():
    inp0 = torch.randn(64, 50, 40, device="cuda", requires_grad=True)
    inputs = [inp0, ]

    mod = Bar().to(device="cuda")
    restore = checkpoint_params(mod)
    orig_mod_results = get_results(mod, inputs)

    restore()
    new_mod = copy.deepcopy(mod)
    copy_mod_results = get_results(new_mod, inputs)
    print("Are Orig_mod and Copy_mod same:", same(orig_mod_results, copy_mod_results))
    # assert same(orig_mod_results, copy_mod_results), "Deepcopy of a mod fails, what the hell"

    restore()
    aot_mod = aot_module(mod, fw_compiler=print_compile)
    aot_mod_results = get_results(aot_mod, inputs)

    print("Recheck Are Orig_mod and Copy_mod same:", same(orig_mod_results, copy_mod_results))
    print("Are Orig_mod and AOT_mod same:", same(orig_mod_results, aot_mod_results))
    print("Are Copy_mod and AOT_mod same:", same(copy_mod_results, aot_mod_results))

test_module()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions