From 00c3efe925373ac6d46b71083f4651d17c3d7198 Mon Sep 17 00:00:00 2001 From: ajava Date: Sat, 2 May 2020 17:13:29 +0200 Subject: [PATCH] Making ASPP-Layer in DeepLab more generic At the moment in the ASPP-Layer the number of output channels are predefined as a constant, which is good for DeepLab but not necessairly in other projects, where another out-channel Nr. is required. Also the number of "atrous rates" is fixed to three, which also could be sometimes more or less depending on the notwork-arch. Again these fixed values may make sense in DeepLab-Model but not necessarily in other type of models. This pull-req. contains the needed changes to make ASPP-Layer generic. --- torchvision/models/segmentation/deeplabv3.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index ae652cd7d2a..ee5c0c7fe64 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -63,19 +63,18 @@ def forward(self, x): class ASPP(nn.Module): - def __init__(self, in_channels, atrous_rates): + def __init__(self, in_channels, atrous_rates, out_channels=256): super(ASPP, self).__init__() - out_channels = 256 modules = [] modules.append(nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())) - rate1, rate2, rate3 = tuple(atrous_rates) - modules.append(ASPPConv(in_channels, out_channels, rate1)) - modules.append(ASPPConv(in_channels, out_channels, rate2)) - modules.append(ASPPConv(in_channels, out_channels, rate3)) + rates = tuple(atrous_rates) + for rate in rates: + modules.append(ASPPConv(in_channels, out_channels, rate)) + modules.append(ASPPPooling(in_channels, out_channels)) self.convs = nn.ModuleList(modules)