Skip to content

Make torch.nn.functional.multilabel_soft_margin_loss more stable #9141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
vadimkantorov opened this issue Jul 3, 2018 · 0 comments
Closed
Assignees
Labels
todo Not as important as medium or high priority tasks, but we will work on these.

Comments

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jul 3, 2018

Currently it calls sigmoid and then binary_cross_entropy (which does log internally):
https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L1796-L1797

Is it not possible to compute it as: torch.mean(-(y * F.logsigmoid(x) + (1 - y) * F.logsigmoid(-x))) ?

y * log(1 / (1 + exp(-x)) + (1 - y) * log(1 - 1 / (1 + exp(-x))) =
y * logsigmoid(x) + (1 - y) * log((1 + exp(-x) - 1) / (1 + exp(-x))) = 
y * logsigmoid(x) + (1 - y) * log(1 / (1 + exp(x))) = 
y * logsigmoid(x) + (1 - y) * logsigmoid(-x)
@zou3519 zou3519 added the todo Not as important as medium or high priority tasks, but we will work on these. label Jul 9, 2018
@weiyangfb weiyangfb self-assigned this Jul 26, 2018
goodlux pushed a commit to goodlux/pytorch that referenced this issue Aug 15, 2018
… shape=(N, C)to (N,) (pytorch#9965)

Summary:
- fixes pytorch#9141, pytorch#9301
- use logsigmoid at multilabel_soft_margin_loss to make it more stable (NOT fixing legacy MultiLabelSoftMarginCriterion)
- return (N) instead of (N, C) to match the same behavior as MultiMarginLoss
- Note that with this PR, the following behavior is expected:
```
loss = F.multilabel_soft_margin_loss(outputs, labels, reduction='none')
loss_mean = F.multilabel_soft_margin_loss(outputs, labels, reduction='elementwise_mean')
loss_sum = F.multilabel_soft_margin_loss(outputs, labels, reduction='sum')

loss.sum() == loss_sum  # True
loss.mean() == loss_mean  # True
```
Pull Request resolved: pytorch#9965

Differential Revision: D9038402

Pulled By: weiyangfb

fbshipit-source-id: 0fa94c7b3cd370ea62bd6333f1a0e9bd0b8ccbb9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
todo Not as important as medium or high priority tasks, but we will work on these.
Projects
None yet
3 participants