@@ -27,7 +27,8 @@ class ConditionalGradient(tf.keras.optimizers.Optimizer):
27
27
28
28
This optimizer helps handle constraints well.
29
29
30
- Currently only supports frobenius norm constraint.
30
+ Currently only supports frobenius norm constraint or nuclear norm
31
+ constraint.
31
32
See https://arxiv.org/pdf/1803.06453.pdf
32
33
33
34
```
@@ -42,6 +43,13 @@ class ConditionalGradient(tf.keras.optimizers.Optimizer):
42
43
gradient is 0.
43
44
44
45
In this implementation, `epsilon` defaults to $10^{-7}$.
46
+
47
+ For nucler norm constraint, the formula is as following:
48
+
49
+ ```
50
+ variable -= (1-learning_rate) * (variable
51
+ + lambda_ * top_singular_vector(gradient))
52
+ ```
45
53
"""
46
54
47
55
@typechecked
@@ -50,6 +58,7 @@ def __init__(
50
58
learning_rate : Union [FloatTensorLike , Callable ],
51
59
lambda_ : Union [FloatTensorLike , Callable ] = 0.01 ,
52
60
epsilon : FloatTensorLike = 1e-7 ,
61
+ ord : str = "fro" ,
53
62
use_locking : bool = False ,
54
63
name : str = "ConditionalGradient" ,
55
64
** kwargs
@@ -64,6 +73,8 @@ def __init__(
64
73
epsilon: A `Tensor` or a floating point value. A small constant
65
74
for numerical stability when handling the case of norm of
66
75
gradient to be zero.
76
+ ord: Order of the norm. Supported values are `'fro'`
77
+ and `'nuclear'`. Default is `'fro'`, which is frobenius norm.
67
78
use_locking: If `True`, use locks for update operations.
68
79
name: Optional name prefix for the operations created when
69
80
applying gradients. Defaults to 'ConditionalGradient'.
@@ -78,13 +89,21 @@ def __init__(
78
89
self ._set_hyper ("learning_rate" , kwargs .get ("lr" , learning_rate ))
79
90
self ._set_hyper ("lambda_" , lambda_ )
80
91
self .epsilon = epsilon or tf .keras .backend .epsilon ()
92
+ supported_norms = ["fro" , "nuclear" ]
93
+ if ord not in supported_norms :
94
+ raise ValueError (
95
+ "'ord' must be a supported matrix norm in %s, got '%s' instead"
96
+ % (supported_norms , ord )
97
+ )
98
+ self .ord = ord
81
99
self ._set_hyper ("use_locking" , use_locking )
82
100
83
101
def get_config (self ):
84
102
config = {
85
103
"learning_rate" : self ._serialize_hyperparameter ("learning_rate" ),
86
104
"lambda_" : self ._serialize_hyperparameter ("lambda_" ),
87
105
"epsilon" : self .epsilon ,
106
+ "ord" : self .ord ,
88
107
"use_locking" : self ._serialize_hyperparameter ("use_locking" ),
89
108
}
90
109
base_config = super ().get_config ()
@@ -106,23 +125,64 @@ def _prepare_local(self, var_device, var_dtype, apply_state):
106
125
self .epsilon , var_dtype
107
126
)
108
127
109
- def _resource_apply_dense (self , grad , var , apply_state = None ):
110
- def frobenius_norm (m ):
111
- return tf .math .reduce_sum (m ** 2 ) ** 0.5
128
+ @staticmethod
129
+ def _frobenius_norm (m ):
130
+ return tf .reduce_sum (m ** 2 ) ** 0.5
131
+
132
+ @staticmethod
133
+ def _top_singular_vector (m ):
134
+ # handle the case where m is a tensor of rank 0 or rank 1.
135
+ # Example:
136
+ # scalar (rank 0) a, shape []=> [[a]], shape [1,1]
137
+ # vector (rank 1) [a,b], shape [2] => [[a,b]], shape [1,2]
138
+ original_rank = tf .rank (m )
139
+ shape = tf .shape (m )
140
+ first_pad = tf .cast (tf .less (original_rank , 2 ), dtype = tf .int32 )
141
+ second_pad = tf .cast (tf .equal (original_rank , 0 ), dtype = tf .int32 )
142
+ new_shape = tf .concat (
143
+ [
144
+ tf .ones (shape = first_pad , dtype = tf .int32 ),
145
+ tf .ones (shape = second_pad , dtype = tf .int32 ),
146
+ shape ,
147
+ ],
148
+ axis = 0 ,
149
+ )
150
+ n = tf .reshape (m , new_shape )
151
+ st , ut , vt = tf .linalg .svd (n , full_matrices = False )
152
+ n_size = tf .shape (n )
153
+ ut = tf .reshape (ut [:, 0 ], [n_size [0 ], 1 ])
154
+ vt = tf .reshape (vt [:, 0 ], [n_size [1 ], 1 ])
155
+ st = tf .matmul (ut , tf .transpose (vt ))
156
+ # when we return the top singular vector, we have to remove the
157
+ # dimension we have added on
158
+ st_shape = tf .shape (st )
159
+ begin = tf .cast (tf .less (original_rank , 2 ), dtype = tf .int32 )
160
+ end = 2 - tf .cast (tf .equal (original_rank , 0 ), dtype = tf .int32 )
161
+ new_shape = st_shape [begin :end ]
162
+ return tf .reshape (st , new_shape )
112
163
164
+ def _resource_apply_dense (self , grad , var , apply_state = None ):
113
165
var_device , var_dtype = var .device , var .dtype .base_dtype
114
166
coefficients = (apply_state or {}).get (
115
167
(var_device , var_dtype )
116
168
) or self ._fallback_apply_state (var_device , var_dtype )
117
- norm = tf .convert_to_tensor (
118
- frobenius_norm (grad ), name = "norm" , dtype = var .dtype .base_dtype
119
- )
120
169
lr = coefficients ["learning_rate" ]
121
170
lambda_ = coefficients ["lambda_" ]
122
171
epsilon = coefficients ["epsilon" ]
123
- var_update_tensor = tf .math .multiply (var , lr ) - (1 - lr ) * lambda_ * grad / (
124
- norm + epsilon
125
- )
172
+ if self .ord == "fro" :
173
+ norm = tf .convert_to_tensor (
174
+ self ._frobenius_norm (grad ), name = "norm" , dtype = var .dtype .base_dtype
175
+ )
176
+ s = grad / (norm + epsilon )
177
+ else :
178
+ top_singular_vector = tf .convert_to_tensor (
179
+ self ._top_singular_vector (grad ),
180
+ name = "top_singular_vector" ,
181
+ dtype = var .dtype .base_dtype ,
182
+ )
183
+ s = top_singular_vector
184
+
185
+ var_update_tensor = tf .math .multiply (var , lr ) - (1 - lr ) * lambda_ * s
126
186
var_update_kwargs = {
127
187
"resource" : var .handle ,
128
188
"value" : var_update_tensor ,
@@ -131,23 +191,28 @@ def frobenius_norm(m):
131
191
return tf .group (var_update_op )
132
192
133
193
def _resource_apply_sparse (self , grad , var , indices , apply_state = None ):
134
- def frobenius_norm (m ):
135
- return tf .reduce_sum (m ** 2 ) ** 0.5
136
-
137
194
var_device , var_dtype = var .device , var .dtype .base_dtype
138
195
coefficients = (apply_state or {}).get (
139
196
(var_device , var_dtype )
140
197
) or self ._fallback_apply_state (var_device , var_dtype )
141
- norm = tf .convert_to_tensor (
142
- frobenius_norm (grad ), name = "norm" , dtype = var .dtype .base_dtype
143
- )
144
198
lr = coefficients ["learning_rate" ]
145
199
lambda_ = coefficients ["lambda_" ]
146
200
epsilon = coefficients ["epsilon" ]
147
201
var_slice = tf .gather (var , indices )
148
- var_update_value = tf .math .multiply (var_slice , lr ) - (
149
- 1 - lr
150
- ) * lambda_ * grad / (norm + epsilon )
202
+ if self .ord == "fro" :
203
+ norm = tf .convert_to_tensor (
204
+ self ._frobenius_norm (grad ), name = "norm" , dtype = var .dtype .base_dtype
205
+ )
206
+ s = grad / (norm + epsilon )
207
+ else :
208
+ top_singular_vector = tf .convert_to_tensor (
209
+ self ._top_singular_vector (grad ),
210
+ name = "top_singular_vector" ,
211
+ dtype = var .dtype .base_dtype ,
212
+ )
213
+ s = top_singular_vector
214
+
215
+ var_update_value = tf .math .multiply (var_slice , lr ) - (1 - lr ) * lambda_ * s
151
216
var_update_kwargs = {
152
217
"resource" : var .handle ,
153
218
"indices" : indices ,
0 commit comments