Skip to content

Commit 3cee95e

Browse files
authored
Merge pull request #5 from colesbury/setattr
Assign child modules via attributes
2 parents c0079d1 + 5de71ea commit 3cee95e

File tree

1 file changed

+30
-32
lines changed

1 file changed

+30
-32
lines changed

imagenet/resnet.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@ class BasicBlock(nn.Container):
1212
expansion = 1
1313

1414
def __init__(self, inplanes, planes, stride=1, downsample=None):
15-
super(BasicBlock, self).__init__(
16-
conv1=conv3x3(inplanes, planes, stride),
17-
bn1=nn.BatchNorm2d(planes),
18-
relu=nn.ReLU(inplace=True),
19-
conv2=conv3x3(planes, planes),
20-
bn2=nn.BatchNorm2d(planes),
21-
downsample=downsample,
22-
)
15+
super(BasicBlock, self).__init__()
16+
self.conv1 = conv3x3(inplanes, planes, stride)
17+
self.bn1 = nn.BatchNorm2d(planes)
18+
self.relu = nn.ReLU(inplace=True)
19+
self.conv2 = conv3x3(planes, planes)
20+
self.bn2 = nn.BatchNorm2d(planes)
21+
self.downsample = downsample
2322
self.stride = stride
2423

2524
def forward(self, x):
@@ -45,17 +44,16 @@ class Bottleneck(nn.Container):
4544
expansion = 4
4645

4746
def __init__(self, inplanes, planes, stride=1, downsample=None):
48-
super(Bottleneck, self).__init__(
49-
conv1=nn.Conv2d(inplanes, planes, kernel_size=1, bias=False),
50-
bn1=nn.BatchNorm2d(planes),
51-
conv2=nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
52-
padding=1, bias=False),
53-
bn2=nn.BatchNorm2d(planes),
54-
conv3=nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False),
55-
bn3=nn.BatchNorm2d(planes * 4),
56-
relu=nn.ReLU(inplace=True),
57-
downsample=downsample,
58-
)
47+
super(Bottleneck, self).__init__()
48+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
49+
self.bn1 = nn.BatchNorm2d(planes)
50+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
51+
padding=1, bias=False)
52+
self.bn2 = nn.BatchNorm2d(planes)
53+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
54+
self.bn3 = nn.BatchNorm2d(planes * 4)
55+
self.relu = nn.ReLU(inplace=True)
56+
self.downsample = downsample
5957
self.stride = stride
6058

6159
def forward(self, x):
@@ -84,19 +82,19 @@ def forward(self, x):
8482
class ResNet(nn.Container):
8583
def __init__(self, block, layers):
8684
self.inplanes = 64
87-
super(ResNet, self).__init__(
88-
conv1=nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
89-
bias=False),
90-
bn1=nn.BatchNorm2d(64),
91-
relu=nn.ReLU(inplace=True),
92-
maxpool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
93-
layer1=self._make_layer(block, 64, layers[0]),
94-
layer2=self._make_layer(block, 128, layers[1], stride=2),
95-
layer3=self._make_layer(block, 256, layers[2], stride=2),
96-
layer4=self._make_layer(block, 512, layers[3], stride=2),
97-
avgpool=nn.AvgPool2d(7),
98-
fc=nn.Linear(512 * block.expansion, 1000),
99-
)
85+
super(ResNet, self).__init__()
86+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
87+
bias=False)
88+
self.bn1 = nn.BatchNorm2d(64)
89+
self.relu = nn.ReLU(inplace=True)
90+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
91+
self.layer1 = self._make_layer(block, 64, layers[0])
92+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
93+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
94+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
95+
self.avgpool = nn.AvgPool2d(7)
96+
self.fc = nn.Linear(512 * block.expansion, 1000)
97+
10098
for m in self.modules():
10199
if isinstance(m, nn.Conv2d):
102100
n = m.kw * m.kh * m.out_channels

0 commit comments

Comments
 (0)