Skip to content

Added support for arbitrary tensors for FRN. #1496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
83 changes: 49 additions & 34 deletions tensorflow_addons/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import tensorflow as tf
from typeguard import typechecked
from typing import Union

from tensorflow_addons.utils import types

Expand Down Expand Up @@ -242,8 +243,8 @@ def _check_axis(self):

if self.axis == 0:
raise ValueError(
"You are trying to normalize your batch axis. Do you want to "
"use tf.layer.batch_normalization instead"
"You are trying to normalize your batch axis. "
"Use tf.layers.batch_normalization instead."
)

def _create_input_spec(self, input_shape):
Expand Down Expand Up @@ -354,6 +355,7 @@ class FilterResponseNormalization(tf.keras.layers.Layer):
Arguments
axis: List of axes that should be normalized. This should represent the
spatial dimensions.
channel_idx: Index of the channel axis in `input_shape`.
epsilon: Small positive float value added to variance to avoid dividing by zero.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
Expand All @@ -368,10 +370,10 @@ class FilterResponseNormalization(tf.keras.layers.Layer):
Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model. This layer, as of now,
works on a 4-D tensor where the tensor should have the shape [N X H X W X C]

TODO: Add support for NCHW data format and FC layers.
when using this layer as the first layer in a model. This layer supports
arbitrary tensors with the following assumptions:
- Expected input tensor to be at least 3D.
- 0th index in tensor shape is expected to be the batch dimension.

Output shape
Same shape as input.
Expand All @@ -385,7 +387,8 @@ class FilterResponseNormalization(tf.keras.layers.Layer):
def __init__(
self,
epsilon: float = 1e-6,
axis: list = [1, 2],
axis: Union[int, list] = [1, 2],
channel_idx: int = -1,
beta_initializer: types.Initializer = "zeros",
gamma_initializer: types.Initializer = "ones",
beta_regularizer: types.Regularizer = None,
Expand All @@ -398,6 +401,7 @@ def __init__(
**kwargs,
):
super().__init__(name=name, **kwargs)
self.channel_idx = channel_idx
self.epsilon = epsilon
self.beta_initializer = tf.keras.initializers.get(beta_initializer)
self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
Expand Down Expand Up @@ -428,11 +432,6 @@ def __init__(
self._check_axis(axis)

def build(self, input_shape):
if len(tf.TensorShape(input_shape)) != 4:
raise ValueError(
"""Only 4-D tensors (CNNs) are supported
as of now."""
)
self._check_if_input_shape_is_none(input_shape)
self._create_input_spec(input_shape)
self._add_gamma_weight(input_shape)
Expand All @@ -453,6 +452,7 @@ def compute_output_shape(self, input_shape):
def get_config(self):
config = {
"axis": self.axis,
"channel_idx": self.channel_idx,
"epsilon": self.epsilon,
"learned_epsilon": self.use_eps_learned,
"beta_initializer": tf.keras.initializers.serialize(self.beta_initializer),
Expand All @@ -473,39 +473,52 @@ def get_config(self):
return dict(**base_config, **config)

def _create_input_spec(self, input_shape):
ndims = len(tf.TensorShape(input_shape))
ndims = len(input_shape)

for idx, x in enumerate(self.axis):
if x < 0:
self.axis[idx] = ndims + x

# Validate axes
for x in self.axis:
if x < 0 or x >= ndims:
if x < 0 or x >= len(input_shape):
raise ValueError("Invalid axis: %d" % x)

if len(self.axis) != len(set(self.axis)):
raise ValueError("Duplicate axis: %s" % self.axis)
elif x == 0:
raise ValueError(
"You are trying to normalize your batch axis. "
"Use tf.layers.batch_normalization instead."
)

elif x == self.channel_idx:
raise ValueError(
"You are trying to normalize over your channel axis. Expected spatial dimensions."
)

axis_to_dim = {x: input_shape[x] for x in self.axis}
self.input_spec = tf.keras.layers.InputSpec(ndim=ndims, axes=axis_to_dim)

def _check_axis(self, axis):
if not isinstance(axis, list):
raise TypeError(
"""Expected a list of values but got {}.""".format(type(axis))
)
else:
if isinstance(axis, list):
self.axis = axis

if self.axis != [1, 2]:
raise ValueError(
"""FilterResponseNormalization operates on per-channel basis.
Axis values should be a list of spatial dimensions."""
elif isinstance(axis, int):
self.axis = [axis]

else:
raise TypeError(
"""Expected a list of values or int but got {}.""".format(type(axis))
)

if len(self.axis) != len(set(self.axis)):
raise ValueError("Duplicate axis: %s" % self.axis)

def _check_if_input_shape_is_none(self, input_shape):
dim1, dim2 = input_shape[self.axis[0]], input_shape[self.axis[1]]
if dim1 is None or dim2 is None:
dims = [input_shape[i] for i in self.axis]

if len(input_shape) < 3:
raise ValueError("Expected input tensor to be at least 3D.")

if None in dims:
raise ValueError(
"""Axis {} of input tensor should have a defined dimension but
the layer received an input with shape {}.""".format(
Expand All @@ -515,9 +528,10 @@ def _check_if_input_shape_is_none(self, input_shape):

def _add_gamma_weight(self, input_shape):
# Get the channel dimension
dim = input_shape[-1]
shape = [1, 1, 1, dim]
# Initialize gamma with shape (1, 1, 1, C)
dim = input_shape[self.channel_idx]
shape = [1] * len(input_shape)
shape[self.channel_idx] = dim
# Initialize the shape of gamma with ones except the channel axis
self.gamma = self.add_weight(
shape=shape,
name="gamma",
Expand All @@ -529,9 +543,10 @@ def _add_gamma_weight(self, input_shape):

def _add_beta_weight(self, input_shape):
# Get the channel dimension
dim = input_shape[-1]
shape = [1, 1, 1, dim]
# Initialize beta with shape (1, 1, 1, C)
dim = input_shape[self.channel_idx]
shape = [1] * len(input_shape)
shape[self.channel_idx] = dim
# Initialize the shape of beta with ones except the channel axis
self.beta = self.add_weight(
shape=shape,
name="beta",
Expand Down
Loading