Skip to content

Commit 4e8c998

Browse files
Improve with providing Nuclear Norm Constraint (#1107)
* Improve with providing Nuclear Norm Constraint * Update based on feedback * Update Based on Feedback * Solve Conflict * Solve Conflict * change format and fix conflict Co-authored-by: gabrieldemarmiesse <[email protected]>
1 parent e632eee commit 4e8c998

File tree

2 files changed

+1106
-107
lines changed

2 files changed

+1106
-107
lines changed

tensorflow_addons/optimizers/conditional_gradient.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class ConditionalGradient(tf.keras.optimizers.Optimizer):
2727
2828
This optimizer helps handle constraints well.
2929
30-
Currently only supports frobenius norm constraint.
30+
Currently only supports frobenius norm constraint or nuclear norm
31+
constraint.
3132
See https://arxiv.org/pdf/1803.06453.pdf
3233
3334
```
@@ -42,6 +43,13 @@ class ConditionalGradient(tf.keras.optimizers.Optimizer):
4243
gradient is 0.
4344
4445
In this implementation, `epsilon` defaults to $10^{-7}$.
46+
47+
For nucler norm constraint, the formula is as following:
48+
49+
```
50+
variable -= (1-learning_rate) * (variable
51+
+ lambda_ * top_singular_vector(gradient))
52+
```
4553
"""
4654

4755
@typechecked
@@ -50,6 +58,7 @@ def __init__(
5058
learning_rate: Union[FloatTensorLike, Callable],
5159
lambda_: Union[FloatTensorLike, Callable] = 0.01,
5260
epsilon: FloatTensorLike = 1e-7,
61+
ord: str = "fro",
5362
use_locking: bool = False,
5463
name: str = "ConditionalGradient",
5564
**kwargs
@@ -64,6 +73,8 @@ def __init__(
6473
epsilon: A `Tensor` or a floating point value. A small constant
6574
for numerical stability when handling the case of norm of
6675
gradient to be zero.
76+
ord: Order of the norm. Supported values are `'fro'`
77+
and `'nuclear'`. Default is `'fro'`, which is frobenius norm.
6778
use_locking: If `True`, use locks for update operations.
6879
name: Optional name prefix for the operations created when
6980
applying gradients. Defaults to 'ConditionalGradient'.
@@ -78,13 +89,21 @@ def __init__(
7889
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
7990
self._set_hyper("lambda_", lambda_)
8091
self.epsilon = epsilon or tf.keras.backend.epsilon()
92+
supported_norms = ["fro", "nuclear"]
93+
if ord not in supported_norms:
94+
raise ValueError(
95+
"'ord' must be a supported matrix norm in %s, got '%s' instead"
96+
% (supported_norms, ord)
97+
)
98+
self.ord = ord
8199
self._set_hyper("use_locking", use_locking)
82100

83101
def get_config(self):
84102
config = {
85103
"learning_rate": self._serialize_hyperparameter("learning_rate"),
86104
"lambda_": self._serialize_hyperparameter("lambda_"),
87105
"epsilon": self.epsilon,
106+
"ord": self.ord,
88107
"use_locking": self._serialize_hyperparameter("use_locking"),
89108
}
90109
base_config = super().get_config()
@@ -106,23 +125,64 @@ def _prepare_local(self, var_device, var_dtype, apply_state):
106125
self.epsilon, var_dtype
107126
)
108127

109-
def _resource_apply_dense(self, grad, var, apply_state=None):
110-
def frobenius_norm(m):
111-
return tf.math.reduce_sum(m ** 2) ** 0.5
128+
@staticmethod
129+
def _frobenius_norm(m):
130+
return tf.reduce_sum(m ** 2) ** 0.5
131+
132+
@staticmethod
133+
def _top_singular_vector(m):
134+
# handle the case where m is a tensor of rank 0 or rank 1.
135+
# Example:
136+
# scalar (rank 0) a, shape []=> [[a]], shape [1,1]
137+
# vector (rank 1) [a,b], shape [2] => [[a,b]], shape [1,2]
138+
original_rank = tf.rank(m)
139+
shape = tf.shape(m)
140+
first_pad = tf.cast(tf.less(original_rank, 2), dtype=tf.int32)
141+
second_pad = tf.cast(tf.equal(original_rank, 0), dtype=tf.int32)
142+
new_shape = tf.concat(
143+
[
144+
tf.ones(shape=first_pad, dtype=tf.int32),
145+
tf.ones(shape=second_pad, dtype=tf.int32),
146+
shape,
147+
],
148+
axis=0,
149+
)
150+
n = tf.reshape(m, new_shape)
151+
st, ut, vt = tf.linalg.svd(n, full_matrices=False)
152+
n_size = tf.shape(n)
153+
ut = tf.reshape(ut[:, 0], [n_size[0], 1])
154+
vt = tf.reshape(vt[:, 0], [n_size[1], 1])
155+
st = tf.matmul(ut, tf.transpose(vt))
156+
# when we return the top singular vector, we have to remove the
157+
# dimension we have added on
158+
st_shape = tf.shape(st)
159+
begin = tf.cast(tf.less(original_rank, 2), dtype=tf.int32)
160+
end = 2 - tf.cast(tf.equal(original_rank, 0), dtype=tf.int32)
161+
new_shape = st_shape[begin:end]
162+
return tf.reshape(st, new_shape)
112163

164+
def _resource_apply_dense(self, grad, var, apply_state=None):
113165
var_device, var_dtype = var.device, var.dtype.base_dtype
114166
coefficients = (apply_state or {}).get(
115167
(var_device, var_dtype)
116168
) or self._fallback_apply_state(var_device, var_dtype)
117-
norm = tf.convert_to_tensor(
118-
frobenius_norm(grad), name="norm", dtype=var.dtype.base_dtype
119-
)
120169
lr = coefficients["learning_rate"]
121170
lambda_ = coefficients["lambda_"]
122171
epsilon = coefficients["epsilon"]
123-
var_update_tensor = tf.math.multiply(var, lr) - (1 - lr) * lambda_ * grad / (
124-
norm + epsilon
125-
)
172+
if self.ord == "fro":
173+
norm = tf.convert_to_tensor(
174+
self._frobenius_norm(grad), name="norm", dtype=var.dtype.base_dtype
175+
)
176+
s = grad / (norm + epsilon)
177+
else:
178+
top_singular_vector = tf.convert_to_tensor(
179+
self._top_singular_vector(grad),
180+
name="top_singular_vector",
181+
dtype=var.dtype.base_dtype,
182+
)
183+
s = top_singular_vector
184+
185+
var_update_tensor = tf.math.multiply(var, lr) - (1 - lr) * lambda_ * s
126186
var_update_kwargs = {
127187
"resource": var.handle,
128188
"value": var_update_tensor,
@@ -131,23 +191,28 @@ def frobenius_norm(m):
131191
return tf.group(var_update_op)
132192

133193
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
134-
def frobenius_norm(m):
135-
return tf.reduce_sum(m ** 2) ** 0.5
136-
137194
var_device, var_dtype = var.device, var.dtype.base_dtype
138195
coefficients = (apply_state or {}).get(
139196
(var_device, var_dtype)
140197
) or self._fallback_apply_state(var_device, var_dtype)
141-
norm = tf.convert_to_tensor(
142-
frobenius_norm(grad), name="norm", dtype=var.dtype.base_dtype
143-
)
144198
lr = coefficients["learning_rate"]
145199
lambda_ = coefficients["lambda_"]
146200
epsilon = coefficients["epsilon"]
147201
var_slice = tf.gather(var, indices)
148-
var_update_value = tf.math.multiply(var_slice, lr) - (
149-
1 - lr
150-
) * lambda_ * grad / (norm + epsilon)
202+
if self.ord == "fro":
203+
norm = tf.convert_to_tensor(
204+
self._frobenius_norm(grad), name="norm", dtype=var.dtype.base_dtype
205+
)
206+
s = grad / (norm + epsilon)
207+
else:
208+
top_singular_vector = tf.convert_to_tensor(
209+
self._top_singular_vector(grad),
210+
name="top_singular_vector",
211+
dtype=var.dtype.base_dtype,
212+
)
213+
s = top_singular_vector
214+
215+
var_update_value = tf.math.multiply(var_slice, lr) - (1 - lr) * lambda_ * s
151216
var_update_kwargs = {
152217
"resource": var.handle,
153218
"indices": indices,

0 commit comments

Comments
 (0)