Skip to content

Commit dbf43d6

Browse files
jerryzh168pytorchmergebot
authored andcommitted
[quant][fx] Only do reference moduel swapping for floating point fused modules (#74231)
Summary: Pull Request resolved: #74231 Add a check to make sure the weighted modules we swap is actually a float fused module, since the reference fused module like reference version of linear - relu would have the same fused type as the floating point linear - relu (and the linear submodule will have different types) Test Plan: phabricator diff for now, can add a test case after we know exactly what the problem is Reviewed By: andrewor14 Differential Revision: D34888290 fbshipit-source-id: a7f53368a7c17f7d1a82afaa50d14d569b4923df (cherry picked from commit 458dac9)
1 parent 0471da5 commit dbf43d6

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

test/quantization/fx/test_quantize_fx.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3600,6 +3600,45 @@ def forward(self, x):
36003600
]
36013601
self.checkGraphModuleNodes(m, expected_node_list=node_list)
36023602

3603+
@skipIfNoFBGEMM
3604+
def test_dynamic_with_fusion_multiple_uses(self):
3605+
"""
3606+
Tests that dynamic quantization APIs work with Linear + Relu fusion
3607+
"""
3608+
class LinearRelu(torch.nn.Module):
3609+
def __init__(self):
3610+
super().__init__()
3611+
self.linear = torch.nn.Linear(5, 5)
3612+
self.relu = torch.nn.ReLU()
3613+
3614+
def forward(self, x):
3615+
x = self.linear(x)
3616+
return self.relu(x)
3617+
3618+
class M(torch.nn.Module):
3619+
def __init__(self):
3620+
super().__init__()
3621+
self.linear_relu = LinearRelu()
3622+
3623+
def forward(self, x):
3624+
x = self.linear_relu(x)
3625+
x = self.linear_relu(x)
3626+
return x
3627+
3628+
for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
3629+
model = M().eval()
3630+
qconfig_dict = {
3631+
"": qconfig
3632+
}
3633+
m = prepare_fx(model, qconfig_dict)
3634+
m = convert_fx(m)
3635+
m(torch.rand(5, 5))
3636+
node_list = [
3637+
ns.call_module(nniqd.LinearReLU),
3638+
ns.call_module(nniqd.LinearReLU),
3639+
]
3640+
self.checkGraphModuleNodes(m, expected_node_list=node_list)
3641+
36033642
def test_ref_linear_module(self):
36043643
""" Make sure the numerics for models with ref linear module
36053644
matches models with fbgemm/qnnpack module

torch/ao/quantization/fx/convert.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@
6868
torch.nn.intrinsic.ConvReLU3d,
6969
)
7070

71+
FLOAT_WEIGHTED_MODULE_CLASSES = (
72+
torch.nn.Linear,
73+
torch.nn.Conv1d,
74+
torch.nn.Conv2d,
75+
torch.nn.Conv3d,
76+
)
77+
7178
QAT_MODULE_CLASSES = (
7279
torch.nn.qat.Linear,
7380
torch.nn.qat.Conv2d,
@@ -746,6 +753,11 @@ def replace_observer_with_dequantize_node(node: Node, graph: Graph):
746753
node, modules, model, is_reference, backend_config_dict)
747754
elif type(modules[node.target]) in set(
748755
weighted_module_classes).union(QAT_MODULE_CLASSES).union(FUSED_MODULE_CLASSES):
756+
# extra check for fused module classes to make sure they are fused module classes
757+
# of target modules
758+
if type(modules[node.target]) in FUSED_MODULE_CLASSES and \
759+
type(modules[node.target][0]) not in FLOAT_WEIGHTED_MODULE_CLASSES:
760+
continue
749761
convert_weighted_module(
750762
node, modules, observed_node_names, quantized_reference_module_mapping, qconfig_map)
751763
elif type(modules[node.target]) in custom_module_classes:

0 commit comments

Comments
 (0)