Open
Description
To reproduce
from torchao import quantize_
from torchao.quantization import int8_weight_only
from torch import nn
import torch
linear = nn.Linear(1024, 1024)
quantize_(linear, int8_weight_only())
linear.cuda()
linear.compile()
linear(torch.randn(1, 1024, device="cuda"))
linear.cpu() # this will error
linear.cuda() # this will also error
Error
Traceback (most recent call last):
File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 945, in _apply
torch.utils.swap_tensors(param, param_applied)
File "/home/xxx/python3.10/site-packages/torch/utils/__init__.py", line 51, in swap_tensors
raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
RuntimeError: Cannot swap t1 because it has weakref associated with it
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/xxx/debug.py", line 11, in <module>
linear.cpu()
File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 1118, in cpu
return self._apply(lambda t: t.cpu())
File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 949, in _apply
raise RuntimeError(
RuntimeError: _apply(): Couldn't swap Linear.weight
This seems like a problem for tensor subclass + compile in general, not limited to AQT. Even doing compile(disable=False)
still has this error.
cc: @jerryzh168
torchao: 0.7.0+git26648c2c (install from source)
pytorch: tested with 2.5.0 and 2.6.0.dev20241102+cu124