Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Add required example_args argument to prepare_fx and prepare_qat_fx #791

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions test/models_densenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,32 @@ def _test_quantize_model(self, model_config):
_find_block_full_path(model.features, block_name)
for block_name in heads.keys()
]
# TODO[quant-example-inputs]: The dimension here is random, if we need to
# use dimension/rank in the future we'd need to get the correct dimensions
standalone_example_inputs = (torch.randn(1, 3, 3, 3),)
# we need to keep the modules used in head standalone since
# it will be accessed with path name directly in execution
prepare_custom_config_dict["standalone_module_name"] = [
(
head,
{"": tq.default_qconfig},
standalone_example_inputs,
{"input_quantized_idxs": [0], "output_quantized_idxs": []},
None,
)
for head in head_path_from_blocks
]
model.initial_block = prepare_fx(model.initial_block, {"": tq.default_qconfig})
# TODO[quant-example-inputs]: The dimension here is random, if we need to
# use dimension/rank in the future we'd need to get the correct dimensions
example_inputs = (torch.randn(1, 3, 3, 3),)
model.initial_block = prepare_fx(
model.initial_block, {"": tq.default_qconfig}, example_inputs
)

model.features = prepare_fx(
model.features,
{"": tq.default_qconfig},
example_inputs,
prepare_custom_config_dict,
)
model.set_heads(heads)
Expand All @@ -148,8 +159,8 @@ def test_small_densenet(self):
self._test_model(MODELS["small_densenet"])

@unittest.skipIf(
get_torch_version() < [1, 8],
"FX Graph Modee Quantization is only availablee from 1.8",
get_torch_version() < [1, 13],
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13",
)
def test_quantized_small_densenet(self):
self._test_quantize_model(MODELS["small_densenet"])
15 changes: 6 additions & 9 deletions test/models_mlp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,20 @@ def test_build_model(self):
self.assertEqual(output.shape, torch.Size([2, 1]))

@unittest.skipIf(
get_torch_version() < [1, 8],
"FX Graph Modee Quantization is only availablee from 1.8",
get_torch_version() < [1, 13],
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13",
)
def test_quantize_model(self):
if get_torch_version() >= [1, 11]:
import torch.ao.quantization as tq
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
else:
import torch.quantization as tq
from torch.quantization.quantize_fx import convert_fx, prepare_fx
import torch.ao.quantization as tq
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx

config = {"name": "mlp", "input_dim": 3, "output_dim": 1, "hidden_dims": [2]}
model = build_model(config)
self.assertTrue(isinstance(model, ClassyModel))

model.eval()
model.mlp = prepare_fx(model.mlp, {"": tq.default_qconfig})
example_inputs = (torch.rand(1, 3),)
model.mlp = prepare_fx(model.mlp, {"": tq.default_qconfig}, example_inputs)
model.mlp = convert_fx(model.mlp)

tensor = torch.tensor([[1, 2, 3]], dtype=torch.float)
Expand Down
17 changes: 8 additions & 9 deletions test/models_regnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,18 @@ def test_quantize_model(self, config):
Test that the model builds using a config using either model_params or
model_name and calls fx graph mode quantization apis
"""
if get_torch_version() < [1, 8]:
self.skipTest("FX Graph Modee Quantization is only availablee from 1.8")
if get_torch_version() >= [1, 11]:
import torch.ao.quantization as tq
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
else:
import torch.quantization as tq
from torch.quantization.quantize_fx import convert_fx, prepare_fx
if get_torch_version() < [1, 13]:
self.skipTest(
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13"
)
import torch.ao.quantization as tq
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx

model = build_model(config)
assert isinstance(model, RegNet)
model.eval()
model.stem = prepare_fx(model.stem, {"": tq.default_qconfig})
example_inputs = (torch.rand(1, 3, 3, 3),)
model.stem = prepare_fx(model.stem, {"": tq.default_qconfig}, example_inputs)
model.stem = convert_fx(model.stem)


Expand Down
27 changes: 19 additions & 8 deletions test/models_resnext_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,29 @@ def _post_training_quantize(model, input):
]
# we need to keep the modules used in head standalone since
# it will be accessed with path name directly in execution
# TODO[quant-example-inputs]: Fix the shape if it is needed in quantization
standalone_example_inputs = (torch.rand(1, 3, 3, 3),)
prepare_custom_config_dict["standalone_module_name"] = [
(
head,
{"": tq.default_qconfig},
standalone_example_inputs,
{"input_quantized_idxs": [0], "output_quantized_idxs": []},
None,
)
for head in head_path_from_blocks
]
model.initial_block = prepare_fx(model.initial_block, {"": tq.default_qconfig})
# TODO[quant-example-inputs]: Fix the shape if it is needed in quantization
example_inputs = (torch.rand(1, 3, 3, 3),)
model.initial_block = prepare_fx(
model.initial_block, {"": tq.default_qconfig}, example_inputs
)

model.blocks = prepare_fx(
model.blocks, {"": tq.default_qconfig}, prepare_custom_config_dict
model.blocks,
{"": tq.default_qconfig},
example_inputs,
prepare_custom_config_dict,
)
model.set_heads(heads)

Expand Down Expand Up @@ -222,8 +233,8 @@ def test_small_resnext(self):
self._test_model(MODELS["small_resnext"])

@unittest.skipIf(
get_torch_version() < [1, 8],
"FX Graph Modee Quantization is only availablee from 1.8",
get_torch_version() < [1, 13],
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13",
)
def test_quantized_small_resnext(self):
self._test_quantize_model(MODELS["small_resnext"])
Expand All @@ -232,8 +243,8 @@ def test_small_resnet(self):
self._test_model(MODELS["small_resnet"])

@unittest.skipIf(
get_torch_version() < [1, 8],
"FX Graph Modee Quantization is only availablee from 1.8",
get_torch_version() < [1, 13],
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13",
)
def test_quantized_small_resnet(self):
self._test_quantize_model(MODELS["small_resnet"])
Expand All @@ -242,8 +253,8 @@ def test_small_resnet_se(self):
self._test_model(MODELS["small_resnet_se"])

@unittest.skipIf(
get_torch_version() < [1, 8],
"FX Graph Modee Quantization is only availablee from 1.8",
get_torch_version() < [1, 13],
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13",
)
def test_quantized_small_resnet_se(self):
self._test_quantize_model(MODELS["small_resnet_se"])
Expand Down