Skip to content

Commit 5aa96fd

Browse files
jon-chuangpytorchmergebot
authored andcommitted
[dynamo] list index: add more list types to testing, support namedtuple, improve error handling (pytorch#110919)
Follow up: pytorch#110817 Minor improvements as discussed in prev PR Pull Request resolved: pytorch#110919 Approved by: https://github.com/ezyang
1 parent 9606cda commit 5aa96fd

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

test/dynamo/test_repros.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3068,18 +3068,31 @@ def f(x, y):
30683068
)
30693069

30703070
def test_list_index(self):
3071-
for index in ([], [2], [0, 3]):
3071+
for i, list_type in enumerate(
3072+
(
3073+
list,
3074+
tuple,
3075+
torch.Size,
3076+
collections.deque,
3077+
namedtuple("FourElems", "one two three four", defaults=[0, 0, 0, 0]),
3078+
)
3079+
):
3080+
torch._dynamo.reset()
3081+
for index in ([], [2], [0, 3]):
30723082

3073-
def f(t):
3074-
xs = ["bar", "foo", "baz", "buzz"]
3075-
res = xs.index("baz", *index)
3076-
return t + res
3083+
def f(t):
3084+
if i == 4: # namedtuple
3085+
xs = list_type(1, 2, 3, 4)
3086+
else:
3087+
xs = list_type([1, 2, 3, 4])
3088+
res = xs.index(3, *index)
3089+
return t + res
30773090

3078-
res = torch._dynamo.optimize(backend="eager", nopython=True)(f)(
3079-
torch.zeros(1)
3080-
)
3091+
res = torch._dynamo.optimize(backend="eager", nopython=True)(f)(
3092+
torch.zeros(1)
3093+
)
30813094

3082-
self.assertEqual(res, torch.tensor([2.0]))
3095+
self.assertEqual(res, torch.tensor([2.0]))
30833096

30843097
def test_list_index_not_found(self):
30853098
def f(t):

torch/_dynamo/variables/lists.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,6 @@ def call_method(
155155
elif name == "index":
156156
from .builder import SourcelessBuilder
157157

158-
assert len(kwargs) == 0
159-
assert len(args) > 0 and len(args) <= 3
160158
return tx.inline_user_function_return(
161159
SourcelessBuilder()(tx, polyfill.index), [self] + list(args), kwargs
162160
)
@@ -653,7 +651,7 @@ def check_and_create_method():
653651
if name not in fields:
654652
method = check_and_create_method()
655653
if not method:
656-
unimplemented(f"NamedTupleVariable.{name}")
654+
super().var_getattr(tx, name)
657655
return method
658656
return self.items[fields.index(name)].add_options(self)
659657

0 commit comments

Comments
 (0)