Skip to content

fix forward and backward for norm with negative infinity norm #12722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

fix forward and backward for norm with negative infinity norm #12722

wants to merge 1 commit into from

Conversation

Separius
Copy link
Contributor

@Separius Separius commented Oct 16, 2018

I found a bug in norm() and fixed it (and added tests to make sure it's fixed)
here is how to reproduce it:

import torch
x = torch.FloatTensor([[10, 12, 13], [4, 0, 12]])
print(torch.norm(x, -40, dim=0, keepdim=True)) #output is tensor([[ 4.0000,  0.0000, 11.9853]])
print(torch.norm(x, float('-inf'), dim=0, keepdim=True)) #output is tensor([[1., 1., 1.]]) which is wrong!
from numpy.linalg import norm as np_norm
x = x.numpy()
print(np_norm(x, ord=-40, axis=0)) #output is array([[4., 0., 11.985261]])
print(np_norm(x, ord=float('-inf'), axis=0)) #output is array([[4., 0., 12.0]])

it's related to #6817 and #6969

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Code looks very reasonable. Thank you! Our CI seems to be having some problem. Could you rebase to master and force push? Sorry about it!

@Separius
Copy link
Contributor Author

Separius commented Oct 17, 2018

@ssnl I checked CI's log and there were two problems, one is with data_loader which is probably unrelated to my code, but the other one was a jit test that my code failed, turns out there is a function in test_jit.py named get_constants() and I think it has a bug(an obvious one), but maybe(0.001%) it was intentional, can you check that too?; I will rebase after that
one more thing is it better to squash all my commits into one or two is better in this case? (jit test; norm)

Edit: ah, looks like it was intentional :),
RuntimeError: Only 'inf' can be cast to a float, but got '-inf' (operator() at /var/lib/jenkins/workspace/torch/csrc/jit/register_prim_ops.cpp:178)
I can look at register_prim_ops.cpp and change it to accept -inf or I can simply exclude -inf test in jit
first one is obviously better, right?

Edit: I changed register_prim_ops.cpp to accept -inf too

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh Thanks! I must admit that I didn't look at the log and assumed that it's the CI base commit problem (rebase usually solves that). My bad! We recently switched from jenkins to CircleCI so there are still some rough edges around CI. Thanks for fixing the JIT issue!

@Separius
Copy link
Contributor Author

rebased and squashed all commits into one, ready to merge! :) @ssnl
BTW: thanks for your fast replies

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

push(stack, std::numeric_limits<double>::infinity());
else if (s->string() == "-inf")
push(stack, -std::numeric_limits<double>::infinity());
else
AT_ERROR(
"Only 'inf' can be cast to a float, but got '",

This comment was marked as off-topic.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jit changes look good, but please add a test for -inf to

pytorch/test/test_jit.py

Lines 2811 to 2816 in 7edfe11

def test_inf(self):
@torch.jit.script
def foo(a):
return a < float('inf')
s = torch.rand(1)
self.assertTrue(foo(s))

@Separius
Copy link
Contributor Author

@zou3519 fixed the error message and added a test for -inf

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 18, 2018
Summary:
I found a bug in norm() and fixed it (and added tests to make sure it's fixed)
here is how to reproduce it:
```python
import torch
x = torch.FloatTensor([[10, 12, 13], [4, 0, 12]])
print(torch.norm(x, -40, dim=0, keepdim=True)) #output is tensor([[ 4.0000,  0.0000, 11.9853]])
print(torch.norm(x, float('-inf'), dim=0, keepdim=True)) #output is tensor([[1., 1., 1.]]) which is wrong!
from numpy.linalg import norm as np_norm
x = x.numpy()
print(np_norm(x, ord=-40, axis=0)) #output is array([[4., 0., 11.985261]])
print(np_norm(x, ord=float('-inf'), axis=0)) #output is array([[4., 0., 12.0]])
```
it's related to [#6817](pytorch/pytorch#6817) and [#6969](pytorch/pytorch#6969)
Pull Request resolved: pytorch/pytorch#12722

Differential Revision: D10427687

Pulled By: soumith

fbshipit-source-id: 936a7491d1e2625410513ee9c39f8c910e8e6803
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants