diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 6c9b9006e5d..071511d0c40 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -138,14 +138,16 @@ def forward(self, x): class QuantizableInceptionE(inception_module.InceptionE): def __init__(self, *args, **kwargs): super(QuantizableInceptionE, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) - self.myop = nn.quantized.FloatFunctional() + self.myop1 = nn.quantized.FloatFunctional() + self.myop2 = nn.quantized.FloatFunctional() + self.myop3 = nn.quantized.FloatFunctional() def _forward(self, x): branch1x1 = self.branch1x1(x) branch3x3 = self.branch3x3_1(x) branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)] - branch3x3 = self.myop.cat(branch3x3, 1) + branch3x3 = self.myop1.cat(branch3x3, 1) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) @@ -153,7 +155,7 @@ def _forward(self, x): self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl), ] - branch3x3dbl = self.myop.cat(branch3x3dbl, 1) + branch3x3dbl = self.myop2.cat(branch3x3dbl, 1) branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) branch_pool = self.branch_pool(branch_pool) @@ -163,7 +165,7 @@ def _forward(self, x): def forward(self, x): outputs = self._forward(x) - return self.myop.cat(outputs, 1) + return self.myop3.cat(outputs, 1) class QuantizableInceptionAux(inception_module.InceptionAux):