Skip to content

Commit ecdea1c

Browse files
authored
Merge branch 'main' into prototype/cleanup_api
2 parents 4b1715b + b3cdec1 commit ecdea1c

File tree

14 files changed

+88
-48
lines changed

14 files changed

+88
-48
lines changed

.circleci/config.yml

+2-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

+2-1
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,9 @@ jobs:
278278
background: true
279279
command: |
280280
sudo apt update -qy && sudo apt install -qy parallel wget
281+
mkdir -p ~/.cache/torch/hub/checkpoints
281282
python scripts/collect_model_urls.py torchvision/prototype/models \
282-
| parallel -j0 wget --no-verbose -P ~/.cache/torch/hub/checkpoints {}
283+
| parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci'
283284
- run:
284285
name: Install torchvision
285286
command: pip install --user --progress-bar off --no-build-isolation .

references/classification/train_quantization.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55

66
import torch
7-
import torch.quantization
7+
import torch.ao.quantization
88
import torch.utils.data
99
import torchvision
1010
import utils
@@ -62,8 +62,8 @@ def main(args):
6262

6363
if not (args.test_only or args.post_training_quantize):
6464
model.fuse_model()
65-
model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
66-
torch.quantization.prepare_qat(model, inplace=True)
65+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
66+
torch.ao.quantization.prepare_qat(model, inplace=True)
6767

6868
if args.distributed and args.sync_bn:
6969
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -96,12 +96,12 @@ def main(args):
9696
)
9797
model.eval()
9898
model.fuse_model()
99-
model.qconfig = torch.quantization.get_default_qconfig(args.backend)
100-
torch.quantization.prepare(model, inplace=True)
99+
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
100+
torch.ao.quantization.prepare(model, inplace=True)
101101
# Calibrate first
102102
print("Calibrating")
103103
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
104-
torch.quantization.convert(model, inplace=True)
104+
torch.ao.quantization.convert(model, inplace=True)
105105
if args.output_dir:
106106
print("Saving quantized model")
107107
if utils.is_main_process():
@@ -114,8 +114,8 @@ def main(args):
114114
evaluate(model, criterion, data_loader_test, device=device)
115115
return
116116

117-
model.apply(torch.quantization.enable_observer)
118-
model.apply(torch.quantization.enable_fake_quant)
117+
model.apply(torch.ao.quantization.enable_observer)
118+
model.apply(torch.ao.quantization.enable_fake_quant)
119119
start_time = time.time()
120120
for epoch in range(args.start_epoch, args.epochs):
121121
if args.distributed:
@@ -126,7 +126,7 @@ def main(args):
126126
with torch.inference_mode():
127127
if epoch >= args.num_observer_update_epochs:
128128
print("Disabling observer for subseq epochs, epoch = ", epoch)
129-
model.apply(torch.quantization.disable_observer)
129+
model.apply(torch.ao.quantization.disable_observer)
130130
if epoch >= args.num_batch_norm_update_epochs:
131131
print("Freezing BN for subseq epochs, epoch = ", epoch)
132132
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
@@ -136,7 +136,7 @@ def main(args):
136136
quantized_eval_model = copy.deepcopy(model_without_ddp)
137137
quantized_eval_model.eval()
138138
quantized_eval_model.to(torch.device("cpu"))
139-
torch.quantization.convert(quantized_eval_model, inplace=True)
139+
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
140140

141141
print("Evaluate Quantized model")
142142
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))

references/classification/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,8 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
345345
# Quantized Classification
346346
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
347347
model.fuse_model()
348-
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
349-
_ = torch.quantization.prepare_qat(model, inplace=True)
348+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
349+
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
350350
print(store_model_weights(model, './qat.pth'))
351351
352352
# Object Detection

test/test_functional_tensor.py

+30
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,36 @@ def test_solarize2(device, dtype, config, channels):
795795
)
796796

797797

