diff --git a/examples/libtorchaudio/speech_recognition/build_pipeline_from_fairseq.py b/examples/libtorchaudio/speech_recognition/build_pipeline_from_fairseq.py index a6da0ae1e1..a87ea3cf4e 100644 --- a/examples/libtorchaudio/speech_recognition/build_pipeline_from_fairseq.py +++ b/examples/libtorchaudio/speech_recognition/build_pipeline_from_fairseq.py @@ -6,6 +6,7 @@ import os import argparse import logging +from typing import Tuple import torch from torch.utils.mobile_optimizer import optimize_for_mobile @@ -15,6 +16,12 @@ from greedy_decoder import Decoder +TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) +if TORCH_VERSION >= (1, 10): + import torch.ao.quantization as tq +else: + import torch.quantization as tq + _LG = logging.getLogger(__name__) @@ -149,7 +156,7 @@ def _main(): if args.quantize: _LG.info('Quantizing the model') model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() - encoder = torch.quantization.quantize_dynamic( + encoder = tq.quantize_dynamic( encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) _LG.info(encoder) diff --git a/examples/libtorchaudio/speech_recognition/build_pipeline_from_huggingface_transformers.py b/examples/libtorchaudio/speech_recognition/build_pipeline_from_huggingface_transformers.py index 10323d96f7..f7cc49e2ee 100644 --- a/examples/libtorchaudio/speech_recognition/build_pipeline_from_huggingface_transformers.py +++ b/examples/libtorchaudio/speech_recognition/build_pipeline_from_huggingface_transformers.py @@ -2,12 +2,19 @@ import argparse import logging import os +from typing import Tuple import torch import torchaudio from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model from greedy_decoder import Decoder +TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) +if TORCH_VERSION >= (1, 10): + import torch.ao.quantization as tq +else: + import torch.quantization as tq + _LG = logging.getLogger(__name__) @@ -90,7 +97,7 @@ def _main(): if args.quantize: _LG.info('Quantizing the model') model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() - encoder = torch.quantization.quantize_dynamic( + encoder = tq.quantize_dynamic( encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) _LG.info(encoder) diff --git a/test/torchaudio_unittest/models/wav2vec2/model_test.py b/test/torchaudio_unittest/models/wav2vec2/model_test.py index f029d970df..d8719ca632 100644 --- a/test/torchaudio_unittest/models/wav2vec2/model_test.py +++ b/test/torchaudio_unittest/models/wav2vec2/model_test.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F +from typing import Tuple from torchaudio.models.wav2vec2 import ( wav2vec2_asr_base, @@ -24,6 +25,12 @@ ) from parameterized import parameterized +TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) +if TORCH_VERSION >= (1, 10): + import torch.ao.quantization as tq +else: + import torch.quantization as tq + def _name_func(testcase_func, i, param): return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}" @@ -206,7 +213,7 @@ def _test_quantize_smoke_test(self, model): # Remove the weight normalization forward hook model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() - quantized = torch.quantization.quantize_dynamic( + quantized = tq.quantize_dynamic( model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) # A lazy way to check that Modules are different @@ -237,7 +244,7 @@ def _test_quantize_torchscript(self, model): # Remove the weight normalization forward hook model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() - quantized = torch.quantization.quantize_dynamic( + quantized = tq.quantize_dynamic( model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) # A lazy way to check that Modules are different