-
Notifications
You must be signed in to change notification settings - Fork 105
Closed
pytorch/torchdynamo
#975Description
This is a subgraph from tts_angular
model
The generated backward pass has many None
outputs, suggesting that that requires_grad
is somehow not passed correctly when LSTM cell is used.
import functorch
import torch
from torch.nn import *
from functorch.compile import memory_efficient_fusion, print_compile, aot_module, decomposition_table
import importlib
import torchdynamo
import copy
import itertools
from torchdynamo.optimizations import backends
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_lstm = LSTM(40, 768, batch_first=True)
self.weight = Parameter(torch.randn(torch.Size([256, 768], requires_grad=True)))
def forward(self, x):
self_lstm = self.self_lstm(x); x = None
getitem = self_lstm[0]; self_lstm = None
linear = torch.nn.functional.linear(getitem, self.weight, bias = None); getitem = self_linear_weight = None
return (linear,)
def reduce_out(out):
if isinstance(out, torch.Tensor):
return torch.sigmoid(out).sum()
elif isinstance(out, (tuple, list)):
return sum([reduce_out(x) for x in out])
raise NotImplementedError("Don't know how to reduce", type(out))
def checkpoint_params(gm):
rng_state = torch.clone(torch.random.get_rng_state())
saved_state = []
for param in itertools.chain(gm.parameters(), gm.buffers()):
saved_state.append((param, param._version, torch.clone(param)))
def restore():
with torch.no_grad():
torch.random.set_rng_state(rng_state)
for param, version, original_value in saved_state:
if param._version != version:
param.copy_(original_value)
return restore
def clone_me(x):
if x is None:
return None
return x.detach().clone().requires_grad_(x.requires_grad)
def collect_results(model, prediction, loss, example_inputs):
results = []
results.append(prediction)
results.append(loss)
for param in model.parameters():
results.append(clone_me(param.grad))
for example in example_inputs:
if isinstance(example, list):
for inp in example:
results.append(clone_me(inp.grad))
else:
results.append(clone_me(example.grad))
return results
def same(a, b):
"""Check correctness to see if a and b match"""
if isinstance(a, (list, tuple, torch.nn.ParameterList)):
if not isinstance(b, (list, tuple)):
return False
return all(same(ai, bi) for ai, bi in zip(a, b))
elif isinstance(a, torch.Tensor):
assert isinstance(b, torch.Tensor)
if not torch.allclose(a, b, atol=1e-5, rtol=1e-5):
print(a.flatten()[1], b.flatten()[1])
print(a.size())
return torch.allclose(a, b, atol=1e-5, rtol=1e-5)
elif isinstance(a, (int, float, type(None), bool, torch.device)):
return a == b
else:
raise RuntimeError(f"unsupported type: {type(a).__name__}")
def clone_inputs(inputs):
clones = [clone_me(x) for x in inputs]
for c in clones:
c.grad = None
return clones
def get_results(mod, inputs):
cloned_inputs = clone_inputs(inputs)
mod.zero_grad(True)
ref = mod(*cloned_inputs)
l = reduce_out(ref)
l.backward()
ref_results = collect_results(mod, ref, l, cloned_inputs)
return ref_results
def test_module():
inp0 = torch.randn(64, 50, 40, device="cuda", requires_grad=True)
inputs = [inp0, ]
mod = Bar().to(device="cuda")
restore = checkpoint_params(mod)
orig_mod_results = get_results(mod, inputs)
restore()
new_mod = copy.deepcopy(mod)
copy_mod_results = get_results(new_mod, inputs)
print("Are Orig_mod and Copy_mod same:", same(orig_mod_results, copy_mod_results))
# assert same(orig_mod_results, copy_mod_results), "Deepcopy of a mod fails, what the hell"
restore()
aot_mod = aot_module(mod, fw_compiler=print_compile)
aot_mod_results = get_results(aot_mod, inputs)
print("Recheck Are Orig_mod and Copy_mod same:", same(orig_mod_results, copy_mod_results))
print("Are Orig_mod and AOT_mod same:", same(orig_mod_results, aot_mod_results))
print("Are Copy_mod and AOT_mod same:", same(copy_mod_results, aot_mod_results))
test_module()
Metadata
Metadata
Assignees
Labels
No labels