798+
@pytest.mark.parametrize("device", cpu_and_gpu())
799+
@pytest.mark.parametrize("threshold", [0.0, 0.25, 0.5, 0.75, 1.0])
800+
def test_solarize_threshold1_bound(threshold, device):
801+
img = torch.rand((3, 12, 23)).to(device)
802+
F_t.solarize(img, threshold)
803+
804+
805+
@pytest.mark.parametrize("device", cpu_and_gpu())
806+
@pytest.mark.parametrize("threshold", [1.5])
807+
def test_solarize_threshold1_upper_bound(threshold, device):
808+
img = torch.rand((3, 12, 23)).to(device)
809+
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
810+
F_t.solarize(img, threshold)
811+
812+
813+
@pytest.mark.parametrize("device", cpu_and_gpu())
814+
@pytest.mark.parametrize("threshold", [0, 64, 128, 192, 255])
815+
def test_solarize_threshold2_bound(threshold, device):
816+
img = torch.randint(0, 256, (3, 12, 23)).to(device)
817+
F_t.solarize(img, threshold)
818+
819+
820+
@pytest.mark.parametrize("device", cpu_and_gpu())
821+
@pytest.mark.parametrize("threshold", [260])
822+
def test_solarize_threshold2_upper_bound(threshold, device):
823+
img = torch.randint(0, 256, (3, 12, 23)).to(device)
824+
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
825+
F_t.solarize(img, threshold)
826+
827+
798828
@pytest.mark.parametrize("device", cpu_and_gpu())
799829
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
800830
@pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])

test/test_models.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -781,19 +781,19 @@ def test_quantized_classification_model(model_fn):
781781
model = model_fn(**kwargs)
782782
if eval_mode:
783783
model.eval()
784-
model.qconfig = torch.quantization.default_qconfig
784+
model.qconfig = torch.ao.quantization.default_qconfig
785785
else:
786786
model.train()
787-
model.qconfig = torch.quantization.default_qat_qconfig
787+
model.qconfig = torch.ao.quantization.default_qat_qconfig
788788

789789
model.fuse_model()
790790
if eval_mode:
791-
torch.quantization.prepare(model, inplace=True)
791+
torch.ao.quantization.prepare(model, inplace=True)
792792
else:
793-
torch.quantization.prepare_qat(model, inplace=True)
793+
torch.ao.quantization.prepare_qat(model, inplace=True)
794794
model.eval()
795795

796-
torch.quantization.convert(model, inplace=True)
796+
torch.ao.quantization.convert(model, inplace=True)
797797

798798
try:
799799
torch.jit.script(model)

torchvision/models/quantization/googlenet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def forward(self, x: Tensor) -> Tensor:
3131
return x
3232

3333
def fuse_model(self) -> None:
34-
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
34+
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
3535

3636

3737
class QuantizableInception(Inception):
@@ -74,8 +74,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
7474
super().__init__( # type: ignore[misc]
7575
blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], *args, **kwargs
7676
)
77-
self.quant = torch.quantization.QuantStub()
78-
self.dequant = torch.quantization.DeQuantStub()
77+
self.quant = torch.ao.quantization.QuantStub()
78+
self.dequant = torch.ao.quantization.DeQuantStub()
7979

8080
def forward(self, x: Tensor) -> GoogLeNetOutputs:
8181
x = self._transform_input(x)

torchvision/models/quantization/inception.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def forward(self, x: Tensor) -> Tensor:
3636
return x
3737

3838
def fuse_model(self) -> None:
39-
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
39+
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
4040

4141

4242
class QuantizableInceptionA(inception_module.InceptionA):
@@ -144,8 +144,8 @@ def __init__(
144144
QuantizableInceptionAux,
145145
],
146146
)
147-
self.quant = torch.quantization.QuantStub()
148-
self.dequant = torch.quantization.DeQuantStub()
147+
self.quant = torch.ao.quantization.QuantStub()
148+
self.dequant = torch.ao.quantization.DeQuantStub()
149149

150150
def forward(self, x: Tensor) -> InceptionOutputs:
151151
x = self._transform_input(x)

torchvision/models/quantization/mobilenetv2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from torch import Tensor
44
from torch import nn
5-
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
5+
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules
66
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
77

88
from ..._internally_replaced_utils import load_state_dict_from_url

torchvision/models/quantization/mobilenetv3.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from torch import nn, Tensor
5-
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
5+
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules
66

77
from ..._internally_replaced_utils import load_state_dict_from_url
88
from ...ops.misc import ConvNormActivation, SqueezeExcitation
@@ -136,13 +136,13 @@ def _mobilenet_v3_model(
136136
backend = "qnnpack"
137137

138138
model.fuse_model()
139-
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
140-
torch.quantization.prepare_qat(model, inplace=True)
139+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
140+
torch.ao.quantization.prepare_qat(model, inplace=True)
141141

142142
if pretrained:
143143
_load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress)
144144

145-
torch.quantization.convert(model, inplace=True)
145+
torch.ao.quantization.convert(model, inplace=True)
146146
model.eval()
147147
else:
148148
if pretrained:

