Skip to content

[JIT] Add Support for NoneLiteral. #8925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from

Conversation

wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jun 27, 2018

Currently we are not be able to support None type yet, things like None in the statements or passing None as argument is not possible (like clamp).

This PR add support for the NoneLiteral and NoneType in script. In a following PR will enable the #8502 tests, and address clamp's gradients problem in Aten native when passing None to clamp.

@wanchaol wanchaol requested a review from jamesr66a June 27, 2018 00:31
@wanchaol wanchaol changed the title fix frontend NameConstant to match python3 ast syntax [JIT] Fix frontend NameConstant to match python3 ast syntax Jun 27, 2018
@wanchaol wanchaol force-pushed the nameconstant branch 2 times, most recently from 275d043 to 3b492bd Compare June 27, 2018 00:44
return FalseLiteral(r)
return Var(Ident(r, "None"))

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor

ezyang commented Jun 27, 2018

@pytorchbot retest this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

return FalseLiteral(r)
elif expr.value is None:
return Var(Ident(r, "None"))

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Jun 28, 2018

Needs a test. Caffe2 test failure is unrelated.

@ezyang
Copy link
Contributor

ezyang commented Jun 28, 2018

Test comment.

@wanchaol
Copy link
Collaborator Author

@ezyang For test, are you saying we could cover the case of the None variable?

@ezyang
Copy link
Contributor

ezyang commented Jun 29, 2018

Yes. Basically, something that failed before this PR, and will pass after it.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@jamesr66a
Copy link
Collaborator

We currently have this hack for emitting a NoneType return value from Print: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/script/compiler.cpp#L637

Can we consolidate this? Can we make it so emitSimpleExpr in the compiler emits one of these?

I think as it is right now, we're going to end up with a variable named None, which is probably bad

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should treat None as a literal in both Python 2 and Python 3, like how we handle True and False. So this PR doesn't seem like an incremental step toward the end solution.

return FalseLiteral(r)
elif expr.value is None:
return Var(Ident(r, "None"))

This comment was marked as off-topic.

@wanchaol wanchaol changed the title [JIT] Fix frontend NameConstant to match python3 ast syntax [WIP JIT] Don't review yet. Add Support for NoneLiteral and Implement Clamp as Aten Native. Jul 6, 2018
@wanchaol wanchaol force-pushed the nameconstant branch 3 times, most recently from 6fe9d91 to 082a684 Compare July 12, 2018 23:46
@wanchaol wanchaol changed the title [WIP JIT] Don't review yet. Add Support for NoneLiteral and Implement Clamp as Aten Native. [JIT] Add Support for NoneLiteral and enable clamp tests. Jul 13, 2018
@wanchaol wanchaol added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jul 13, 2018
@wanchaol wanchaol force-pushed the nameconstant branch 2 times, most recently from 8e91318 to bc5e4a5 Compare July 13, 2018 22:27
@wanchaol wanchaol changed the title [JIT] Add Support for NoneLiteral and enable clamp tests. [JIT] Add Support for NoneLiteral. Jul 13, 2018
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@wanchaol
Copy link
Collaborator Author

superseded by #9596

facebook-github-bot pushed a commit that referenced this pull request Jul 28, 2018
…JIT (#9596)

Summary:
Supersedes #8925

This PR fixes #8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like:

```python
torch.jit.script
def func():
    x = torch.randn(3, 3, requires_grad=True)
    y = torch.clamp(x, None, 0) # max = 0
    y = torch.clamp(x, min=None, max=0)
```

In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp.

In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed.

Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp).

zdevito jamesr66a
Pull Request resolved: #9596

Reviewed By: zdevito

Differential Revision: D8940839

Pulled By: wanchaol

fbshipit-source-id: c543a867b82e0ab8c99384773b173fdde2605d28
jramseyer pushed a commit to jramseyer/pytorch that referenced this pull request Jul 30, 2018
…JIT (pytorch#9596)

Summary:
Supersedes pytorch#8925

This PR fixes pytorch#8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like:

```python
torch.jit.script
def func():
    x = torch.randn(3, 3, requires_grad=True)
    y = torch.clamp(x, None, 0) # max = 0
    y = torch.clamp(x, min=None, max=0)
```

In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp.

In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed.

Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp).

zdevito jamesr66a
Pull Request resolved: pytorch#9596

Reviewed By: zdevito

Differential Revision: D8940839

Pulled By: wanchaol

fbshipit-source-id: c543a867b82e0ab8c99384773b173fdde2605d28
goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
…JIT (pytorch#9596)

Summary:
Supersedes pytorch#8925

This PR fixes pytorch#8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like:

```python
torch.jit.script
def func():
    x = torch.randn(3, 3, requires_grad=True)
    y = torch.clamp(x, None, 0) # max = 0
    y = torch.clamp(x, min=None, max=0)
```

In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp.

In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed.

Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp).

zdevito jamesr66a
Pull Request resolved: pytorch#9596

Reviewed By: zdevito

Differential Revision: D8940839

Pulled By: wanchaol

fbshipit-source-id: c543a867b82e0ab8c99384773b173fdde2605d28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants