From 5de71ea2a4a5108289e8869cce8e11cf3ea7e869 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Mon, 17 Oct 2016 20:43:03 -0700 Subject: [PATCH] Assign child modules via attributes --- imagenet/resnet.py | 62 ++++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/imagenet/resnet.py b/imagenet/resnet.py index d5827db878..d647dd2c28 100644 --- a/imagenet/resnet.py +++ b/imagenet/resnet.py @@ -12,14 +12,13 @@ class BasicBlock(nn.Container): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): - super(BasicBlock, self).__init__( - conv1=conv3x3(inplanes, planes, stride), - bn1=nn.BatchNorm2d(planes), - relu=nn.ReLU(inplace=True), - conv2=conv3x3(planes, planes), - bn2=nn.BatchNorm2d(planes), - downsample=downsample, - ) + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample self.stride = stride def forward(self, x): @@ -45,17 +44,16 @@ class Bottleneck(nn.Container): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): - super(Bottleneck, self).__init__( - conv1=nn.Conv2d(inplanes, planes, kernel_size=1, bias=False), - bn1=nn.BatchNorm2d(planes), - conv2=nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False), - bn2=nn.BatchNorm2d(planes), - conv3=nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False), - bn3=nn.BatchNorm2d(planes * 4), - relu=nn.ReLU(inplace=True), - downsample=downsample, - ) + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample self.stride = stride def forward(self, x): @@ -84,19 +82,19 @@ def forward(self, x): class ResNet(nn.Container): def __init__(self, block, layers): self.inplanes = 64 - super(ResNet, self).__init__( - conv1=nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, - bias=False), - bn1=nn.BatchNorm2d(64), - relu=nn.ReLU(inplace=True), - maxpool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1), - layer1=self._make_layer(block, 64, layers[0]), - layer2=self._make_layer(block, 128, layers[1], stride=2), - layer3=self._make_layer(block, 256, layers[2], stride=2), - layer4=self._make_layer(block, 512, layers[3], stride=2), - avgpool=nn.AvgPool2d(7), - fc=nn.Linear(512 * block.expansion, 1000), - ) + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc = nn.Linear(512 * block.expansion, 1000) + for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kw * m.kh * m.out_channels