torchvision/models/quantization/resnet.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import torch.nn as nn
55
from torch import Tensor
6-
from torch.quantization import fuse_modules
6+
from torch.ao.quantization import fuse_modules
77
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
88

99
from ..._internally_replaced_utils import load_state_dict_from_url
@@ -42,9 +42,9 @@ def forward(self, x: Tensor) -> Tensor:
4242
return out
4343

4444
def fuse_model(self) -> None:
45-
torch.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
45+
torch.ao.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
4646
if self.downsample:
47-
torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
47+
torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
4848

4949

5050
class QuantizableBottleneck(Bottleneck):
@@ -75,15 +75,15 @@ def forward(self, x: Tensor) -> Tensor:
7575
def fuse_model(self) -> None:
7676
fuse_modules(self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True)
7777
if self.downsample:
78-
torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
78+
torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
7979

8080

8181
class QuantizableResNet(ResNet):
8282
def __init__(self, *args: Any, **kwargs: Any) -> None:
8383
super().__init__(*args, **kwargs)
8484

85-
self.quant = torch.quantization.QuantStub()
86-
self.dequant = torch.quantization.DeQuantStub()
85+
self.quant = torch.ao.quantization.QuantStub()
86+
self.dequant = torch.ao.quantization.DeQuantStub()
8787

8888
def forward(self, x: Tensor) -> Tensor:
8989
x = self.quant(x)

torchvision/models/quantization/shufflenetv2.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
4141
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
4242
def __init__(self, *args: Any, **kwargs: Any) -> None:
4343
super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) # type: ignore[misc]
44-
self.quant = torch.quantization.QuantStub()
45-
self.dequant = torch.quantization.DeQuantStub()
44+
self.quant = torch.ao.quantization.QuantStub()
45+
self.dequant = torch.ao.quantization.DeQuantStub()
4646

4747
def forward(self, x: Tensor) -> Tensor:
4848
x = self.quant(x)
@@ -60,12 +60,12 @@ def fuse_model(self) -> None:
6060

6161
for name, m in self._modules.items():
6262
if name in ["conv1", "conv5"]:
63-
torch.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
63+
torch.ao.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
6464
for m in self.modules():
6565
if type(m) is QuantizableInvertedResidual:
6666
if len(m.branch1._modules.items()) > 0:
67-
torch.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
68-
torch.quantization.fuse_modules(
67+
torch.ao.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
68+
torch.ao.quantization.fuse_modules(
6969
m.branch2,
7070
[["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
7171
inplace=True,

torchvision/models/quantization/utils.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,19 @@ def quantize_model(model: nn.Module, backend: str) -> None:
2424
model.eval()
2525
# Make sure that weight qconfig matches that of the serialized models
2626
if backend == "fbgemm":
27-
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
28-
activation=torch.quantization.default_observer,
29-
weight=torch.quantization.default_per_channel_weight_observer,
27+
model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
28+
activation=torch.ao.quantization.default_observer,
29+
weight=torch.ao.quantization.default_per_channel_weight_observer,
3030
)
3131
elif backend == "qnnpack":
32-
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
33-
activation=torch.quantization.default_observer, weight=torch.quantization.default_weight_observer
32+
model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
33+
activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer
3434
)
3535

3636
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
3737
model.fuse_model() # type: ignore[operator]
38-
torch.quantization.prepare(model, inplace=True)
38+
torch.ao.quantization.prepare(model, inplace=True)
3939
model(_dummy_input_data)
40-
torch.quantization.convert(model, inplace=True)
40+
torch.ao.quantization.convert(model, inplace=True)
4141

4242
return

torchvision/transforms/functional_tensor.py

+8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ def _assert_image_tensor(img: Tensor) -> None:
1616
raise TypeError("Tensor is not a torch image.")
1717

1818

19+
def _assert_threshold(img: Tensor, threshold: float) -> None:
20+
bound = 1 if img.is_floating_point() else 255
21+
if threshold > bound:
22+
raise TypeError("Threshold should be less than bound of img.")
23+
24+
1925
def get_image_size(img: Tensor) -> List[int]:
2026
# Returns (w, h) of tensor image
2127
_assert_image_tensor(img)
@@ -882,6 +888,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
882888

883889
_assert_channels(img, [1, 3])
884890

891+
_assert_threshold(img, threshold)
892+
885893
inverted_img = invert(img)
886894
return torch.where(img >= threshold, inverted_img, img)
887895

0 commit comments

Comments
 (0)