Skip to content

Commit 35df825

Browse files
authored
Merge pull request #1026 from warshallrho/master
[fix bug] copy original model's trainable_weights and nontrainable_weights when initializing ModelLayer
2 parents ef75b07 + a8c352f commit 35df825

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,10 @@ To release a new version, please update the changelog as followed:
106106

107107
### Fixed
108108
- Fix `tf.models.Model._construct_graph` for list of outputs, e.g. STN case (PR #1010)
109-
- Enable better `in_channels` exception raise. (pR #1015)
109+
- Enable better `in_channels` exception raise. (PR #1015)
110110
- Set allow_pickle=True in np.load() (#PR 1021)
111111
- Remove `private_method` decorator (#PR 1025)
112+
- Copy original model's `trainable_weights` and `nontrainable_weights` when initializing `ModelLayer` (#PR 1026)
112113

113114
### Removed
114115

@@ -118,7 +119,7 @@ To release a new version, please update the changelog as followed:
118119

119120
- @zsdonghao
120121
- @ChrisWu1997: #1010 #1015 #1025
121-
- @warshallrho: #1017 #1021
122+
- @warshallrho: #1017 #1021 #1026
122123
- @ArnoldLIULJ: #1023
123124
- @JingqingZ: #1023
124125

tensorlayer/layers/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,8 @@ def __init__(self, model, name=None):
512512

513513
# Layer weight state
514514
self._all_weights = model.all_weights
515+
self._trainable_weights = model.trainable_weights
516+
self._nontrainable_weights = model.nontrainable_weights
515517

516518
# Layer training state
517519
self.is_train = True

0 commit comments

Comments
 (0)