Skip to content

Commit 2bc1e81

Browse files
committed
Adding more norm layers in split_normalization_params.
1 parent 6c12921 commit 2bc1e81

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

torchvision/ops/_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch
44
from torch import nn, Tensor
55

6+
from .misc import FrozenBatchNorm2d
7+
68

79
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
810
"""
@@ -43,7 +45,13 @@ def split_normalization_params(
4345
) -> Tuple[List[Tensor], List[Tensor]]:
4446
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
4547
if not norm_classes:
46-
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]
48+
norm_classes = [
49+
nn.modules.batchnorm._BatchNorm,
50+
nn.LayerNorm,
51+
nn.GroupNorm,
52+
nn.modules.instancenorm._InstanceNorm,
53+
nn.LocalResponseNorm,
54+
]
4755

4856
for t in norm_classes:
4957
if not issubclass(t, nn.Module):

0 commit comments

Comments
 (0)