Skip to content

Commit 0acfbc1

Browse files
author
kuangliu
committed
Fix ResNet pre-act block, see issue pytorch#9
1 parent d5c53c0 commit 0acfbc1

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

models/resnet.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,14 @@ def __init__(self, in_planes, planes, stride=1):
5555
self.bn2 = nn.BatchNorm2d(planes)
5656
self.conv2 = conv3x3(planes, planes)
5757

58-
self.shortcut = nn.Sequential()
5958
if stride != 1 or in_planes != self.expansion*planes:
6059
self.shortcut = nn.Sequential(
6160
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
6261
)
6362

6463
def forward(self, x):
6564
out = F.relu(self.bn1(x))
66-
shortcut = self.shortcut(out)
65+
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
6766
out = self.conv1(out)
6867
out = self.conv2(F.relu(self.bn2(out)))
6968
out += shortcut
@@ -111,15 +110,14 @@ def __init__(self, in_planes, planes, stride=1):
111110
self.bn3 = nn.BatchNorm2d(planes)
112111
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
113112

114-
self.shortcut = nn.Sequential()
115113
if stride != 1 or in_planes != self.expansion*planes:
116114
self.shortcut = nn.Sequential(
117115
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
118116
)
119117

120118
def forward(self, x):
121119
out = F.relu(self.bn1(x))
122-
shortcut = self.shortcut(out)
120+
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
123121
out = self.conv1(out)
124122
out = self.conv2(F.relu(self.bn2(out)))
125123
out = self.conv3(F.relu(self.bn3(out)))

0 commit comments

Comments
 (0)