Skip to content

Commit 6e77d57

Browse files
authored
5422 update attentionunet parameters (#5423)
Signed-off-by: Wenqi Li <[email protected]> Fixes #5422 ### Description the kernel size and strides are hard-coded: https://github.com/Project-MONAI/MONAI/blob/a209b06438343830e561a0afd41b1025516a8977/monai/networks/nets/attentionunet.py#L151 this PR makes the values tunable. `level` parameter is not used and removed in this PR. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]>
1 parent a209b06 commit 6e77d57

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

monai/networks/nets/attentionunet.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,27 @@ def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
143143

144144

145145
class AttentionLayer(nn.Module):
146-
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0):
146+
def __init__(
147+
self,
148+
spatial_dims: int,
149+
in_channels: int,
150+
out_channels: int,
151+
submodule: nn.Module,
152+
up_kernel_size=3,
153+
strides=2,
154+
dropout=0.0,
155+
):
147156
super().__init__()
148157
self.attention = AttentionBlock(
149158
spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2
150159
)
151-
self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2)
160+
self.upconv = UpConv(
161+
spatial_dims=spatial_dims,
162+
in_channels=out_channels,
163+
out_channels=in_channels,
164+
strides=strides,
165+
kernel_size=up_kernel_size,
166+
)
152167
self.merge = Convolution(
153168
spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout
154169
)
@@ -174,7 +189,7 @@ class AttentionUnet(nn.Module):
174189
channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2.
175190
strides (Sequence[int]): stride to use for convolutions.
176191
kernel_size: convolution kernel size.
177-
upsample_kernel_size: convolution kernel size for transposed convolution layers.
192+
up_kernel_size: convolution kernel size for transposed convolution layers.
178193
dropout: dropout ratio. Defaults to no dropout.
179194
"""
180195

@@ -210,9 +225,9 @@ def __init__(
210225
)
211226
self.up_kernel_size = up_kernel_size
212227

213-
def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = 0) -> nn.Module:
228+
def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module:
214229
if len(channels) > 2:
215-
subblock = _create_block(channels[1:], strides[1:], level=level + 1)
230+
subblock = _create_block(channels[1:], strides[1:])
216231
return AttentionLayer(
217232
spatial_dims=spatial_dims,
218233
in_channels=channels[0],
@@ -227,17 +242,19 @@ def _create_block(channels: Sequence[int], strides: Sequence[int], level: int =
227242
),
228243
subblock,
229244
),
245+
up_kernel_size=self.up_kernel_size,
246+
strides=strides[0],
230247
dropout=dropout,
231248
)
232249
else:
233250
# the next layer is the bottom so stop recursion,
234-
# create the bottom layer as the sublock for this layer
235-
return self._get_bottom_layer(channels[0], channels[1], strides[0], level=level + 1)
251+
# create the bottom layer as the subblock for this layer
252+
return self._get_bottom_layer(channels[0], channels[1], strides[0])
236253

237254
encdec = _create_block(self.channels, self.strides)
238255
self.model = nn.Sequential(head, encdec, reduce_channels)
239256

240-
def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, level: int) -> nn.Module:
257+
def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -> nn.Module:
241258
return AttentionLayer(
242259
spatial_dims=self.dimensions,
243260
in_channels=in_channels,
@@ -249,6 +266,8 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, l
249266
strides=strides,
250267
dropout=self.dropout,
251268
),
269+
up_kernel_size=self.up_kernel_size,
270+
strides=strides,
252271
dropout=self.dropout,
253272
)
254273

tests/test_attentionunet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_attentionunet(self):
3939
shape = (3, 1) + (92,) * dims
4040
input = torch.rand(*shape)
4141
model = att.AttentionUnet(
42-
spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2)
42+
spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), up_kernel_size=5, strides=(1, 2)
4343
)
4444
output = model(input)
4545
self.assertEqual(output.shape[2:], input.shape[2:])

0 commit comments

Comments
 (0)