diff --git a/tensorlayer/layers/normalization.py b/tensorlayer/layers/normalization.py index 39db51159..dd7cc2189 100644 --- a/tensorlayer/layers/normalization.py +++ b/tensorlayer/layers/normalization.py @@ -291,9 +291,9 @@ def forward(self, inputs): if self.axes is None: self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis] - mean, var = tf.nn.moments(inputs, self.axes, keepdims=False) if self.is_train: # update moving_mean and moving_var + mean, var = tf.nn.moments(inputs, self.axes, keepdims=False) self.moving_mean = moving_averages.assign_moving_average( self.moving_mean, mean, self.decay, zero_debias=False )