Skip to content

Commit dfd2a9b

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
utils.sigmoid with log and fat options (#1938)
Summary: Pull Request resolved: #1938 This commit introduces `utils.sigmoid` with `log` and `fat` options, thereby enabling easy access and discovery of the `logexpit` and `(log)_fatmoid` functions that exhibit better numerical behavior - when applicable - than the canonical sigmoid. Reviewed By: Balandat Differential Revision: D47519695 fbshipit-source-id: 5389026ac9f6d95c4d70d69c61c61e8aa00cb87e
1 parent 50bcf95 commit dfd2a9b

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

botorch/utils/objective.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from __future__ import annotations
1212

13+
import warnings
14+
1315
from typing import Callable, List, Optional, Union
1416

1517
import torch
@@ -166,6 +168,7 @@ def compute_smoothed_feasibility_indicator(
166168
return is_feasible if log else is_feasible.exp()
167169

168170

171+
# TODO: deprecate this function
169172
def soft_eval_constraint(lhs: Tensor, eta: float = 1e-3) -> Tensor:
170173
r"""Element-wise evaluation of a constraint in a 'soft' fashion
171174
@@ -181,6 +184,11 @@ def soft_eval_constraint(lhs: Tensor, eta: float = 1e-3) -> Tensor:
181184
For each element `x`, `value(x) -> 0` as `x` becomes positive, and
182185
`value(x) -> 1` as x becomes negative.
183186
"""
187+
warnings.warn(
188+
"`soft_eval_constraint` is deprecated. Please consider `torch.utils.sigmoid` "
189+
+ "with its `fat` and `log` options to compute feasibility indicators.",
190+
DeprecationWarning,
191+
)
184192
if eta <= 0:
185193
raise ValueError("eta must be positive.")
186194
return torch.sigmoid(-lhs / eta)

botorch/utils/safe_math.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,20 @@ def fatmoid(X: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor:
257257
def cauchy(x: Tensor) -> Tensor:
258258
"""Computes a Lorentzian, i.e. an un-normalized Cauchy density function."""
259259
return 1 / (1 + x.square())
260+
261+
262+
def sigmoid(X: Tensor, log: bool = False, fat: bool = False) -> Tensor:
263+
"""A sigmoid function with an optional fat tail and evaluation in log space for
264+
better numerical behavior. Notably, the fat-tailed sigmoid can be used to remedy
265+
numerical underflow problems in the value and gradient of the canonical sigmoid.
266+
267+
Args:
268+
X: The Tensor on which to evaluate the sigmoid.
269+
log: Toggles the evaluation of the log sigmoid.
270+
fat: Toggles the evaluation of the fat-tailed sigmoid.
271+
272+
Returns:
273+
A Tensor of (log-)sigmoid values.
274+
"""
275+
Y = log_fatmoid(X) if fat else logexpit(X)
276+
return Y if log else Y.exp()

test/utils/test_safe_math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
log_fatmoid,
2525
log_fatplus,
2626
log_softplus,
27+
logexpit,
2728
logmeanexp,
29+
sigmoid,
2830
smooth_amax,
2931
)
3032
from botorch.utils.testing import BotorchTestCase
@@ -395,5 +397,15 @@ def test_smooth_non_linearities(self):
395397
)
396398
self.assertFalse((log_feas_vals.exp() > 1 / 2).item())
397399

400+
# testing sigmoid wrapper function
401+
X = torch.randn(3, 4, 5, **tkwargs)
402+
sigmoid_X = torch.sigmoid(X)
403+
self.assertAllClose(sigmoid(X), sigmoid_X)
404+
self.assertAllClose(sigmoid(X, log=True), logexpit(X))
405+
self.assertAllClose(sigmoid(X, log=True).exp(), sigmoid_X)
406+
fatmoid_X = fatmoid(X)
407+
self.assertAllClose(sigmoid(X, fat=True), fatmoid_X)
408+
self.assertAllClose(sigmoid(X, log=True, fat=True).exp(), fatmoid_X)
409+
398410
with self.assertRaisesRegex(UnsupportedError, "Only dtypes"):
399411
log_softplus(torch.randn(2, dtype=torch.float16))

0 commit comments

Comments
 (0)