diff --git a/CHANGELOG.md b/CHANGELOG.md index eca64385a..e932d4a6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,6 +75,8 @@ To release a new version, please update the changelog as followed: ### Changed - remove `tl.layers.initialize_global_variables(sess)` (PR #931) +- change `tl.layers.core`, `tl.models.core` (PR #966) + - change `weights` into `all_weights`, `trainable_weights`, `nontrainable_weights` ### Dependencies Update - nltk>=3.3,<3.4 => nltk>=3.3,<3.5 (PR #892) diff --git a/docs/modules/files.rst b/docs/modules/files.rst index 833bcf5ef..c823b4db9 100644 --- a/docs/modules/files.rst +++ b/docs/modules/files.rst @@ -142,14 +142,14 @@ sake of cross-platform. Other file formats such as ``.npz`` are also available. .. code-block:: python ## save model as .h5 - tl.files.save_weights_to_hdf5('model.h5', network.weights) + tl.files.save_weights_to_hdf5('model.h5', network.all_weights) # restore model from .h5 (in order) - tl.files.load_hdf5_to_weights_in_order('model.h5', network.weights) + tl.files.load_hdf5_to_weights_in_order('model.h5', network.all_weights) # restore model from .h5 (by name) - tl.files.load_hdf5_to_weights('model.h5', network.weights) + tl.files.load_hdf5_to_weights('model.h5', network.all_weights) ## save model as .npz - tl.files.save_npz(network.weights , name='model.npz') + tl.files.save_npz(network.all_weights , name='model.npz') # restore model from .npz (method 1) load_params = tl.files.load_npz(name='model.npz') tl.files.assign_weights(sess, load_params, network) diff --git a/docs/user/faq.rst b/docs/user/faq.rst index dc94586ff..f1c98a075 100644 --- a/docs/user/faq.rst +++ b/docs/user/faq.rst @@ -46,19 +46,19 @@ To choose which variables to update, you can do as below. .. code-block:: python - train_params = network.weights[3:] + train_params = network.trainable_weights[3:] The second way is to get the variables by a given name. For example, if you want to get all variables which the layer name contain ``dense``, you can do as below. .. code-block:: python - train_params = network.get_layer('dense').weights + train_params = network.get_layer('dense').trainable_weights After you get the variable list, you can define your optimizer like that so as to update only a part of the variables. .. code-block:: python - train_weights = network.weights + train_weights = network.trainable_weights optimizer.apply_gradients(zip(grad, train_weights)) Logging diff --git a/docs/user/get_start_advance.rst b/docs/user/get_start_advance.rst index 1b2f5b125..20cdaa871 100644 --- a/docs/user/get_start_advance.rst +++ b/docs/user/get_start_advance.rst @@ -36,7 +36,7 @@ Get a part of CNN nn = tl.layers.Dense(n_units=100, name='out')(nn) model = tl.models.Model(inputs=ni, outputs=nn) # train your own classifier (only update the last layer) - train_params = model.get_layer('out').weights + train_params = model.get_layer('out').all_weights Reuse CNN ------------------ diff --git a/docs/user/get_start_model.rst b/docs/user/get_start_model.rst index 807baf112..670d325ef 100644 --- a/docs/user/get_start_model.rst +++ b/docs/user/get_start_model.rst @@ -149,11 +149,11 @@ We can get the specific weights by indexing or naming. .. code-block:: python # indexing - all_weights = MLP.weights - some_weights = MLP.weights[1:3] + all_weights = MLP.all_weights + some_weights = MLP.all_weights[1:3] # naming - some_weights = MLP.get_layer('dense1').weights + some_weights = MLP.get_layer('dense1').all_weights Save and restore model diff --git a/examples/basic_tutorials/tutorial_cifar10_cnn_static.py b/examples/basic_tutorials/tutorial_cifar10_cnn_static.py index 896b15c86..c12c791a1 100644 --- a/examples/basic_tutorials/tutorial_cifar10_cnn_static.py +++ b/examples/basic_tutorials/tutorial_cifar10_cnn_static.py @@ -87,7 +87,7 @@ def get_model_batchnorm(inputs_shape): # learning_rate_decay_factor = 0.1 # num_epoch_decay = 350 -train_weights = net.weights +train_weights = net.trainable_weights # learning_rate = tf.Variable(init_learning_rate) optimizer = tf.optimizers.Adam(learning_rate) diff --git a/examples/basic_tutorials/tutorial_mnist_mlp_dynamic.py b/examples/basic_tutorials/tutorial_mnist_mlp_dynamic.py index b1b5909d6..1ffa7fbe0 100644 --- a/examples/basic_tutorials/tutorial_mnist_mlp_dynamic.py +++ b/examples/basic_tutorials/tutorial_mnist_mlp_dynamic.py @@ -46,7 +46,7 @@ def forward(self, x, foo=None): n_epoch = 500 batch_size = 500 print_freq = 5 -train_weights = MLP.weights +train_weights = MLP.trainable_weights optimizer = tf.optimizers.Adam(learning_rate=0.0001) ## the following code can help you understand SGD deeply diff --git a/examples/basic_tutorials/tutorial_mnist_mlp_dynamic_2.py b/examples/basic_tutorials/tutorial_mnist_mlp_dynamic_2.py index b369a0b26..b752012b0 100644 --- a/examples/basic_tutorials/tutorial_mnist_mlp_dynamic_2.py +++ b/examples/basic_tutorials/tutorial_mnist_mlp_dynamic_2.py @@ -65,7 +65,7 @@ def forward(self, x, foo=None): n_epoch = 500 batch_size = 500 print_freq = 5 -train_weights = MLP1.weights + MLP2.weights +train_weights = MLP1.trainable_weights + MLP2.trainable_weights optimizer = tf.optimizers.Adam(learning_rate=0.0001) ## the following code can help you understand SGD deeply diff --git a/examples/basic_tutorials/tutorial_mnist_mlp_static.py b/examples/basic_tutorials/tutorial_mnist_mlp_static.py index 1d99a3a86..c9c15f911 100644 --- a/examples/basic_tutorials/tutorial_mnist_mlp_static.py +++ b/examples/basic_tutorials/tutorial_mnist_mlp_static.py @@ -37,7 +37,7 @@ def get_model(inputs_shape): n_epoch = 500 batch_size = 500 print_freq = 5 -train_weights = MLP.weights +train_weights = MLP.trainable_weights optimizer = tf.optimizers.Adam(lr=0.0001) ## the following code can help you understand SGD deeply diff --git a/examples/basic_tutorials/tutorial_mnist_mlp_static_2.py b/examples/basic_tutorials/tutorial_mnist_mlp_static_2.py index 52568bb88..f0836c528 100644 --- a/examples/basic_tutorials/tutorial_mnist_mlp_static_2.py +++ b/examples/basic_tutorials/tutorial_mnist_mlp_static_2.py @@ -46,7 +46,7 @@ def get_model(inputs_shape, hmodel): n_epoch = 500 batch_size = 500 print_freq = 5 -train_weights = MLP.weights +train_weights = MLP.trainable_weights optimizer = tf.optimizers.Adam(lr=0.0001) ## the following code can help you understand SGD deeply diff --git a/examples/basic_tutorials/tutorial_mnist_siamese.py b/examples/basic_tutorials/tutorial_mnist_siamese.py index 6b744367d..db43f1163 100644 --- a/examples/basic_tutorials/tutorial_mnist_siamese.py +++ b/examples/basic_tutorials/tutorial_mnist_siamese.py @@ -96,7 +96,7 @@ def create_pairs(x, digit_indices): # training settings print_freq = 5 -train_weights = model.weights +train_weights = model.trainable_weights optimizer = tf.optimizers.RMSprop() diff --git a/examples/keras_tfslim/tutorial_keras.py b/examples/keras_tfslim/tutorial_keras.py index ba2c4b831..0622bc745 100644 --- a/examples/keras_tfslim/tutorial_keras.py +++ b/examples/keras_tfslim/tutorial_keras.py @@ -38,7 +38,7 @@ n_epoch = 200 learning_rate = 0.0001 -train_params = network.weights +train_params = network.trainable_weights optimizer = tf.optimizers.Adam(learning_rate) for epoch in range(n_epoch): diff --git a/examples/reinforcement_learning/tutorial_atari_pong.py b/examples/reinforcement_learning/tutorial_atari_pong.py index 0e18f93c4..ad8e264df 100644 --- a/examples/reinforcement_learning/tutorial_atari_pong.py +++ b/examples/reinforcement_learning/tutorial_atari_pong.py @@ -84,7 +84,7 @@ def get_model(inputs_shape): M = tl.models.Model(inputs=ni, outputs=nn, name="mlp") return M model = get_model([None, D]) -train_weights = model.weights +train_weights = model.trainable_weights # probs = model(t_states, is_train=True).outputs # sampling_prob = tf.nn.softmax(probs) diff --git a/examples/reinforcement_learning/tutorial_cartpole_ac.py b/examples/reinforcement_learning/tutorial_cartpole_ac.py index e525a3bdb..4d8b6f8ea 100644 --- a/examples/reinforcement_learning/tutorial_cartpole_ac.py +++ b/examples/reinforcement_learning/tutorial_cartpole_ac.py @@ -122,8 +122,8 @@ def learn(self, s, a, td): _logits = self.model([s]).outputs # _probs = tf.nn.softmax(_logits) _exp_v = tl.rein.cross_entropy_reward_loss(logits=_logits, actions=[a], rewards=td[0]) - grad = tape.gradient(_exp_v, self.model.weights) - self.optimizer.apply_gradients(zip(grad, self.model.weights)) + grad = tape.gradient(_exp_v, self.model.trainable_weights) + self.optimizer.apply_gradients(zip(grad, self.model.trainable_weights)) return _exp_v def choose_action(self, s): @@ -178,8 +178,8 @@ def learn(self, s, r, s_): # TD_error = r + lambd * V(newS) - V(S) td_error = r + LAMBDA * v_ - v loss = tf.square(td_error) - grad = tape.gradient(loss, self.model.weights) - self.optimizer.apply_gradients(zip(grad, self.model.weights)) + grad = tape.gradient(loss, self.model.trainable_weights) + self.optimizer.apply_gradients(zip(grad, self.model.trainable_weights)) return td_error diff --git a/examples/reinforcement_learning/tutorial_frozenlake_dqn.py b/examples/reinforcement_learning/tutorial_frozenlake_dqn.py index c905dee4c..9411da423 100644 --- a/examples/reinforcement_learning/tutorial_frozenlake_dqn.py +++ b/examples/reinforcement_learning/tutorial_frozenlake_dqn.py @@ -63,7 +63,7 @@ def get_model(inputs_shape): return tl.models.Model(inputs=ni, outputs=nn, name="Q-Network") qnetwork = get_model([1, 16]) qnetwork.train() -train_weights = qnetwork.weights +train_weights = qnetwork.trainable_weights # chose action greedily with reward. in Q-Learning, policy is greedy, so we use "max" to select the next action. # predict = tf.argmax(y, 1) diff --git a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py index e0db623fb..aecc69f61 100644 --- a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py +++ b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_dynamic.py @@ -87,7 +87,7 @@ def forward(self, inputs): learning_rate = 0.0001 print_freq = 10 batch_size = 64 -train_weights = net.weights +train_weights = net.trainable_weights optimizer = tf.optimizers.Adam(lr=learning_rate) ##================== TRAINING ================================================## diff --git a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_static.py b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_static.py index dfc615fc8..c9a93629f 100644 --- a/examples/spatial_transformer_network/tutorial_spatial_transformer_network_static.py +++ b/examples/spatial_transformer_network/tutorial_spatial_transformer_network_static.py @@ -84,7 +84,7 @@ def get_model(inputs_shape): learning_rate = 0.0001 print_freq = 10 batch_size = 64 -train_weights = net.weights +train_weights = net.trainable_weights optimizer = tf.optimizers.Adam(lr=learning_rate) ##================== TRAINING ================================================## diff --git a/examples/text_classification/tutorial_imdb_fasttext.py b/examples/text_classification/tutorial_imdb_fasttext.py index 6d785f402..2c2c7aed0 100644 --- a/examples/text_classification/tutorial_imdb_fasttext.py +++ b/examples/text_classification/tutorial_imdb_fasttext.py @@ -138,8 +138,8 @@ def train_test_and_save_model(): cost = tl.cost.cross_entropy(y_pred, y_batch, name='cost') # backward, calculate gradients and update the weights - grad = tape.gradient(cost, model.weights) - optimizer.apply_gradients(zip(grad, model.weights)) + grad = tape.gradient(cost, model.trainable_weights) + optimizer.apply_gradients(zip(grad, model.trainable_weights)) # calculate the accuracy predictions = tf.argmax(y_pred, axis=1, output_type=tf.int32) diff --git a/examples/text_generation/tutorial_generate_text.py b/examples/text_generation/tutorial_generate_text.py index ef2b2e6c3..22a17ea37 100644 --- a/examples/text_generation/tutorial_generate_text.py +++ b/examples/text_generation/tutorial_generate_text.py @@ -289,7 +289,7 @@ def loss_fn(outputs, targets, batch_size, sequence_length): # tvars = network.all_params $ all parameters # tvars = network.all_params[1:] $ parameters except embedding matrix # Train the whole network. - tvars = rnn_model.weights + tvars = rnn_model.trainable_weights # grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), max_grad_norm) # optimizer = tf.train.GradientDescentOptimizer(lr) train_op = tf.train.GradientDescentOptimizer(lr).minimize(cost, var_list=tvars) diff --git a/examples/text_word_embedding/tutorial_word2vec_basic.py b/examples/text_word_embedding/tutorial_word2vec_basic.py index 4285ee992..6310699ad 100644 --- a/examples/text_word_embedding/tutorial_word2vec_basic.py +++ b/examples/text_word_embedding/tutorial_word2vec_basic.py @@ -240,8 +240,8 @@ def main_word2vec_basic(): with tf.GradientTape() as tape: outputs, nce_cost = model([batch_inputs, batch_labels]) - grad = tape.gradient(nce_cost, model.weights) - optimizer.apply_gradients(zip(grad, model.weights)) + grad = tape.gradient(nce_cost, model.trainable_weights) + optimizer.apply_gradients(zip(grad, model.trainable_weights)) average_loss += nce_cost diff --git a/tensorlayer/db.py b/tensorlayer/db.py index 025566dc2..cb8db8e10 100644 --- a/tensorlayer/db.py +++ b/tensorlayer/db.py @@ -148,7 +148,7 @@ def save_model(self, network=None, model_name='model', **kwargs): self._fill_project_info(kwargs) # put project_name into kwargs # params = network.get_all_params() - params = network.weights + params = network.all_weights s = time.time() diff --git a/tensorlayer/files/utils.py b/tensorlayer/files/utils.py index 8cc718f5b..72fcb1824 100644 --- a/tensorlayer/files/utils.py +++ b/tensorlayer/files/utils.py @@ -1907,7 +1907,7 @@ def save_npz(save_list=None, name='model.npz'): -------- Save model to npz - >>> tl.files.save_npz(network.weights, name='model.npz') + >>> tl.files.save_npz(network.all_weights, name='model.npz') Load model from npz (Method 1) @@ -1993,7 +1993,7 @@ def assign_weights(weights, network): """ ops = [] for idx, param in enumerate(weights): - ops.append(network.weights[idx].assign(param)) + ops.append(network.all_weights[idx].assign(param)) return ops @@ -2073,7 +2073,7 @@ def load_and_assign_npz_dict(name='model.npz', network=None, skip=False): if len(weights.keys()) != len(set(weights.keys())): raise Exception("Duplication in model npz_dict %s" % name) - net_weights_name = [w.name for w in network.weights] + net_weights_name = [w.name for w in network.all_weights] for key in weights.keys(): if key not in net_weights_name: @@ -2085,7 +2085,7 @@ def load_and_assign_npz_dict(name='model.npz', network=None, skip=False): "if you want to skip redundant or mismatch weights." % key ) else: - assign_tf_variable(network.weights[net_weights_name.index(key)], weights[key]) + assign_tf_variable(network.all_weights[net_weights_name.index(key)], weights[key]) logging.info("[*] Model restored from npz_dict %s" % name) @@ -2549,9 +2549,9 @@ def _save_weights_to_hdf5_group(f, layers): elif isinstance(layer, tl.layers.LayerList): _save_weights_to_hdf5_group(g, layer.layers) elif isinstance(layer, tl.layers.Layer): - if layer.weights is not None: - weight_values = tf_variables_to_numpy(layer.weights) - weight_names = [w.name.encode('utf8') for w in layer.weights] + if layer.all_weights is not None: + weight_values = tf_variables_to_numpy(layer.all_weights) + weight_names = [w.name.encode('utf8') for w in layer.all_weights] else: weight_values = [] weight_names = [] @@ -2593,7 +2593,7 @@ def _load_weights_from_hdf5_group_in_order(f, layers): elif isinstance(layer, tl.layers.Layer): weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] for iid, w_name in enumerate(weight_names): - assign_tf_variable(layer.weights[iid], np.asarray(g[w_name])) + assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name])) else: raise Exception("Only layer or model can be saved into hdf5.") if idx == len(layers) - 1: @@ -2639,7 +2639,7 @@ def _load_weights_from_hdf5_group(f, layers, skip=False): elif isinstance(layer, tl.layers.Layer): weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] for iid, w_name in enumerate(weight_names): - assign_tf_variable(layer.weights[iid], np.asarray(g[w_name])) + assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name])) else: raise Exception("Only layer or model can be saved into hdf5.") diff --git a/tensorlayer/layers/convolution/separable_conv.py b/tensorlayer/layers/convolution/separable_conv.py index a19cabbcc..b6ae62446 100644 --- a/tensorlayer/layers/convolution/separable_conv.py +++ b/tensorlayer/layers/convolution/separable_conv.py @@ -156,7 +156,7 @@ def build(self, inputs_shape): ) # initialize weights outputs_shape = _out.shape # self._add_weights(self.layer.weights) - self._weights = self.layer.weights + self._trainable_weights = self.layer.weights def forward(self, inputs): outputs = self.layer(inputs) @@ -302,7 +302,7 @@ def build(self, inputs_shape): tf.convert_to_tensor(np.random.uniform(size=list(inputs_shape)), dtype=np.float) ) # initialize weights outputs_shape = _out.shape - self._weights = self.layer.weights + self._trainable_weights = self.layer.weights def forward(self, inputs): outputs = self.layer(inputs) diff --git a/tensorlayer/layers/convolution/simplified_deconv.py b/tensorlayer/layers/convolution/simplified_deconv.py index 0dc204caf..847062859 100644 --- a/tensorlayer/layers/convolution/simplified_deconv.py +++ b/tensorlayer/layers/convolution/simplified_deconv.py @@ -141,7 +141,7 @@ def build(self, inputs_shape): tf.convert_to_tensor(np.random.uniform(size=inputs_shape), dtype=np.float32) ) #np.random.uniform([1] + list(inputs_shape))) # initialize weights outputs_shape = _out.shape - self._weights = self.layer.weights + self._trainable_weights = self.layer.weights def forward(self, inputs): outputs = self.layer(inputs) @@ -264,7 +264,7 @@ def build(self, inputs_shape): ) #self.layer(np.random.uniform([1] + list(inputs_shape))) # initialize weights outputs_shape = _out.shape # self._add_weights(self.layer.weights) - self._weights = self.layer.weights + self._trainable_weights = self.layer.weights def forward(self, inputs): outputs = self.layer(inputs) diff --git a/tensorlayer/layers/core.py b/tensorlayer/layers/core.py index c5753c9bc..ce98f156c 100644 --- a/tensorlayer/layers/core.py +++ b/tensorlayer/layers/core.py @@ -34,8 +34,12 @@ class Layer(object): Initializing the Layer. __call__() (1) Building the Layer if necessary. (2) Forwarding the computation. - weights() + all_weights() + Return a list of Tensor which are all weights of this Layer. + trainable_weights() Return a list of Tensor which are all trainable weights of this Layer. + nontrainable_weights() + Return a list of Tensor which are all nontrainable weights of this Layer. build() Abstract method. Build the Layer. All trainable weights should be defined in this function. forward() @@ -89,7 +93,9 @@ def __init__(self, name=None, *args, **kwargs): self._nodes_fixed = False # Layer weight state - self._weights = None + self._all_weights = None + self._trainable_weights = None + self._nontrainable_weights = None # Layer training state self.is_train = True @@ -136,8 +142,24 @@ def config(self): return _config @property - def weights(self): - return self._weights + def all_weights(self): + if self._all_weights is not None and len(self._all_weights) > 0: + pass + else: + self._all_weights = list() + if self._trainable_weights is not None: + self._all_weights.extend(self._trainable_weights) + if self._nontrainable_weights is not None: + self._all_weights.extend(self._nontrainable_weights) + return self._all_weights + + @property + def trainable_weights(self): + return self._trainable_weights + + @property + def nontrainable_weights(self): + return self._nontrainable_weights def __call__(self, inputs, *args, **kwargs): """ @@ -218,12 +240,17 @@ def _fix_nodes_for_layers(self): """ fix LayerNodes to stop growing for this layer""" self._nodes_fixed = True - def _get_weights(self, var_name, shape, init=tl.initializers.random_normal()): + def _get_weights(self, var_name, shape, init=tl.initializers.random_normal(), trainable=True): """ Get trainable variables. """ weight = get_variable_with_initializer(scope_name=self.name, var_name=var_name, shape=shape, init=init) - if self._weights is None: - self._weights = list() - self._weights.append(weight) # Add into the weight collection + if trainable is True: + if self._trainable_weights is None: + self._trainable_weights = list() + self._trainable_weights.append(weight) + else: + if self._nontrainable_weights is None: + self._nontrainable_weights = list() + self._nontrainable_weights.append(weight) return weight @abstractmethod @@ -407,7 +434,7 @@ def __init__(self, model, name=None): self._built = True # Layer weight state - self._weights = model.weights + self._all_weights = model.all_weights # Layer training state self.is_train = True @@ -497,12 +524,12 @@ def __init__(self, layers, name=None): for layer in self.layers: if layer._built is False: is_built = False - if layer._built and layer.weights is not None: + if layer._built and layer.all_weights is not None: # some layers in the list passed in have already been built # e.g. using input shape to construct layers in dynamic eager - if self._weights is None: - self._weights = list() - self._weights.extend(layer.weights) + if self._all_weights is None: + self._all_weights = list() + self._all_weights.extend(layer.all_weights) if is_built: self._built = True @@ -550,10 +577,10 @@ def build(self, inputs_shape): is_build = layer._built out_tensor = layer(in_tensor) # nlayer = layer(in_layer) - if is_build == False and layer.weights is not None: - if self._weights is None: - self._weights = list() - self._weights.extend(layer.weights) + if is_build is False and layer.all_weights is not None: + if self._all_weights is None: + self._all_weights = list() + self._all_weights.extend(layer.all_weights) layer._built = True in_tensor = out_tensor # in_layer = nlayer diff --git a/tensorlayer/layers/lambda_layers.py b/tensorlayer/layers/lambda_layers.py index 997721e7c..13bc3ecbe 100644 --- a/tensorlayer/layers/lambda_layers.py +++ b/tensorlayer/layers/lambda_layers.py @@ -96,8 +96,8 @@ class Lambda(Layer): >>> pred_y = model(data_x) >>> loss = tl.cost.mean_squared_error(pred_y, data_y) - >>> gradients = tape.gradient(loss, model.weights) - >>> optimizer.apply_gradients(zip(gradients, model.weights)) + >>> gradients = tape.gradient(loss, model.trainable_weights) + >>> optimizer.apply_gradients(zip(gradients, model.trainable_weights)) """ @@ -111,14 +111,14 @@ def __init__( super(Lambda, self).__init__(name=name) self.fn = fn - self._weights = fn_weights if fn_weights is not None else [] + self._trainable_weights = fn_weights if fn_weights is not None else [] self.fn_args = fn_args if fn_args is not None else {} try: fn_name = repr(self.fn) except: fn_name = 'name not available' - logging.info("Lambda %s: func: %s, len_weights: %s" % (self.name, fn_name, len(self._weights))) + logging.info("Lambda %s: func: %s, len_weights: %s" % (self.name, fn_name, len(self._trainable_weights))) self.build() self._built = True @@ -134,7 +134,8 @@ def __repr__(self): except: fn_name = 'name not available' return s.format( - classname=self.__class__.__name__, fn_name=fn_name, len_weights=len(self._weights), **self.__dict__ + classname=self.__class__.__name__, fn_name=fn_name, len_weights=len(self._trainable_weights), + **self.__dict__ ) def build(self, inputs_shape=None): @@ -233,14 +234,16 @@ def __init__( super(ElementwiseLambda, self).__init__(name=name) self.fn = fn - self._weights = fn_weights if fn_weights is not None else [] + self._trainable_weights = fn_weights if fn_weights is not None else [] self.fn_args = fn_args if fn_args is not None else {} try: fn_name = repr(self.fn) except: fn_name = 'name not available' - logging.info("ElementwiseLambda %s: func: %s, len_weights: %s" % (self.name, fn_name, len(self._weights))) + logging.info( + "ElementwiseLambda %s: func: %s, len_weights: %s" % (self.name, fn_name, len(self._trainable_weights)) + ) self.build() self._built = True @@ -256,7 +259,8 @@ def __repr__(self): except: fn_name = 'name not available' return s.format( - classname=self.__class__.__name__, fn_name=fn_name, len_weights=len(self._weights), **self.__dict__ + classname=self.__class__.__name__, fn_name=fn_name, len_weights=len(self._trainable_weights), + **self.__dict__ ) def build(self, inputs_shape=None): diff --git a/tensorlayer/layers/normalization.py b/tensorlayer/layers/normalization.py index 572936228..d8cec274c 100644 --- a/tensorlayer/layers/normalization.py +++ b/tensorlayer/layers/normalization.py @@ -258,8 +258,12 @@ def build(self, inputs_shape): if self.gamma_init: self.gamma = self._get_weights("gamma", shape=params_shape, init=self.gamma_init) - self.moving_mean = self._get_weights("moving_mean", shape=params_shape, init=self.moving_mean_init) - self.moving_var = self._get_weights("moving_var", shape=params_shape, init=self.moving_var_init) + self.moving_mean = self._get_weights( + "moving_mean", shape=params_shape, init=self.moving_mean_init, trainable=False + ) + self.moving_var = self._get_weights( + "moving_var", shape=params_shape, init=self.moving_var_init, trainable=False + ) def forward(self, inputs): mean, var = tf.nn.moments(inputs, self.axes, keepdims=True) diff --git a/tensorlayer/layers/recurrent.py b/tensorlayer/layers/recurrent.py index acc6dba05..16b7208d0 100644 --- a/tensorlayer/layers/recurrent.py +++ b/tensorlayer/layers/recurrent.py @@ -149,10 +149,10 @@ def build(self, inputs_shape): with tf.name_scope(self.name) as scope: self.cell.build(tuple(inputs_shape)) - if self._weights is None: - self._weights = list() + if self._trainable_weights is None: + self._trainable_weights = list() for var in self.cell.trainable_variables: - self._weights.append(var) + self._trainable_weights.append(var) # @tf.function def forward(self, inputs, initial_state=None, **kwargs): @@ -341,12 +341,12 @@ def build(self, inputs_shape): self.fw_cell.build(tuple(inputs_shape)) self.bw_cell.build(tuple(inputs_shape)) - if self._weights is None: - self._weights = list() + if self._trainable_weights is None: + self._trainable_weights = list() for var in self.fw_cell.trainable_variables: - self._weights.append(var) + self._trainable_weights.append(var) for var in self.bw_cell.trainable_variables: - self._weights.append(var) + self._trainable_weights.append(var) # @tf.function def forward(self, inputs, fw_initial_state=None, bw_initial_state=None, **kwargs): diff --git a/tensorlayer/models/core.py b/tensorlayer/models/core.py index ee32ddd39..c811b9648 100644 --- a/tensorlayer/models/core.py +++ b/tensorlayer/models/core.py @@ -114,7 +114,7 @@ class Model(object): >>> outputs_s = M_static(data) Save and load weights - + >>> M_static.save_weights('./model_weights.h5') >>> M_static.load_weights('./model_weights.h5') @@ -182,7 +182,9 @@ def __init__(self, inputs=None, outputs=None, name=None): self.is_train = None # Model weights - self._weights = None + self._all_weights = None + self._trainable_weights = None + self._nontrainable_weights = None # Model args of all layers, ordered by all_layers self._config = None @@ -354,7 +356,9 @@ def all_layers(self): # dynamic model self._all_layers = list() attr_list = [attr for attr in dir(self) if attr[:2] != "__"] - attr_list.remove("weights") + attr_list.remove("all_weights") + attr_list.remove("trainable_weights") + attr_list.remove("nontrainable_weights") attr_list.remove("all_layers") for idx, attr in enumerate(attr_list): try: @@ -387,18 +391,46 @@ def all_layers(self): return self._all_layers @property - def weights(self): + def trainable_weights(self): + """Return trainable weights of this network in a list.""" + if self._trainable_weights is not None and len(self._trainable_weights) > 0: + # self._trainable_weights already extracted, so do nothing + pass + else: + self._trainable_weights = [] + for layer in self.all_layers: + if layer.trainable_weights is not None: + self._trainable_weights.extend(layer.trainable_weights) + + return self._trainable_weights + + @property + def nontrainable_weights(self): + """Return nontrainable weights of this network in a list.""" + if self._nontrainable_weights is not None and len(self._nontrainable_weights) > 0: + # self._nontrainable_weights already extracted, so do nothing + pass + else: + self._nontrainable_weights = [] + for layer in self.all_layers: + if layer.nontrainable_weights is not None: + self._nontrainable_weights.extend(layer.nontrainable_weights) + + return self._nontrainable_weights + + @property + def all_weights(self): """Return all weights of this network in a list.""" - if self._weights is not None and len(self._weights) > 0: - # self._weights already extracted, so do nothing + if self._all_weights is not None and len(self._all_weights) > 0: + # self._all_weights already extracted, so do nothing pass else: - self._weights = [] + self._all_weights = [] for layer in self.all_layers: - if layer.weights is not None: - self._weights.extend(layer.weights) + if layer.all_weights is not None: + self._all_weights.extend(layer.all_weights) - return self._weights + return self._all_weights @property def config(self): @@ -769,7 +801,7 @@ def save_weights(self, filepath, format=None): >>> net.save_weights('./model.npz', format='npz_dict') """ - if self.weights is None or len(self.weights) == 0: + if self.all_weights is None or len(self.all_weights) == 0: logging.warning("Model contains no weights or layers haven't been built, nothing will be saved") return @@ -783,9 +815,9 @@ def save_weights(self, filepath, format=None): if format == 'hdf5' or format == 'h5': utils.save_weights_to_hdf5(filepath, self) elif format == 'npz': - utils.save_npz(self.weights, filepath) + utils.save_npz(self.all_weights, filepath) elif format == 'npz_dict': - utils.save_npz_dict(self.weights, filepath) + utils.save_npz_dict(self.all_weights, filepath) elif format == 'ckpt': # TODO: enable this when tf save ckpt is enabled raise NotImplementedError("ckpt load/save is not supported now.") @@ -817,7 +849,7 @@ def load_weights(self, filepath, format=None, in_order=True, skip=False): skip : bool Allow skipping weights whose name is mismatched between the file and model. Only useful when 'format' is 'hdf5' or 'npz_dict'. If 'skip' is True, 'in_order' argument will be ignored and those loaded weights - whose name is not found in model weights (self.weights) will be skipped. If 'skip' is False, error will + whose name is not found in model weights (self.all_weights) will be skipped. If 'skip' is False, error will occur when mismatch is found. Default is False. diff --git a/tensorlayer/models/mobilenetv1.py b/tensorlayer/models/mobilenetv1.py index 9c8f7ac51..8065eeef3 100644 --- a/tensorlayer/models/mobilenetv1.py +++ b/tensorlayer/models/mobilenetv1.py @@ -44,10 +44,10 @@ def restore_params(network, path='models'): expected_bytes=25600116 ) # ls -al params = load_npz(name=os.path.join(path, 'mobilenet.npz')) - for idx, net_weight in enumerate(network.weights): + for idx, net_weight in enumerate(network.all_weights): if 'batchnorm' in net_weight.name: params[idx] = params[idx].reshape(1, 1, 1, -1) - assign_weights(params[:len(network.weights)], network) + assign_weights(params[:len(network.all_weights)], network) del params @@ -84,7 +84,7 @@ def MobileNetV1(pretrained=False, end_with='out', name=None): >>> nn = Flatten(name='flatten')(nn) >>> model = tl.models.Model(inputs=ni, outputs=nn) >>> # train your own classifier (only update the last layer) - >>> train_params = model.get_layer('out').weights + >>> train_params = model.get_layer('out').trainable_weights Returns ------- diff --git a/tensorlayer/models/squeezenetv1.py b/tensorlayer/models/squeezenetv1.py index b1a5cecc0..6d6a70535 100644 --- a/tensorlayer/models/squeezenetv1.py +++ b/tensorlayer/models/squeezenetv1.py @@ -38,7 +38,7 @@ def restore_params(network, path='models'): expected_bytes=7405613 ) # ls -al params = load_npz(name=os.path.join(path, 'squeezenet.npz')) - assign_weights(params[:len(network.weights)], network) + assign_weights(params[:len(network.all_weights)], network) del params @@ -75,7 +75,7 @@ def SqueezeNetV1(pretrained=False, end_with='out', name=None): >>> nn = GlobalMeanPool2d(name='globalmeanpool')(nn) >>> model = tl.models.Model(inputs=ni, outputs=nn) >>> # train your own classifier (only update the last layer) - >>> train_params = model.get_layer('conv10').weights + >>> train_params = model.get_layer('conv10').trainable_weights Returns ------- diff --git a/tensorlayer/models/vgg.py b/tensorlayer/models/vgg.py index 7c4bf201e..391878c61 100644 --- a/tensorlayer/models/vgg.py +++ b/tensorlayer/models/vgg.py @@ -163,7 +163,7 @@ def restore_model(model, layer_type): for val in sorted(npz.items()): logging.info(" Loading weights %s in %s" % (str(val[1].shape), val[0])) weights.append(val[1]) - if len(model.weights) == len(weights): + if len(model.all_weights) == len(weights): break elif layer_type == 'vgg19': npz = np.load(os.path.join('models', model_saved_name[layer_type]), encoding='latin1').item() @@ -172,7 +172,7 @@ def restore_model(model, layer_type): logging.info(" Loading %s in %s" % (str(val[1][0].shape), val[0])) logging.info(" Loading %s in %s" % (str(val[1][1].shape), val[0])) weights.extend(val[1]) - if len(model.weights) == len(weights): + if len(model.all_weights) == len(weights): break # assign weight values assign_weights(weights, model) @@ -231,7 +231,7 @@ def vgg16(pretrained=False, end_with='outputs', mode='dynamic', name=None): >>> nn = tl.layers.Dense(n_units=100, name='out')(nn) >>> model = tl.models.Model(inputs=ni, outputs=nn) >>> # train your own classifier (only update the last layer) - >>> train_params = model.get_layer('out').weights + >>> train_params = model.get_layer('out').trainable_weights Reuse model @@ -293,7 +293,7 @@ def vgg19(pretrained=False, end_with='outputs', mode='dynamic', name=None): >>> nn = tl.layers.Dense(n_units=100, name='out')(nn) >>> model = tl.models.Model(inputs=ni, outputs=nn) >>> # train your own classifier (only update the last layer) - >>> train_params = model.get_layer('out').weights + >>> train_params = model.get_layer('out').trainable_weights Reuse model diff --git a/tensorlayer/utils.py b/tensorlayer/utils.py index 0d555eae3..d6b8e6d78 100644 --- a/tensorlayer/utils.py +++ b/tensorlayer/utils.py @@ -127,7 +127,7 @@ def fit( tf.summary.scalar('acc', train_acc, step=epoch) # FIXME : there seems to be an internal error in Tensorboard (misuse of tf.name_scope) # if tensorboard_weight_histograms is not None: - # for param in network.weights: + # for param in network.all_weights: # tf.summary.histogram(param.name, param, step=epoch) if (X_val is not None) and (y_val is not None): @@ -138,7 +138,7 @@ def fit( tf.summary.scalar('acc', val_acc, step=epoch) # FIXME : there seems to be an internal error in Tensorboard (misuse of tf.name_scope) # if tensorboard_weight_histograms is not None: - # for param in network.weights: + # for param in network.all_weights: # tf.summary.histogram(param.name, param, step=epoch) if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: @@ -661,8 +661,8 @@ def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(lea y_pred = network(X_batch) _loss = cost(y_pred, y_batch) - grad = tape.gradient(_loss, network.weights) - train_op.apply_gradients(zip(grad, network.weights)) + grad = tape.gradient(_loss, network.trainable_weights) + train_op.apply_gradients(zip(grad, network.trainable_weights)) if acc is not None: _acc = acc(y_pred, y_batch) diff --git a/tests/files/test_utils_saveload.py b/tests/files/test_utils_saveload.py index b3954f5b8..58a1d374a 100644 --- a/tests/files/test_utils_saveload.py +++ b/tests/files/test_utils_saveload.py @@ -67,49 +67,49 @@ def tearDownClass(cls): pass def test_hdf5(self): - modify_val = np.zeros_like(self.static_model.weights[-2].numpy()) - ori_val = self.static_model.weights[-2].numpy() + modify_val = np.zeros_like(self.static_model.all_weights[-2].numpy()) + ori_val = self.static_model.all_weights[-2].numpy() tl.files.save_weights_to_hdf5("./model_basic.h5", self.static_model) - self.static_model.weights[-2].assign(modify_val) + self.static_model.all_weights[-2].assign(modify_val) tl.files.load_hdf5_to_weights_in_order("./model_basic.h5", self.static_model) - self.assertLess(np.max(np.abs(ori_val - self.static_model.weights[-2].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - self.static_model.all_weights[-2].numpy())), 1e-7) - self.static_model.weights[-2].assign(modify_val) + self.static_model.all_weights[-2].assign(modify_val) tl.files.load_hdf5_to_weights("./model_basic.h5", self.static_model) - self.assertLess(np.max(np.abs(ori_val - self.static_model.weights[-2].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - self.static_model.all_weights[-2].numpy())), 1e-7) - ori_weights = self.static_model._weights - self.static_model._weights = self.static_model._weights[1:] - self.static_model.weights[-2].assign(modify_val) + ori_weights = self.static_model._all_weights + self.static_model._all_weights = self.static_model._all_weights[1:] + self.static_model.all_weights[-2].assign(modify_val) tl.files.load_hdf5_to_weights("./model_basic.h5", self.static_model, skip=True) - self.assertLess(np.max(np.abs(ori_val - self.static_model.weights[-2].numpy())), 1e-7) - self.static_model._weights = ori_weights + self.assertLess(np.max(np.abs(ori_val - self.static_model.all_weights[-2].numpy())), 1e-7) + self.static_model._all_weights = ori_weights def test_npz(self): - modify_val = np.zeros_like(self.dynamic_model.weights[-2].numpy()) - ori_val = self.dynamic_model.weights[-2].numpy() - tl.files.save_npz(self.dynamic_model.weights, "./model_basic.npz") + modify_val = np.zeros_like(self.dynamic_model.all_weights[-2].numpy()) + ori_val = self.dynamic_model.all_weights[-2].numpy() + tl.files.save_npz(self.dynamic_model.all_weights, "./model_basic.npz") - self.dynamic_model.weights[-2].assign(modify_val) + self.dynamic_model.all_weights[-2].assign(modify_val) tl.files.load_and_assign_npz("./model_basic.npz", self.dynamic_model) - self.assertLess(np.max(np.abs(ori_val - self.dynamic_model.weights[-2].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - self.dynamic_model.all_weights[-2].numpy())), 1e-7) def test_npz_dict(self): - modify_val = np.zeros_like(self.dynamic_model.weights[-2].numpy()) - ori_val = self.dynamic_model.weights[-2].numpy() - tl.files.save_npz_dict(self.dynamic_model.weights, "./model_basic.npz") + modify_val = np.zeros_like(self.dynamic_model.all_weights[-2].numpy()) + ori_val = self.dynamic_model.all_weights[-2].numpy() + tl.files.save_npz_dict(self.dynamic_model.all_weights, "./model_basic.npz") - self.dynamic_model.weights[-2].assign(modify_val) + self.dynamic_model.all_weights[-2].assign(modify_val) tl.files.load_and_assign_npz_dict("./model_basic.npz", self.dynamic_model) - self.assertLess(np.max(np.abs(ori_val - self.dynamic_model.weights[-2].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - self.dynamic_model.all_weights[-2].numpy())), 1e-7) - ori_weights = self.dynamic_model._weights - self.dynamic_model._weights = self.static_model._weights[1:] - self.dynamic_model.weights[-2].assign(modify_val) + ori_weights = self.dynamic_model._all_weights + self.dynamic_model._all_weights = self.static_model._all_weights[1:] + self.dynamic_model.all_weights[-2].assign(modify_val) tl.files.load_and_assign_npz_dict("./model_basic.npz", self.dynamic_model, skip=True) - self.assertLess(np.max(np.abs(ori_val - self.dynamic_model.weights[-2].numpy())), 1e-7) - self.dynamic_model._weights = ori_weights + self.assertLess(np.max(np.abs(ori_val - self.dynamic_model.all_weights[-2].numpy())), 1e-7) + self.dynamic_model._all_weights = ori_weights if __name__ == '__main__': diff --git a/tests/layers/test_layers_convolution.py b/tests/layers/test_layers_convolution.py index 61beb733c..0f5979d5b 100644 --- a/tests/layers/test_layers_convolution.py +++ b/tests/layers/test_layers_convolution.py @@ -51,7 +51,7 @@ def test_layer_n1(self): # self.assertEqual(len(self.n1.all_layers), 2) # self.assertEqual(len(self.n1.all_params), 2) # self.assertEqual(self.n1.count_params(), 192) - self.assertEqual(len(self.n1._info[0].layer.weights), 2) + self.assertEqual(len(self.n1._info[0].layer.all_weights), 2) self.assertEqual(self.n1.get_shape().as_list()[1:], [50, 32]) def test_layer_n2(self): @@ -59,7 +59,7 @@ def test_layer_n2(self): # self.assertEqual(len(self.n2.all_layers), 3) # self.assertEqual(len(self.n2.all_params), 4) # self.assertEqual(self.n2.count_params(), 5344) - self.assertEqual(len(self.n2._info[0].layer.weights), 2) + self.assertEqual(len(self.n2._info[0].layer.all_weights), 2) self.assertEqual(self.n2.get_shape().as_list()[1:], [25, 32]) def test_layer_n3(self): @@ -67,7 +67,7 @@ def test_layer_n3(self): # self.assertEqual(len(self.n2.all_layers), 3) # self.assertEqual(len(self.n2.all_params), 4) # self.assertEqual(self.n2.count_params(), 5344) - self.assertEqual(len(self.n3._info[0].layer.weights), 2) + self.assertEqual(len(self.n3._info[0].layer.all_weights), 2) self.assertEqual(self.n3.get_shape().as_list()[1:], [50, 64]) def test_layer_n4(self): @@ -75,7 +75,7 @@ def test_layer_n4(self): # self.assertEqual(len(self.n2.all_layers), 3) # self.assertEqual(len(self.n2.all_params), 4) # self.assertEqual(self.n2.count_params(), 5344) - self.assertEqual(len(self.n4._info[0].layer.weights), 3) + self.assertEqual(len(self.n4._info[0].layer.all_weights), 3) self.assertEqual(self.n4.get_shape().as_list()[1:], [25, 32]) def test_layer_n5(self): @@ -127,7 +127,7 @@ def test_layer_n5(self): # # self.assertEqual(len(self.n1.all_layers), 2) # # self.assertEqual(len(self.n1.all_params), 2) # # self.assertEqual(self.n1.count_params(), 192) -# self.assertEqual(len(self.n1._info[0].layer.weights), 2) +# self.assertEqual(len(self.n1._info[0].layer.all_weights), 2) # self.assertEqual(self.n1.get_shape().as_list()[1:], [50, 32]) # # def test_layer_n2(self): @@ -135,7 +135,7 @@ def test_layer_n5(self): # # self.assertEqual(len(self.n2.all_layers), 3) # # self.assertEqual(len(self.n2.all_params), 4) # # self.assertEqual(self.n2.count_params(), 5344) -# self.assertEqual(len(self.n2._info[0].layer.weights), 2) +# self.assertEqual(len(self.n2._info[0].layer.all_weights), 2) # self.assertEqual(self.n2.get_shape().as_list()[1:], [25, 32]) # # # def test_layer_n3(self): @@ -223,7 +223,7 @@ def test_layer_n1(self): # self.assertEqual(len(self.n1.all_layers), 2) # self.assertEqual(len(self.n1.all_params), 2) # self.assertEqual(self.n1.count_params(), 2432) - self.assertEqual(len(self.n1._info[0].layer.weights), 2) + self.assertEqual(len(self.n1._info[0].layer.all_weights), 2) self.assertEqual(self.n1.get_shape().as_list()[1:], [200, 200, 32]) def test_layer_n2(self): @@ -231,7 +231,7 @@ def test_layer_n2(self): # self.assertEqual(len(self.n2.all_layers), 3) # self.assertEqual(len(self.n2.all_params), 4) # self.assertEqual(self.n2.count_params(), 11680) - self.assertEqual(len(self.n2._info[0].layer.weights), 2) + self.assertEqual(len(self.n2._info[0].layer.all_weights), 2) self.assertEqual(self.n2.get_shape().as_list()[1:], [100, 100, 32]) def test_layer_n3(self): @@ -239,7 +239,7 @@ def test_layer_n3(self): # self.assertEqual(len(self.n3.all_layers), 4) # self.assertEqual(len(self.n3.all_params), 5) # self.assertEqual(self.n3.count_params(), 20896) - self.assertEqual(len(self.n3._info[0].layer.weights), 1) # b_init is None + self.assertEqual(len(self.n3._info[0].layer.all_weights), 1) # b_init is None self.assertEqual(self.n3.get_shape().as_list()[1:], [50, 50, 32]) def test_layer_n4(self): @@ -247,7 +247,7 @@ def test_layer_n4(self): # self.assertEqual(len(self.n4.all_layers), 5) # self.assertEqual(len(self.n4.all_params), 7) # self.assertEqual(self.n4.count_params(), 46528) - self.assertEqual(len(self.n4._info[0].layer.weights), 2) + self.assertEqual(len(self.n4._info[0].layer.all_weights), 2) self.assertEqual(self.n4.get_shape().as_list()[1:], [100, 100, 32]) def test_layer_n5(self): @@ -255,7 +255,7 @@ def test_layer_n5(self): # self.assertEqual(len(self.n5.all_layers), 6) # self.assertEqual(len(self.n5.all_params), 9) # self.assertEqual(self.n5.count_params(), 55776) - self.assertEqual(len(self.n5._info[0].layer.weights), 2) + self.assertEqual(len(self.n5._info[0].layer.all_weights), 2) self.assertEqual(self.n5.get_shape().as_list()[1:], [200, 200, 32]) def test_layer_n6(self): @@ -263,7 +263,7 @@ def test_layer_n6(self): # self.assertEqual(len(self.n6.all_layers), 7) # self.assertEqual(len(self.n6.all_params), 11) # self.assertEqual(self.n6.count_params(), 56416) - self.assertEqual(len(self.n6._info[0].layer.weights), 2) + self.assertEqual(len(self.n6._info[0].layer.all_weights), 2) self.assertEqual(self.n6.get_shape().as_list()[1:], [200, 200, 64]) def test_layer_n7(self): @@ -271,7 +271,7 @@ def test_layer_n7(self): # self.assertEqual(len(self.n7.all_layers), 8) # self.assertEqual(len(self.n7.all_params), 13) # self.assertEqual(self.n7.count_params(), 74880) - self.assertEqual(len(self.n7._info[0].layer.weights), 2) + self.assertEqual(len(self.n7._info[0].layer.all_weights), 2) self.assertEqual(self.n7.get_shape().as_list()[1:], [100, 100, 32]) def test_layer_n8(self): @@ -279,7 +279,7 @@ def test_layer_n8(self): # self.assertEqual(len(self.n7.all_layers), 8) # self.assertEqual(len(self.n7.all_params), 13) # self.assertEqual(self.n7.count_params(), 74880) - self.assertEqual(len(self.n8._info[0].layer.weights), 2) + self.assertEqual(len(self.n8._info[0].layer.all_weights), 2) self.assertEqual(self.n8.get_shape().as_list()[1:], [50, 50, 64]) def test_layer_n9(self): @@ -287,35 +287,35 @@ def test_layer_n9(self): # self.assertEqual(len(self.n7.all_layers), 8) # self.assertEqual(len(self.n7.all_params), 13) # self.assertEqual(self.n7.count_params(), 74880) - self.assertEqual(len(self.n9._info[0].layer.weights), 3) + self.assertEqual(len(self.n9._info[0].layer.all_weights), 3) self.assertEqual(self.n9.get_shape().as_list()[1:], [24, 24, 32]) def test_layer_n10(self): # self.assertEqual(len(self.n7.all_layers), 8) # self.assertEqual(len(self.n7.all_params), 13) # self.assertEqual(self.n7.count_params(), 74880) - self.assertEqual(len(self.n10._info[0].layer.weights), 2) + self.assertEqual(len(self.n10._info[0].layer.all_weights), 2) self.assertEqual(self.n10.get_shape().as_list()[1:], [12, 12, 64]) def test_layer_n11(self): # self.assertEqual(len(self.n7.all_layers), 8) # self.assertEqual(len(self.n7.all_params), 13) # self.assertEqual(self.n7.count_params(), 74880) - self.assertEqual(len(self.n11._info[0].layer.weights), 2) + self.assertEqual(len(self.n11._info[0].layer.all_weights), 2) self.assertEqual(self.n11.get_shape().as_list()[1:], [12, 12, 32]) def test_layer_n12(self): # self.assertEqual(len(self.n7.all_layers), 8) # self.assertEqual(len(self.n7.all_params), 13) # self.assertEqual(self.n7.count_params(), 74880) - self.assertEqual(len(self.n12._info[0].layer.weights), 2) + self.assertEqual(len(self.n12._info[0].layer.all_weights), 2) self.assertEqual(self.n12.get_shape().as_list()[1:], [12, 12, 64]) def test_layer_n13(self): # self.assertEqual(len(self.n7.all_layers), 8) # self.assertEqual(len(self.n7.all_params), 13) # self.assertEqual(self.n7.count_params(), 74880) - self.assertEqual(len(self.n13._info[0].layer.weights), 2) + self.assertEqual(len(self.n13._info[0].layer.all_weights), 2) self.assertEqual(self.n13.get_shape().as_list()[1:], [12, 12, 32]) def test_layer_n14(self): @@ -393,7 +393,7 @@ def test_layer_n1(self): # self.assertEqual(len(self.n1.all_layers), 2) # self.assertEqual(len(self.n1.all_params), 2) # self.assertEqual(self.n1.count_params(), 800) - self.assertEqual(len(self.n1._info[0].layer.weights), 2) + self.assertEqual(len(self.n1._info[0].layer.all_weights), 2) self.assertEqual(self.n1.get_shape().as_list()[1:], [10, 10, 10, 32]) def test_layer_n2(self): @@ -401,7 +401,7 @@ def test_layer_n2(self): # self.assertEqual(len(self.n2.all_layers), 3) # self.assertEqual(len(self.n2.all_params), 4) # self.assertEqual(self.n2.count_params(), 33696) - self.assertEqual(len(self.n2._info[0].layer.weights), 2) + self.assertEqual(len(self.n2._info[0].layer.all_weights), 2) self.assertEqual(self.n2.get_shape().as_list()[1:], [20, 20, 20, 128]) def test_layer_n3(self): @@ -409,7 +409,7 @@ def test_layer_n3(self): # self.assertEqual(len(self.n3.all_layers), 4) # self.assertEqual(len(self.n3.all_params), 6) # self.assertEqual(self.n3.count_params(), 144320) - self.assertEqual(len(self.n3._info[0].layer.weights), 1) # b_init is None + self.assertEqual(len(self.n3._info[0].layer.all_weights), 1) # b_init is None self.assertEqual(self.n3.get_shape().as_list()[1:], [7, 7, 7, 64]) def test_layer_n4(self): @@ -417,7 +417,7 @@ def test_layer_n4(self): # self.assertEqual(len(self.n3.all_layers), 4) # self.assertEqual(len(self.n3.all_params), 6) # self.assertEqual(self.n3.count_params(), 144320) - self.assertEqual(len(self.n4._info[0].layer.weights), 2) + self.assertEqual(len(self.n4._info[0].layer.all_weights), 2) self.assertEqual(self.n4.get_shape().as_list()[1:], [14, 14, 14, 32]) diff --git a/tests/layers/test_layers_core_basedense_dropout.py b/tests/layers/test_layers_core_basedense_dropout.py index 2ec0019f3..19178f5d6 100644 --- a/tests/layers/test_layers_core_basedense_dropout.py +++ b/tests/layers/test_layers_core_basedense_dropout.py @@ -73,17 +73,17 @@ def test_net1(self): def test_net2(self): # test weights - self.assertEqual(self.innet._info[0].layer.weights, None) - self.assertEqual(self.dropout1._info[0].layer.weights, None) - self.assertEqual(self.dense1._info[0].layer.weights[0].get_shape().as_list(), [784, 800]) - self.assertEqual(self.dense1._info[0].layer.weights[1].get_shape().as_list(), [ + self.assertEqual(self.innet._info[0].layer.all_weights, []) + self.assertEqual(self.dropout1._info[0].layer.all_weights, []) + self.assertEqual(self.dense1._info[0].layer.all_weights[0].get_shape().as_list(), [784, 800]) + self.assertEqual(self.dense1._info[0].layer.all_weights[1].get_shape().as_list(), [ 800, ]) - self.assertEqual(self.dense2._info[0].layer.weights[0].get_shape().as_list(), [800, 10]) - self.assertEqual(len(self.dense1._info[0].layer.weights), 2) - self.assertEqual(len(self.dense2._info[0].layer.weights), 1) + self.assertEqual(self.dense2._info[0].layer.all_weights[0].get_shape().as_list(), [800, 10]) + self.assertEqual(len(self.dense1._info[0].layer.all_weights), 2) + self.assertEqual(len(self.dense2._info[0].layer.all_weights), 1) - self.assertEqual(len(self.model.weights), 3) + self.assertEqual(len(self.model.all_weights), 3) # a special case self.model.release_memory() @@ -134,7 +134,7 @@ def test_layerlist(self): )(innet) model = Model(inputs=innet, outputs=hlayer) - # for w in model.weights: + # for w in model.all_weights: # print(w.name) data = np.random.normal(size=[self.batch_size, self.inputs_shape[1]]).astype(np.float32) diff --git a/tests/layers/test_layers_deformable_convolution.py b/tests/layers/test_layers_deformable_convolution.py index 41b0ad533..b31d5ce98 100644 --- a/tests/layers/test_layers_deformable_convolution.py +++ b/tests/layers/test_layers_deformable_convolution.py @@ -44,12 +44,12 @@ def tearDownClass(cls): def test_layer_n1(self): - self.assertEqual(len(self.deformconv1._info[0].layer.weights), 2) + self.assertEqual(len(self.deformconv1._info[0].layer.all_weights), 2) self.assertEqual(self.deformconv1.get_shape().as_list()[1:], [10, 10, 32]) def test_layer_n2(self): - self.assertEqual(len(self.deformconv2._info[0].layer.weights), 2) + self.assertEqual(len(self.deformconv2._info[0].layer.all_weights), 2) self.assertEqual(self.deformconv2.get_shape().as_list()[1:], [10, 10, 64]) diff --git a/tests/layers/test_layers_lambda.py b/tests/layers/test_layers_lambda.py index 163339ca5..e7c0bc713 100644 --- a/tests/layers/test_layers_lambda.py +++ b/tests/layers/test_layers_lambda.py @@ -58,8 +58,8 @@ def forward(self, x): pred_y = model(self.data_x) loss = tl.cost.mean_squared_error(pred_y, self.data_y) - gradients = tape.gradient(loss, model.weights) - optimizer.apply_gradients(zip(gradients, model.weights)) + gradients = tape.gradient(loss, model.trainable_weights) + optimizer.apply_gradients(zip(gradients, model.trainable_weights)) print("epoch %d, loss %f" % (epoch, loss)) diff --git a/tests/layers/test_layers_recurrent.py b/tests/layers/test_layers_recurrent.py index fee1acb27..2e4dbab39 100644 --- a/tests/layers/test_layers_recurrent.py +++ b/tests/layers/test_layers_recurrent.py @@ -62,8 +62,8 @@ def test_basic_simplernn(self): pred_y, final_state = rnn_model(self.data_x) loss = tl.cost.mean_squared_error(pred_y, self.data_y) - gradients = tape.gradient(loss, rnn_model.weights) - optimizer.apply_gradients(zip(gradients, rnn_model.weights)) + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) @@ -131,8 +131,8 @@ def forward(self, x): pred_y = rnn_model(self.data_x) loss = tl.cost.mean_squared_error(pred_y, self.data_y) - gradients = tape.gradient(loss, rnn_model.weights) - optimizer.apply_gradients(zip(gradients, rnn_model.weights)) + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) @@ -165,8 +165,8 @@ def forward(self, x): pred_y = rnn_model(self.data_x) loss = tl.cost.mean_squared_error(pred_y, self.data_y) - gradients = tape.gradient(loss, rnn_model.weights) - optimizer.apply_gradients(zip(gradients, rnn_model.weights)) + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) @@ -205,8 +205,8 @@ def forward(self, x): pred_y = rnn_model(self.data_x) loss = tl.cost.mean_squared_error(pred_y, self.data_y) - gradients = tape.gradient(loss, rnn_model.weights) - optimizer.apply_gradients(zip(gradients, rnn_model.weights)) + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) @@ -232,8 +232,8 @@ def test_basic_lstmrnn(self): pred_y, final_h, final_c = rnn_model(self.data_x) loss = tl.cost.mean_squared_error(pred_y, self.data_y) - gradients = tape.gradient(loss, rnn_model.weights) - optimizer.apply_gradients(zip(gradients, rnn_model.weights)) + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) @@ -259,8 +259,8 @@ def test_basic_grurnn(self): pred_y, final_h = rnn_model(self.data_x) loss = tl.cost.mean_squared_error(pred_y, self.data_y) - gradients = tape.gradient(loss, rnn_model.weights) - optimizer.apply_gradients(zip(gradients, rnn_model.weights)) + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) @@ -292,8 +292,8 @@ def test_basic_birnn_simplernncell(self): self.assertEqual( r.get_shape().as_list(), [self.batch_size * self.num_steps, self.hidden_size + self.hidden_size + 1] ) - gradients = tape.gradient(loss, rnn_model.weights) - optimizer.apply_gradients(zip(gradients, rnn_model.weights)) + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) @@ -326,8 +326,8 @@ def test_basic_birnn_lstmcell(self): self.assertEqual( r.get_shape().as_list(), [self.batch_size, self.num_steps, self.hidden_size + self.hidden_size + 1] ) - gradients = tape.gradient(loss, rnn_model.weights) - optimizer.apply_gradients(zip(gradients, rnn_model.weights)) + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) @@ -362,8 +362,8 @@ def forward(self, x): pred_y = rnn_model(self.data_x) loss = tl.cost.mean_squared_error(pred_y, self.data_y) - gradients = tape.gradient(loss, rnn_model.weights) - optimizer.apply_gradients(zip(gradients, rnn_model.weights)) + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) @@ -396,8 +396,8 @@ def test_stack_simplernn(self): pred_y = rnn_model(self.data_x) loss = tl.cost.mean_squared_error(pred_y, self.data_y) - gradients = tape.gradient(loss, rnn_model.weights) - optimizer.apply_gradients(zip(gradients, rnn_model.weights)) + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) @@ -433,8 +433,8 @@ def test_stack_birnn_simplernncell(self): pred_y = rnn_model(self.data_x) loss = tl.cost.mean_squared_error(pred_y, self.data_y2) - gradients = tape.gradient(loss, rnn_model.weights) - optimizer.apply_gradients(zip(gradients, rnn_model.weights)) + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) diff --git a/tests/models/test_auto_naming.py b/tests/models/test_auto_naming.py index 143dba963..fb8f03720 100644 --- a/tests/models/test_auto_naming.py +++ b/tests/models/test_auto_naming.py @@ -231,7 +231,7 @@ def test_layerlist(self): tl.layers.Dense(n_units=3, name='dense1')] )(inputs) model = tl.models.Model(inputs=inputs, outputs=layer1, name='layerlistmodel') - print([w.name for w in model.weights]) + print([w.name for w in model.all_weights]) test_flag = False except Exception as e: print(e) @@ -258,7 +258,7 @@ def forward(self, x): model_layer = tl.layers.ModelLayer(inner_model())(inputs) model = tl.models.Model(inputs=inputs, outputs=model_layer, name='modellayermodel') print(model) - print([w.name for w in model.weights]) + print([w.name for w in model.all_weights]) test_flag = False except Exception as e: print(e) @@ -273,7 +273,7 @@ def test_layerlist(self): tl.layers.Dense(n_units=3, name='dense1')] )(inputs) model = tl.models.Model(inputs=inputs, outputs=layer1, name='layerlistmodel') - print([w.name for w in model.weights]) + print([w.name for w in model.all_weights]) self.fail("Fail to detect duplicate name in layerlist") except Exception as e: print(e) diff --git a/tests/models/test_model_core.py b/tests/models/test_model_core.py index 1c3732819..caf3044b2 100644 --- a/tests/models/test_model_core.py +++ b/tests/models/test_model_core.py @@ -70,7 +70,7 @@ def test_dynamic_basic(self): # test empty model before calling self.assertEqual(model_basic.is_train, None) - self.assertEqual(model_basic._weights, None) + self.assertEqual(model_basic._all_weights, None) self.assertEqual(model_basic._inputs, None) self.assertEqual(model_basic._outputs, None) self.assertEqual(model_basic._model_layer, None) @@ -80,10 +80,10 @@ def test_dynamic_basic(self): # test layer and weights access all_layers = model_basic.all_layers self.assertEqual(len(model_basic.all_layers), 7) - self.assertEqual(model_basic._weights, None) + self.assertEqual(model_basic._all_weights, None) - self.assertIsNotNone(model_basic.weights) - print([w.name for w in model_basic.weights]) + self.assertIsNotNone(model_basic.all_weights) + print([w.name for w in model_basic.all_weights]) # test model mode model_basic.train() @@ -139,7 +139,7 @@ def test_static_basic(self): # test empty model before calling self.assertEqual(model_basic.is_train, None) - self.assertEqual(model_basic._weights, None) + self.assertEqual(model_basic._all_weights, None) self.assertIsNotNone(model_basic._inputs) self.assertIsNotNone(model_basic._outputs) self.assertEqual(model_basic._model_layer, None) @@ -149,10 +149,10 @@ def test_static_basic(self): # test layer and weights access all_layers = model_basic.all_layers self.assertEqual(len(model_basic.all_layers), 8) - self.assertEqual(model_basic._weights, None) + self.assertEqual(model_basic._all_weights, None) - self.assertIsNotNone(model_basic.weights) - print([w.name for w in model_basic.weights]) + self.assertIsNotNone(model_basic.all_weights) + print([w.name for w in model_basic.all_weights]) # test model mode model_basic.train() @@ -283,7 +283,7 @@ def forward(self, x): return x model = ill_model() - weights = model.weights + weights = model.all_weights except Exception as e: self.assertIsInstance(e, AttributeError) print(e) @@ -360,7 +360,7 @@ def forward(self, x): return x model = my_model() - weights = model.weights + weights = model.all_weights self.assertGreater(len(weights), 2) print(len(weights)) @@ -381,7 +381,7 @@ def test_get_layer(self): model_basic = basic_static_model() self.assertIsInstance(model_basic.get_layer('conv2'), tl.layers.Conv2d) self.assertIsInstance(model_basic.get_layer(index=2), tl.layers.MaxPool2d) - print([w.name for w in model_basic.get_layer(index=-1).weights]) + print([w.name for w in model_basic.get_layer(index=-1).all_weights]) try: model_basic.get_layer('abc') except Exception as e: diff --git a/tests/models/test_model_save.py b/tests/models/test_model_save.py index 1029a5622..ba224ee25 100644 --- a/tests/models/test_model_save.py +++ b/tests/models/test_model_save.py @@ -92,38 +92,38 @@ def normal_save(self, model_basic): # hdf5 print('testing hdf5 saving...') - modify_val = np.zeros_like(model_basic.weights[-2].numpy()) - ori_val = model_basic.weights[-2].numpy() + modify_val = np.zeros_like(model_basic.all_weights[-2].numpy()) + ori_val = model_basic.all_weights[-2].numpy() model_basic.save_weights("./model_basic.h5") - model_basic.weights[-2].assign(modify_val) + model_basic.all_weights[-2].assign(modify_val) model_basic.load_weights("./model_basic.h5") - self.assertLess(np.max(np.abs(ori_val - model_basic.weights[-2].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7) - model_basic.weights[-2].assign(modify_val) + model_basic.all_weights[-2].assign(modify_val) model_basic.load_weights("./model_basic.h5", format="hdf5") - self.assertLess(np.max(np.abs(ori_val - model_basic.weights[-2].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7) - model_basic.weights[-2].assign(modify_val) + model_basic.all_weights[-2].assign(modify_val) model_basic.load_weights("./model_basic.h5", format="hdf5", in_order=False) - self.assertLess(np.max(np.abs(ori_val - model_basic.weights[-2].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7) # npz print('testing npz saving...') model_basic.save_weights("./model_basic.npz", format='npz') - model_basic.weights[-2].assign(modify_val) + model_basic.all_weights[-2].assign(modify_val) model_basic.load_weights("./model_basic.npz") - model_basic.weights[-2].assign(modify_val) + model_basic.all_weights[-2].assign(modify_val) model_basic.load_weights("./model_basic.npz", format='npz') model_basic.save_weights("./model_basic.npz") - self.assertLess(np.max(np.abs(ori_val - model_basic.weights[-2].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7) # npz_dict print('testing npz_dict saving...') model_basic.save_weights("./model_basic.npz", format='npz_dict') - model_basic.weights[-2].assign(modify_val) + model_basic.all_weights[-2].assign(modify_val) model_basic.load_weights("./model_basic.npz", format='npz_dict') - self.assertLess(np.max(np.abs(ori_val - model_basic.weights[-2].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7) # ckpt try: @@ -163,12 +163,12 @@ def test_skip(self): print("testing dynamic skip load...") self.dynamic_basic.save_weights("./model_basic.h5") - ori_weights = self.dynamic_basic_skip.weights + ori_weights = self.dynamic_basic_skip.all_weights ori_val = ori_weights[1].numpy() modify_val = np.zeros_like(ori_val) - self.dynamic_basic_skip.weights[1].assign(modify_val) + self.dynamic_basic_skip.all_weights[1].assign(modify_val) self.dynamic_basic_skip.load_weights("./model_basic.h5", skip=True) - self.assertLess(np.max(np.abs(ori_val - self.dynamic_basic_skip.weights[1].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - self.dynamic_basic_skip.all_weights[1].numpy())), 1e-7) try: self.dynamic_basic_skip.load_weights("./model_basic.h5", in_order=False, skip=False) @@ -177,12 +177,12 @@ def test_skip(self): print("testing static skip load...") self.static_basic.save_weights("./model_basic.h5") - ori_weights = self.static_basic_skip.weights + ori_weights = self.static_basic_skip.all_weights ori_val = ori_weights[1].numpy() modify_val = np.zeros_like(ori_val) - self.static_basic_skip.weights[1].assign(modify_val) + self.static_basic_skip.all_weights[1].assign(modify_val) self.static_basic_skip.load_weights("./model_basic.h5", skip=True) - self.assertLess(np.max(np.abs(ori_val - self.static_basic_skip.weights[1].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - self.static_basic_skip.all_weights[1].numpy())), 1e-7) try: self.static_basic_skip.load_weights("./model_basic.h5", in_order=False, skip=False) @@ -196,13 +196,13 @@ def test_nested_vgg(self): nested_vgg.save_weights("nested_vgg.h5") # modify vgg1 weight val - tar_weight1 = nested_vgg.vgg1.layers[0].weights[0] + tar_weight1 = nested_vgg.vgg1.layers[0].all_weights[0] print(tar_weight1.name) ori_val1 = tar_weight1.numpy() modify_val1 = np.zeros_like(ori_val1) tar_weight1.assign(modify_val1) # modify vgg2 weight val - tar_weight2 = nested_vgg.vgg2.layers[1].weights[0] + tar_weight2 = nested_vgg.vgg2.layers[1].all_weights[0] print(tar_weight2.name) ori_val2 = tar_weight2.numpy() modify_val2 = np.zeros_like(ori_val2) @@ -236,12 +236,12 @@ def forward(self, *inputs, **kwargs): print([x.name for x in net.all_layers]) # modify vgg1 weight val - tar_weight1 = net.inner.vgg1.layers[0].weights[0] + tar_weight1 = net.inner.vgg1.layers[0].all_weights[0] ori_val1 = tar_weight1.numpy() modify_val1 = np.zeros_like(ori_val1) tar_weight1.assign(modify_val1) # modify vgg2 weight val - tar_weight2 = net.inner.vgg2.layers[1].weights[0] + tar_weight2 = net.inner.vgg2.layers[1].all_weights[0] ori_val2 = tar_weight2.numpy() modify_val2 = np.zeros_like(ori_val2) tar_weight2.assign(modify_val2) @@ -264,7 +264,7 @@ def test_layerlist(self): model = tl.models.Model(inputs=inputs, outputs=layer1, name='layerlistmodel') model.save_weights("layerlist.h5") - tar_weight = model.get_layer(index=-1)[0].weights[0] + tar_weight = model.get_layer(index=-1)[0].all_weights[0] print(tar_weight.name) ori_val = tar_weight.numpy() modify_val = np.zeros_like(ori_val) diff --git a/tests/models/test_model_save_graph.py b/tests/models/test_model_save_graph.py index adbed8c83..95229938b 100644 --- a/tests/models/test_model_save_graph.py +++ b/tests/models/test_model_save_graph.py @@ -73,7 +73,7 @@ def test_save(self): print(MLP) n_epoch = 3 batch_size = 500 - train_weights = MLP.weights + train_weights = MLP.trainable_weights optimizer = tf.optimizers.Adam(lr=0.0001) for epoch in range(n_epoch): ## iterate the dataset n_epoch times @@ -118,7 +118,7 @@ def test_save(self): n_epoch = 3 batch_size = 500 - train_weights = MLP.weights + train_weights = MLP.trainable_weights optimizer = tf.optimizers.Adam(lr=0.0001) X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 784)) val_loss, val_acc, n_iter = 0, 0, 0 @@ -302,13 +302,13 @@ def test_lambda_layer_keras_model(self): self.assertEqual((output2 == output4).all(), True) self.assertEqual(M2.config, M4.config) - ori_weights = M4.weights + ori_weights = M4.all_weights ori_val = ori_weights[1].numpy() modify_val = np.zeros_like(ori_val) + 10 - M4.weights[1].assign(modify_val) + M4.all_weights[1].assign(modify_val) M4 = Model.load('M2_keras.hdf5') - self.assertLess(np.max(np.abs(ori_val - M4.weights[1].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - M4.all_weights[1].numpy())), 1e-7) def test_lambda_layer_keras_layer(self): input_shape = [100, 5] @@ -331,13 +331,13 @@ def test_lambda_layer_keras_layer(self): self.assertEqual((output1 == output3).all(), True) self.assertEqual(M1.config, M3.config) - ori_weights = M3.weights + ori_weights = M3.all_weights ori_val = ori_weights[1].numpy() modify_val = np.zeros_like(ori_val) + 10 - M3.weights[1].assign(modify_val) + M3.all_weights[1].assign(modify_val) M3 = Model.load('M1_keras.hdf5') - self.assertLess(np.max(np.abs(ori_val - M3.weights[1].numpy())), 1e-7) + self.assertLess(np.max(np.abs(ori_val - M3.all_weights[1].numpy())), 1e-7) class ElementWise_lambda_test(CustomTestCase): diff --git a/tests/pending/test_mnist_simple.py b/tests/pending/test_mnist_simple.py index 79d3dc8dd..5fe68c97b 100644 --- a/tests/pending/test_mnist_simple.py +++ b/tests/pending/test_mnist_simple.py @@ -44,7 +44,7 @@ def setUpClass(cls): # y_op = tf.argmax(tf.nn.softmax(y), 1) # define the optimizer - train_params = cls.network.all_params + train_params = cls.network.trainable_weights cls.train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cls.cost, var_list=train_params) @classmethod diff --git a/tests/performance_test/vgg/tl2-autograph.py b/tests/performance_test/vgg/tl2-autograph.py index a08688843..bd84bde8c 100644 --- a/tests/performance_test/vgg/tl2-autograph.py +++ b/tests/performance_test/vgg/tl2-autograph.py @@ -23,7 +23,7 @@ # training setting num_iter = NUM_ITERS batch_size = BATCH_SIZE -train_weights = vgg.weights +train_weights = vgg.trainable_weights optimizer = tf.optimizers.Adam(learning_rate=LERANING_RATE) loss_object = tl.cost.cross_entropy diff --git a/tests/performance_test/vgg/tl2-eager.py b/tests/performance_test/vgg/tl2-eager.py index f44a54ae2..401297a8e 100644 --- a/tests/performance_test/vgg/tl2-eager.py +++ b/tests/performance_test/vgg/tl2-eager.py @@ -23,7 +23,7 @@ # training setting num_iter = NUM_ITERS batch_size = BATCH_SIZE -train_weights = vgg.weights +train_weights = vgg.trainable_weights optimizer = tf.optimizers.Adam(learning_rate=LERANING_RATE) loss_object = tl.cost.cross_entropy diff --git a/tests/test_initializers.py b/tests/test_initializers.py index 7998c98ac..df86fd834 100644 --- a/tests/test_initializers.py +++ b/tests/test_initializers.py @@ -30,38 +30,38 @@ def init_dense(self, w_init): def test_zeros(self): dense = self.init_dense(tl.initializers.zeros()) - self.assertEqual(np.sum(dense.weights[0].numpy() - np.zeros(shape=self.w_shape)), self.eps) + self.assertEqual(np.sum(dense.all_weights[0].numpy() - np.zeros(shape=self.w_shape)), self.eps) nn = dense(self.ni) def test_ones(self): dense = self.init_dense(tl.initializers.ones()) - self.assertEqual(np.sum(dense.weights[0].numpy() - np.ones(shape=self.w_shape)), self.eps) + self.assertEqual(np.sum(dense.all_weights[0].numpy() - np.ones(shape=self.w_shape)), self.eps) nn = dense(self.ni) def test_constant(self): dense = self.init_dense(tl.initializers.constant(value=5.0)) - self.assertEqual(np.sum(dense.weights[0].numpy() - np.ones(shape=self.w_shape) * 5.0), self.eps) + self.assertEqual(np.sum(dense.all_weights[0].numpy() - np.ones(shape=self.w_shape) * 5.0), self.eps) nn = dense(self.ni) # test with numpy arr arr = np.random.uniform(size=self.w_shape).astype(np.float32) dense = self.init_dense(tl.initializers.constant(value=arr)) - self.assertEqual(np.sum(dense.weights[0].numpy() - arr), self.eps) + self.assertEqual(np.sum(dense.all_weights[0].numpy() - arr), self.eps) nn = dense(self.ni) def test_RandomUniform(self): dense = self.init_dense(tl.initializers.random_uniform(minval=-0.1, maxval=0.1, seed=1234)) - print(dense.weights[0].numpy()) + print(dense.all_weights[0].numpy()) nn = dense(self.ni) def test_RandomNormal(self): dense = self.init_dense(tl.initializers.random_normal(mean=0.0, stddev=0.1)) - print(dense.weights[0].numpy()) + print(dense.all_weights[0].numpy()) nn = dense(self.ni) def test_TruncatedNormal(self): dense = self.init_dense(tl.initializers.truncated_normal(mean=0.0, stddev=0.1)) - print(dense.weights[0].numpy()) + print(dense.all_weights[0].numpy()) nn = dense(self.ni) def test_deconv2d_bilinear_upsampling_initializer(self):