|
1 |
| -""" |
2 |
| -helper class that supports empty tensors on some nn functions. |
3 |
| -
|
4 |
| -Ideally, add support directly in PyTorch to empty tensors in |
5 |
| -those functions. |
6 |
| -
|
7 |
| -This can be removed once https://github.com/pytorch/pytorch/issues/12013 |
8 |
| -is implemented |
9 |
| -""" |
10 |
| - |
11 | 1 | import warnings
|
12 | 2 | from typing import Callable, List, Optional
|
13 | 3 |
|
@@ -53,8 +43,11 @@ def __init__(self, *args, **kwargs):
|
53 | 43 | # This is not in nn
|
54 | 44 | class FrozenBatchNorm2d(torch.nn.Module):
|
55 | 45 | """
|
56 |
| - BatchNorm2d where the batch statistics and the affine parameters |
57 |
| - are fixed |
| 46 | + BatchNorm2d where the batch statistics and the affine parameters are fixed |
| 47 | +
|
| 48 | + Args: |
| 49 | + num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)`` |
| 50 | + eps (float): a value added to the denominator for numerical stability. Default: 1e-5 |
58 | 51 | """
|
59 | 52 |
|
60 | 53 | def __init__(
|
@@ -109,6 +102,23 @@ def __repr__(self) -> str:
|
109 | 102 |
|
110 | 103 |
|
111 | 104 | class ConvNormActivation(torch.nn.Sequential):
|
| 105 | + """ |
| 106 | + Configurable block used for Convolution-Normalzation-Activation blocks. |
| 107 | +
|
| 108 | + Args: |
| 109 | + in_channels (int): Number of channels in the input image |
| 110 | + out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block |
| 111 | + kernel_size: (int, optional): Size of the convolving kernel. Default: 3 |
| 112 | + stride (int, optional): Stride of the convolution. Default: 1 |
| 113 | + padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` |
| 114 | + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 |
| 115 | + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` |
| 116 | + activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` |
| 117 | + dilation (int): Spacing between kernel elements. Default: 1 |
| 118 | + inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` |
| 119 | +
|
| 120 | + """ |
| 121 | + |
112 | 122 | def __init__(
|
113 | 123 | self,
|
114 | 124 | in_channels: int,
|
@@ -146,6 +156,17 @@ def __init__(
|
146 | 156 |
|
147 | 157 |
|
148 | 158 | class SqueezeExcitation(torch.nn.Module):
|
| 159 | + """ |
| 160 | + This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). |
| 161 | + Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3. |
| 162 | +
|
| 163 | + Args: |
| 164 | + input_channels (int): Number of channels in the input image |
| 165 | + squeeze_channels (int): Number of squeeze channels |
| 166 | + activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU`` |
| 167 | + scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid`` |
| 168 | + """ |
| 169 | + |
149 | 170 | def __init__(
|
150 | 171 | self,
|
151 | 172 | input_channels: int,
|
|
0 commit comments