From 30fc364d9850c6d9773136d14d52398302595394 Mon Sep 17 00:00:00 2001 From: Jakub Gajski <78765372+szperajacyzolw@users.noreply.github.com> Date: Tue, 22 Jun 2021 20:45:00 +0200 Subject: [PATCH] Update inception.py Insertion in line 119: self.flatten = nn.Flatten() Change in line 191: x = torch.flatten(x, 1) -> x = self.flatten(x) This change allows to override flattening before build-in dense classifier, therefore enabling non-dense custom processing heads(e.g. pseudo-embedders for features injection into transformers for image captioning). Before, flattening was inaccessible, forcing users to play with un-flattening, which is inconvenient. --- torchvision/models/inception.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index b9c6ab74534..e7cd8b16c83 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -116,6 +116,7 @@ def __init__( self.Mixed_7c = inception_e(2048) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.dropout = nn.Dropout() + self.flatten = nn.Flatten() self.fc = nn.Linear(2048, num_classes) if init_weights: for m in self.modules(): @@ -187,7 +188,7 @@ def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]: # N x 2048 x 1 x 1 x = self.dropout(x) # N x 2048 x 1 x 1 - x = torch.flatten(x, 1) + x = self.flatten(x) # change from torch.flatten to nn.Flatten allows to override this stage, whitch was previously impossible # N x 2048 x = self.fc(x) # N x 1000 (num_classes)