diff --git a/tensorlayer/layers/normalization.py b/tensorlayer/layers/normalization.py index 7df44c2e4..161d6e018 100644 --- a/tensorlayer/layers/normalization.py +++ b/tensorlayer/layers/normalization.py @@ -226,7 +226,6 @@ def __init__( self.moving_var_init = moving_var_init self.num_features = num_features - self.channel_axis = -1 if data_format == 'channels_last' else 1 self.axes = None if num_features is not None: @@ -288,6 +287,7 @@ def build(self, inputs_shape): def forward(self, inputs): self._check_input_shape(inputs) + self.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1 if self.axes is None: self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]