Skip to content

Commit 7a5b258

Browse files
lllchozsdonghao
authored andcommitted
Enable skip biases in Conv3dLayer, the same as beta and gamma in BN Layer (#421)
* update by lllcho on March 15 * update logging * b_init in c3d can None and gamma/beta in BN layer can skip * fix some comments * add comments in bn layer
1 parent ba71d18 commit 7a5b258

File tree

2 files changed

+52
-30
lines changed

2 files changed

+52
-30
lines changed

tensorlayer/layers/convolution.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ class Conv2dLayer(Layer):
123123
padding : str
124124
The padding algorithm type: "SAME" or "VALID".
125125
W_init : initializer
126-
The initializer for the the weight matrix.
126+
The initializer for the weight matrix.
127127
b_init : initializer or None
128-
The initializer for the the bias vector. If None, skip biases.
128+
The initializer for the bias vector. If None, skip biases.
129129
W_init_args : dictionary
130130
The arguments for the weight matrix initializer.
131131
b_init_args : dictionary
@@ -358,8 +358,8 @@ class Conv3dLayer(Layer):
358358
The padding algorithm type: "SAME" or "VALID".
359359
W_init : initializer
360360
The initializer for the weight matrix.
361-
b_init : initializer
362-
The initializer for the bias vector.
361+
b_init : initializer or None
362+
The initializer for the bias vector. If None, skip biases.
363363
W_init_args : dictionary
364364
The arguments for the weight matrix initializer.
365365
b_init_args : dictionary
@@ -403,16 +403,22 @@ def __init__(
403403
# W = tf.Variable(W_init(shape=shape, **W_init_args), name='W_conv')
404404
# b = tf.Variable(b_init(shape=[shape[-1]], **b_init_args), name='b_conv')
405405
W = tf.get_variable(name='W_conv3d', shape=shape, initializer=W_init, dtype=LayersConfig.tf_dtype, **W_init_args)
406-
b = tf.get_variable(name='b_conv3d', shape=(shape[-1]), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args)
407-
self.outputs = act(tf.nn.conv3d(self.inputs, W, strides=strides, padding=padding, name=None) + b)
406+
if b_init:
407+
b = tf.get_variable(name='b_conv3d', shape=(shape[-1]), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args)
408+
self.outputs = act(tf.nn.conv3d(self.inputs, W, strides=strides, padding=padding, name=None) + b)
409+
else:
410+
self.outputs = act(tf.nn.conv3d(self.inputs, W, strides=strides, padding=padding, name=None))
408411

409412
# self.outputs = act( tf.nn.conv3d(self.inputs, W, strides=strides, padding=padding, name=None) + b )
410413

411414
# self.all_layers = list(layer.all_layers)
412415
# self.all_params = list(layer.all_params)
413416
# self.all_drop = dict(layer.all_drop)
414417
self.all_layers.append(self.outputs)
415-
self.all_params.extend([W, b])
418+
if b_init:
419+
self.all_params.extend([W, b])
420+
else:
421+
self.all_params.extend([W])
416422

417423

418424
class DeConv3dLayer(Layer):
@@ -435,8 +441,8 @@ class DeConv3dLayer(Layer):
435441
The padding algorithm type: "SAME" or "VALID".
436442
W_init : initializer
437443
The initializer for the weight matrix.
438-
b_init : initializer
439-
The initializer for the bias vector.
444+
b_init : initializer or None
445+
The initializer for the bias vector. If None, skip biases.
440446
W_init_args : dictionary
441447
The arguments for the weight matrix initializer.
442448
b_init_args : dictionary
@@ -474,15 +480,20 @@ def __init__(
474480

475481
with tf.variable_scope(name):
476482
W = tf.get_variable(name='W_deconv3d', shape=shape, initializer=W_init, dtype=LayersConfig.tf_dtype, **W_init_args)
477-
b = tf.get_variable(name='b_deconv3d', shape=(shape[-2]), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args)
478-
479-
self.outputs = act(tf.nn.conv3d_transpose(self.inputs, W, output_shape=output_shape, strides=strides, padding=padding) + b)
483+
if b_init:
484+
b = tf.get_variable(name='b_deconv3d', shape=(shape[-2]), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args)
485+
self.outputs = act(tf.nn.conv3d_transpose(self.inputs, W, output_shape=output_shape, strides=strides, padding=padding) + b)
486+
else:
487+
self.outputs = act(tf.nn.conv3d_transpose(self.inputs, W, output_shape=output_shape, strides=strides, padding=padding))
480488

481489
# self.all_layers = list(layer.all_layers)
482490
# self.all_params = list(layer.all_params)
483491
# self.all_drop = dict(layer.all_drop)
484492
self.all_layers.append(self.outputs)
485-
self.all_params.extend([W, b])
493+
if b_init:
494+
self.all_params.extend([W, b])
495+
else:
496+
self.all_params.extend([W])
486497

487498

488499
class UpSampling2dLayer(Layer):

tensorlayer/layers/normalization.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,13 @@ class BatchNormLayer(Layer):
7575
The activation function of this layer.
7676
is_train : boolean
7777
Is being used for training or inference.
78-
beta_init : initializer
79-
The initializer for initializing beta.
80-
gamma_init : initializer
81-
The initializer for initializing gamma.
78+
beta_init : initializer or None
79+
The initializer for initializing beta, if None, skip beta.
80+
Usually you should not skip beta unless you know what happened.
81+
gamma_init : initializer or None
82+
The initializer for initializing gamma, if None, skip gamma.
83+
When the batch normalization layer is use instead of 'biases', or the next layer is linear, this can be
84+
disabled since the scaling can be done by the next layer. see `Inception-ResNet-v2 <https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py>`__
8285
dtype : TensorFlow dtype
8386
tf.float32 (default) or tf.float16.
8487
name : str
@@ -112,19 +115,27 @@ def __init__(
112115

113116
with tf.variable_scope(name):
114117
axis = list(range(len(x_shape) - 1))
115-
116118
# 1. beta, gamma
117-
if tf.__version__ > '0.12.1' and beta_init == tf.zeros_initializer:
118-
beta_init = beta_init()
119-
beta = tf.get_variable('beta', shape=params_shape, initializer=beta_init, dtype=LayersConfig.tf_dtype, trainable=is_train)
120-
121-
gamma = tf.get_variable(
122-
'gamma',
123-
shape=params_shape,
124-
initializer=gamma_init,
125-
dtype=LayersConfig.tf_dtype,
126-
trainable=is_train,
127-
)
119+
variables = []
120+
if beta_init:
121+
if tf.__version__ > '0.12.1' and beta_init == tf.zeros_initializer:
122+
beta_init = beta_init()
123+
beta = tf.get_variable('beta', shape=params_shape, initializer=beta_init, dtype=LayersConfig.tf_dtype, trainable=is_train)
124+
variables.append(beta)
125+
else:
126+
beta = None
127+
128+
if gamma_init:
129+
gamma = tf.get_variable(
130+
'gamma',
131+
shape=params_shape,
132+
initializer=gamma_init,
133+
dtype=LayersConfig.tf_dtype,
134+
trainable=is_train,
135+
)
136+
variables.append(gamma)
137+
else:
138+
gamma = None
128139

129140
# 2.
130141
if tf.__version__ > '0.12.1':
@@ -163,7 +174,7 @@ def mean_var_with_update():
163174
else:
164175
self.outputs = act(tf.nn.batch_normalization(self.inputs, moving_mean, moving_variance, beta, gamma, epsilon))
165176

166-
variables = [beta, gamma, moving_mean, moving_variance]
177+
variables.extend([moving_mean, moving_variance])
167178

168179
# logging.info(len(variables))
169180
# for idx, v in enumerate(variables):

0 commit comments

Comments
 (0)