We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6c12921 commit 2bc1e81Copy full SHA for 2bc1e81
torchvision/ops/_utils.py
@@ -3,6 +3,8 @@
3
import torch
4
from torch import nn, Tensor
5
6
+from .misc import FrozenBatchNorm2d
7
+
8
9
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
10
"""
@@ -43,7 +45,13 @@ def split_normalization_params(
43
45
) -> Tuple[List[Tensor], List[Tensor]]:
44
46
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
47
if not norm_classes:
- norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]
48
+ norm_classes = [
49
+ nn.modules.batchnorm._BatchNorm,
50
+ nn.LayerNorm,
51
+ nn.GroupNorm,
52
+ nn.modules.instancenorm._InstanceNorm,
53
+ nn.LocalResponseNorm,
54
+ ]
55
56
for t in norm_classes:
57
if not issubclass(t, nn.Module):
0 commit comments