Skip to content

Commit e5ae0e6

Browse files
suofacebook-github-bot
authored andcommitted
[jit] Allow instance overrides of ignored methods (pytorch#61076)
Summary: Pull Request resolved: pytorch#61076 Previously we would always retrieve ignored methods from the type, which doesn't work when the user has overriden the ignored method for a specific instance. This PR changes things up so we retrieve the ignored method as a bound method from the object being scripted, unwrap it, then re-bind it to the scriptmodule. Test Plan: Imported from OSS Differential Revision: D29504421 Pulled By: suo fbshipit-source-id: 14649863ea69a8d2180dd2c4341ec9a826039de1
1 parent ccfdb30 commit e5ae0e6

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

test/jit/test_recursive_script.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import sys
3+
import types
34
import typing
45
import typing_extensions
56
from typing import List, Dict, Optional, Tuple
@@ -729,3 +730,23 @@ def forward(self, x):
729730
self.checkModule(mod, (torch.rand(2, 2),))
730731
mod.foo = None
731732
self.checkModule(mod, (torch.rand(2, 2),))
733+
734+
def test_override_instance_method_ignore(self):
735+
class M(torch.nn.Module):
736+
@torch.jit.ignore
737+
def i_am_ignored(self):
738+
return "old"
739+
740+
m = M()
741+
742+
# Override the ignored method by binding a new method to this instance.
743+
@torch.jit.ignore
744+
def i_am_ignored(self):
745+
return "new"
746+
747+
m.i_am_ignored = types.MethodType(i_am_ignored, m)
748+
self.assertEqual(m.i_am_ignored(), "new")
749+
750+
# ScriptModule should correctly reflect the override.
751+
s = torch.jit.script(m)
752+
self.assertEqual(s.i_am_ignored(), "new")

torch/jit/_recursive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def init_fn(script_module):
499499
continue
500500
item = getattr(nn_module, name, None)
501501
if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
502-
unbound_function = getattr(type(nn_module), name)
502+
unbound_function = getattr(nn_module, name).__func__
503503
bound_method = unbound_function.__get__(script_module)
504504
setattr(script_module, name, bound_method)
505505
elif concrete_type.is_ignored_attribute(name):

0 commit comments

Comments
 (0)