Skip to content

[AQT] Failed to move compiled module with AQT to a different device #1309

Open
@gau-nernst

Description

@gau-nernst

To reproduce

from torchao import quantize_
from torchao.quantization import int8_weight_only
from torch import nn
import torch

linear = nn.Linear(1024, 1024)
quantize_(linear, int8_weight_only())
linear.cuda()
linear.compile()
linear(torch.randn(1, 1024, device="cuda"))
linear.cpu()  # this will error
linear.cuda()  # this will also error

Error

Traceback (most recent call last):
  File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 945, in _apply
    torch.utils.swap_tensors(param, param_applied)
  File "/home/xxx/python3.10/site-packages/torch/utils/__init__.py", line 51, in swap_tensors
    raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
RuntimeError: Cannot swap t1 because it has weakref associated with it

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xxx/debug.py", line 11, in <module>
    linear.cpu()
  File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 1118, in cpu
    return self._apply(lambda t: t.cpu())
  File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 949, in _apply
    raise RuntimeError(
RuntimeError: _apply(): Couldn't swap Linear.weight

This seems like a problem for tensor subclass + compile in general, not limited to AQT. Even doing compile(disable=False) still has this error.

cc: @jerryzh168

torchao: 0.7.0+git26648c2c (install from source)
pytorch: tested with 2.5.0 and 2.6.0.dev20241102+cu124

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions