Skip to content

Commit b1de0c4

Browse files
authored
Update conv1d (#499)
* update conv1d with tf.layers * update conv1d fun --> class * yapf * fix super
1 parent c226cb3 commit b1de0c4

File tree

2 files changed

+67
-51
lines changed

2 files changed

+67
-51
lines changed

docs/modules/layers.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ APIs may better for you.
484484

485485
1D Convolution
486486
^^^^^^^^^^^^^^^^^^^^^^^
487-
.. autofunction:: Conv1d
487+
.. autoclass:: Conv1d
488488

489489
2D Convolution
490490
^^^^^^^^^^^^^^^^^^^^^^^

tensorlayer/layers/convolution.py

Lines changed: 66 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -85,24 +85,21 @@ def __init__(
8585

8686
if act is None:
8787
act = tf.identity
88+
8889
if W_init_args is None:
8990
W_init_args = {}
9091
if b_init_args is None:
9192
b_init_args = {}
9293

9394
with tf.variable_scope(name):
9495
W = tf.get_variable(name='W_conv1d', shape=shape, initializer=W_init, dtype=LayersConfig.tf_dtype, **W_init_args)
95-
self.outputs = tf.nn.convolution(
96-
self.inputs, W, strides=(stride, ), padding=padding, dilation_rate=(dilation_rate, ), data_format=data_format) # 1.2
96+
self.outputs = tf.nn.convolution(self.inputs, W, strides=(stride, ), padding=padding, dilation_rate=(dilation_rate, )) # 1.2
9797
if b_init:
9898
b = tf.get_variable(name='b_conv1d', shape=(shape[-1]), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args)
9999
self.outputs = self.outputs + b
100100

101101
self.outputs = act(self.outputs)
102102

103-
# self.all_layers = list(layer.all_layers)
104-
# self.all_params = list(layer.all_params)
105-
# self.all_drop = dict(layer.all_drop)
106103
self.all_layers.append(self.outputs)
107104
if b_init:
108105
self.all_params.extend([W, b])
@@ -1260,22 +1257,7 @@ def deconv2d_bilinear_upsampling_initializer(shape):
12601257
return bilinear_weights_init
12611258

12621259

1263-
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
1264-
def conv1d(
1265-
prev_layer,
1266-
n_filter=32,
1267-
filter_size=5,
1268-
stride=1,
1269-
dilation_rate=1,
1270-
act=tf.identity,
1271-
padding='SAME',
1272-
data_format="NWC",
1273-
W_init=tf.truncated_normal_initializer(stddev=0.02),
1274-
b_init=tf.constant_initializer(value=0.0),
1275-
W_init_args=None,
1276-
b_init_args=None,
1277-
name='conv1d',
1278-
):
1260+
class Conv1d(Layer):
12791261
"""Simplified version of :class:`Conv1dLayer`.
12801262
12811263
Parameters
@@ -1301,17 +1283,12 @@ def conv1d(
13011283
b_init : initializer or None
13021284
The initializer for the bias vector. If None, skip biases.
13031285
W_init_args : dictionary
1304-
The arguments for the weight matrix initializer.
1286+
The arguments for the weight matrix initializer (deprecated).
13051287
b_init_args : dictionary
1306-
The arguments for the bias vector initializer.
1288+
The arguments for the bias vector initializer (deprecated).
13071289
name : str
13081290
A unique layer name
13091291
1310-
Returns
1311-
-------
1312-
:class:`Layer`
1313-
A :class:`Conv1dLayer` object.
1314-
13151292
Examples
13161293
---------
13171294
>>> x = tf.placeholder(tf.float32, (batch_size, width))
@@ -1331,25 +1308,67 @@ def conv1d(
13311308
13321309
"""
13331310

1334-
if W_init_args is None:
1335-
W_init_args = {}
1336-
if b_init_args is None:
1337-
b_init_args = {}
1311+
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
1312+
def __init__(self,
1313+
prev_layer,
1314+
n_filter=32,
1315+
filter_size=5,
1316+
stride=1,
1317+
dilation_rate=1,
1318+
act=tf.identity,
1319+
padding='SAME',
1320+
data_format="channels_last",
1321+
W_init=tf.truncated_normal_initializer(stddev=0.02),
1322+
b_init=tf.constant_initializer(value=0.0),
1323+
W_init_args=None,
1324+
b_init_args=None,
1325+
name='conv1d'):
13381326

1339-
return Conv1dLayer(
1340-
prev_layer=prev_layer,
1341-
act=act,
1342-
shape=(filter_size, int(prev_layer.outputs.get_shape()[-1]), n_filter),
1343-
stride=stride,
1344-
dilation_rate=dilation_rate,
1345-
padding=padding,
1346-
data_format=data_format,
1347-
W_init=W_init,
1348-
b_init=b_init,
1349-
W_init_args=W_init_args,
1350-
b_init_args=b_init_args,
1351-
name=name,
1352-
)
1327+
super(Conv1d, self).__init__(prev_layer=prev_layer, name=name)
1328+
logging.info("Conv1d %s: n_filter:%d filter_size:%s stride:%d pad:%s act:%s dilation_rate:%d" % (name, n_filter, filter_size, stride, padding,
1329+
act.__name__, dilation_rate))
1330+
1331+
self.inputs = prev_layer.outputs
1332+
if tf.__version__ > '1.3':
1333+
con1d = tf.layers.Conv1D(
1334+
filters=n_filter,
1335+
kernel_size=filter_size,
1336+
strides=stride,
1337+
padding=padding,
1338+
data_format=data_format,
1339+
dilation_rate=dilation_rate,
1340+
activation=act,
1341+
use_bias=(True if b_init else False),
1342+
kernel_initializer=W_init,
1343+
bias_initializer=b_init,
1344+
name=name)
1345+
# con1d.dtype = LayersConfig.tf_dtype # unsupport, it will use the same dtype of inputs
1346+
self.outputs = con1d(self.inputs)
1347+
new_variables = con1d.weights # new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
1348+
self.all_layers.append(self.outputs)
1349+
self.all_params.extend(new_variables)
1350+
else:
1351+
raise RuntimeError("please update TF > 1.3 or downgrade TL < 1.8.4")
1352+
# if W_init_args is None:
1353+
# W_init_args = {}
1354+
# if b_init_args is None:
1355+
# b_init_args = {}
1356+
# data_format='HWC'
1357+
# return Conv1dLayer(
1358+
1359+
# prev_layer=prev_layer,
1360+
# act=act,
1361+
# shape=(filter_size, int(prev_layer.outputs.get_shape()[-1]), n_filter),
1362+
# stride=stride,
1363+
# dilation_rate=dilation_rate,
1364+
# padding=padding,
1365+
# data_format=data_format,
1366+
# W_init=W_init,
1367+
# b_init=b_init,
1368+
# W_init_args=W_init_args,
1369+
# b_init_args=b_init_args,
1370+
# name=name,
1371+
# )
13531372

13541373

13551374
# TODO: DeConv1d
@@ -1682,9 +1701,6 @@ def __init__(self,
16821701
)
16831702
new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
16841703

1685-
# self.all_layers = list(layer.all_layers)
1686-
# self.all_params = list(layer.all_params)
1687-
# self.all_drop = dict(layer.all_drop)
16881704
self.all_layers.append(self.outputs)
16891705
self.all_params.extend(new_variables)
16901706

@@ -2010,6 +2026,6 @@ def __init__(
20102026

20112027
# Alias
20122028
AtrousConv1dLayer = atrous_conv1d
2013-
Conv1d = conv1d
2029+
# Conv1d = conv1d
20142030
# Conv2d = conv2d
20152031
# DeConv2d = deconv2d

0 commit comments

Comments
 (0)