Skip to content

Example in torchscript documentation does not work #24429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
gandroz opened this issue Aug 15, 2019 · 2 comments
Closed

Example in torchscript documentation does not work #24429

gandroz opened this issue Aug 15, 2019 · 2 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@gandroz
Copy link

gandroz commented Aug 15, 2019

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. Go to doc and check for code example
  2. Copy/paste into fresh install venv woth pytorch 1.2
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight

example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
n = Net()
# the following two calls are equivalent
module = torch.jit.trace_module(n, example_forward_input)

Error message

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-2-bd56084c5306> in <module>
     14 n = Net()
     15 # the following two calls are equivalent
---> 16 module = torch.jit.trace_module(n, example_forward_input)
     17 module = torch.jit.trace_module(n.forward, example_forward_input)

~/miniconda3/envs/ai/lib/python3.7/site-packages/torch/jit/__init__.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, _force_outplace, _module_class, _compilation_unit)
    894 
    895     if not isinstance(inputs, dict):
--> 896         raise AttributeError("expected a dictionary of (method_name, input) pairs")
    897 
    898     module = make_module(mod, _module_class, _compilation_unit)

AttributeError: expected a dictionary of (method_name, input) pairs

Expected behavior

Doc example to work :)

Environment

PyTorch version: 1.2.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti

Nvidia driver version: 430.14
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.0

Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] torch==1.2.0
[pip] torchsummary==1.5.1
[pip] torchvision==0.4.0a0+6b959ee
[conda] Could not collect

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Aug 15, 2019
@gandroz
Copy link
Author

gandroz commented Aug 15, 2019

As suggested by the error message, a fix could be:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight

example_forward_input = torch.rand(1, 1, 3, 3)
n = Net()
inputs = {'forward' : example_forward_input}
module = torch.jit.trace_module(n, inputs)

or

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight

example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
n = Net()
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)

@driazati driazati self-assigned this Aug 15, 2019
@driazati
Copy link
Contributor

This is fixed in the master docs, but we still need to push that to the stable docs so it's what people actually see. We're also in the process of adding our doc examples to our CI so this won't happen in the future.

facebook-github-bot pushed a commit that referenced this issue Aug 21, 2019
Summary:
Stacked PRs
 * #24445 - [jit] Misc doc updates #2
 * **#24435 - [jit] Add docs to CI**

This integrates the [doctest](http://www.sphinx-doc.org/en/master/usage/extensions/doctest.html) module into `jit.rst` so that we can run our code examples as unit tests. They're added to `test_jit.py` under the `TestDocs` class (which takes about 30s to run). This should help prevent things like #24429 from happening in the future. They can be run manually by doing `cd docs && make doctest`.

* The test setup requires a hack since `doctest` defines everything in the `builtins` module which upsets `inspect`
* There are several places where the code wasn't testable (i.e. it threw an exception on purpose). This may be resolvable, but I'd prefer to leave that for a follow up. For now there are `TODO` comments littered around.
](https://our.intern.facebook.com/intern/diff/16840882/)
Pull Request resolved: #24435

Pulled By: driazati

Differential Revision: D16840882

fbshipit-source-id: c4b26e7c374cd224a5a4a2d523163d7b997280ed
@suo suo closed this as completed Oct 3, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

No branches or pull requests

4 participants