Skip to content

Commit bd27e94

Browse files
authored
Making ASPP-Layer in DeepLab more generic (#2174)
At the moment in the ASPP-Layer the number of output channels are predefined as a constant, which is good for DeepLab but not necessairly in other projects, where another out-channel Nr. is required. Also the number of "atrous rates" is fixed to three, which also could be sometimes more or less depending on the notwork-arch. Again these fixed values may make sense in DeepLab-Model but not necessarily in other type of models. This pull-req. contains the needed changes to make ASPP-Layer generic.
1 parent 1a40d9c commit bd27e94

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

torchvision/models/segmentation/deeplabv3.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,18 @@ def forward(self, x):
6363

6464

6565
class ASPP(nn.Module):
66-
def __init__(self, in_channels, atrous_rates):
66+
def __init__(self, in_channels, atrous_rates, out_channels=256):
6767
super(ASPP, self).__init__()
68-
out_channels = 256
6968
modules = []
7069
modules.append(nn.Sequential(
7170
nn.Conv2d(in_channels, out_channels, 1, bias=False),
7271
nn.BatchNorm2d(out_channels),
7372
nn.ReLU()))
7473

75-
rate1, rate2, rate3 = tuple(atrous_rates)
76-
modules.append(ASPPConv(in_channels, out_channels, rate1))
77-
modules.append(ASPPConv(in_channels, out_channels, rate2))
78-
modules.append(ASPPConv(in_channels, out_channels, rate3))
74+
rates = tuple(atrous_rates)
75+
for rate in rates:
76+
modules.append(ASPPConv(in_channels, out_channels, rate))
77+
7978
modules.append(ASPPPooling(in_channels, out_channels))
8079

8180
self.convs = nn.ModuleList(modules)

0 commit comments

Comments
 (0)