Skip to content

Commit 83db609

Browse files
sijiacpytorchmergebot
authored andcommitted
[inductor] fix the cudagraph tree test (pytorch#132043)
Summary: There are two kinds of exceptions: Case #1: ``` static input data pointer changed. input name: primals_2. data pointer changed from 140315748992000 to 140315748993536. input stack trace: File "/dev/shm/uid-30083/c0899c70-seed-nspid4026535598_cgpid16622182-ns-4026535192/caffe2/test/inductor/test_cudagraph_trees.py", line 1826, in forward return self.static_tensor + x + self.goo(x) File "/dev/shm/uid-30083/c0899c70-seed-nspid4026535598_cgpid16622182-ns-4026535192/caffe2/test/inductor/test_cudagraph_trees.py", line 1816, in forward return self.linear(x) input name: primals_3. data pointer changed from 140315748990976 to 140315748993024. input stack trace: File "/dev/shm/uid-30083/c0899c70-seed-nspid4026535598_cgpid16622182-ns-4026535192/caffe2/test/inductor/test_cudagraph_trees.py", line 1825, in forward self.static_tensor.add_(torch.ones((2, 2), device="cuda")) ``` Case #2: ``` static input data pointer changed. input name: primals_2. data pointer changed from 139852509086720 to 139852509088256. input stack trace: None input name: primals_3. data pointer changed from 139852509085696 to 139852509087744. input stack trace: File "/dev/shm/uid-30083/f61ee184-seed-nspid4026560782_cgpid769179-ns-4026560865/caffe2/test/inductor/test_cudagraph_trees.py", line 1825, in forward self.static_tensor.add_(torch.ones((2, 2), device="cuda")) ``` The current impl only covered the case #2 Test Plan: https://www.internalfb.com/intern/testinfra/testrun/15481123762274476 Differential Revision: D60340212 Pull Request resolved: pytorch#132043 Approved by: https://github.com/BoyuanFeng
1 parent 36e8289 commit 83db609

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

test/inductor/test_cudagraph_trees.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,9 +1839,9 @@ def forward(self, x) -> torch.Tensor:
18391839
with self.assertRaisesRegex(
18401840
Exception,
18411841
r"static input data pointer changed.\n"
1842-
r"input name: primals_2. data pointer changed from .* to .*. input stack trace: None\n"
1842+
r"input name: primals_2. data pointer changed from .* to .*. input stack trace:(?s).*"
18431843
r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*,"
1844-
r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n\n",
1844+
r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n",
18451845
):
18461846
self.curr_node().run(
18471847
[foo.goo.linear.weight, foo.goo.linear.bias, foo.static_tensor, inp]

0 commit comments

Comments
 (0)