@@ -55,15 +55,14 @@ def __init__(self, in_planes, planes, stride=1):
55
55
self .bn2 = nn .BatchNorm2d (planes )
56
56
self .conv2 = conv3x3 (planes , planes )
57
57
58
- self .shortcut = nn .Sequential ()
59
58
if stride != 1 or in_planes != self .expansion * planes :
60
59
self .shortcut = nn .Sequential (
61
60
nn .Conv2d (in_planes , self .expansion * planes , kernel_size = 1 , stride = stride , bias = False )
62
61
)
63
62
64
63
def forward (self , x ):
65
64
out = F .relu (self .bn1 (x ))
66
- shortcut = self .shortcut (out )
65
+ shortcut = self .shortcut (out ) if hasattr ( self , 'shortcut' ) else x
67
66
out = self .conv1 (out )
68
67
out = self .conv2 (F .relu (self .bn2 (out )))
69
68
out += shortcut
@@ -111,15 +110,14 @@ def __init__(self, in_planes, planes, stride=1):
111
110
self .bn3 = nn .BatchNorm2d (planes )
112
111
self .conv3 = nn .Conv2d (planes , self .expansion * planes , kernel_size = 1 , bias = False )
113
112
114
- self .shortcut = nn .Sequential ()
115
113
if stride != 1 or in_planes != self .expansion * planes :
116
114
self .shortcut = nn .Sequential (
117
115
nn .Conv2d (in_planes , self .expansion * planes , kernel_size = 1 , stride = stride , bias = False )
118
116
)
119
117
120
118
def forward (self , x ):
121
119
out = F .relu (self .bn1 (x ))
122
- shortcut = self .shortcut (out )
120
+ shortcut = self .shortcut (out ) if hasattr ( self , 'shortcut' ) else x
123
121
out = self .conv1 (out )
124
122
out = self .conv2 (F .relu (self .bn2 (out )))
125
123
out = self .conv3 (F .relu (self .bn3 (out )))
0 commit comments