@@ -124,3 +124,61 @@ def _create_slots(self, var_list):
124
124
) # pylint: disable=protected-access
125
125
for var in var_list :
126
126
self .add_slot (var , "average" , var .read_value ())
127
+
128
+ def shadow_copy (self , model_weights ):
129
+ """Creates shadow variables for the given model weights."""
130
+ for var in model_weights :
131
+ self .add_slot (var , "average" , initializer = "zeros" )
132
+ self ._average_weights = [self .get_slot (var , "average" ) for var in model_weights ]
133
+ self ._model_weights = model_weights
134
+
135
+ @property
136
+ def has_shadow_copy (self ):
137
+ """Whether this optimizer has created shadow variables."""
138
+ return self ._model_weights is not None
139
+
140
+ def swap_weights (self ):
141
+ """Swap the average and moving weights.
142
+
143
+ This is a convenience method to allow one to evaluate the averaged weights
144
+ at test time. Loads the weights stored in `self._average_weights` into the model,
145
+ keeping a copy of the original model weights. Swapping twice will return
146
+ the original weights.
147
+ """
148
+ if tf .distribute .in_cross_replica_context ():
149
+ strategy = tf .distribute .get_strategy ()
150
+ return strategy .run (self ._swap_weights , args = ())
151
+ else :
152
+ raise ValueError (
153
+ "Swapping weights must occur under a " "tf.distribute.Strategy"
154
+ )
155
+
156
+ @tf .function
157
+ def _swap_weights (self ):
158
+ def fn_0 (a , b ):
159
+ a .assign_add (b )
160
+ return a
161
+
162
+ def fn_1 (b , a ):
163
+ b .assign (a - b )
164
+ return b
165
+
166
+ def fn_2 (a , b ):
167
+ a .assign_sub (b )
168
+ return a
169
+
170
+ def swap (strategy , a , b ):
171
+ """Swap `a` and `b` and mirror to all devices."""
172
+ for a_element , b_element in zip (a , b ):
173
+ strategy .extended .update (
174
+ a_element , fn_0 , args = (b_element ,)
175
+ ) # a = a + b
176
+ strategy .extended .update (
177
+ b_element , fn_1 , args = (a_element ,)
178
+ ) # b = a - b
179
+ strategy .extended .update (
180
+ a_element , fn_2 , args = (b_element ,)
181
+ ) # a = a - b
182
+
183
+ ctx = tf .distribute .get_replica_context ()
184
+ return ctx .merge_call (swap , args = (self ._average_weights , self ._model_weights ,))
0 commit comments