@@ -445,12 +445,17 @@ def run_aot_eager(self, f, orig_args, _dynamic=False):
445
445
graph = "\n " .join (log_stream .getvalue ().strip ().split ("\n " )[4 :]).strip ()
446
446
return [aot_eager_args , result , graph ]
447
447
448
- def run_inductor (self , f , orig_args , _dynamic = False ):
448
+ def run_inductor (
449
+ self ,
450
+ f ,
451
+ orig_args ,
452
+ _dynamic = False ,
453
+ log_module = "torch._inductor.compile_fx" ,
454
+ log_function = "post_grad_graphs" ,
455
+ ):
449
456
compiled_args = pytree .tree_map_only (torch .Tensor , torch .clone , orig_args )
450
457
451
- log_stream , ctx = logs_to_string (
452
- "torch._inductor.compile_fx" , "post_grad_graphs"
453
- )
458
+ log_stream , ctx = logs_to_string (log_module , log_function )
454
459
result = None
455
460
with ctx ():
456
461
result = torch .compile (
@@ -1733,6 +1738,41 @@ def f(x, w):
1733
1738
y = f (x , w )
1734
1739
self .assertEqual (y , x .sin ())
1735
1740
1741
+ @torch ._inductor .config .patch (enable_auto_functionalized_v2 = True )
1742
+ def test_scheduling_with_multiple_mutates (self ):
1743
+ with torch .library ._scoped_library ("mylib" , "FRAGMENT" ) as lib :
1744
+ torch .library .define (
1745
+ "mylib::foo" ,
1746
+ "(Tensor! x, Tensor! y, Tensor z) -> ()" ,
1747
+ tags = torch .Tag .pt2_compliant_tag ,
1748
+ lib = lib ,
1749
+ )
1750
+
1751
+ @torch .library .impl ("mylib::foo" , "cpu" , lib = lib )
1752
+ @torch ._dynamo .disable
1753
+ def foo (x , y , z ):
1754
+ pass
1755
+
1756
+ def func (x , w ):
1757
+ a = torch .empty_like (x ) # buf0
1758
+ b = torch .empty_like (x ) # buf1
1759
+ torch .ops .mylib .foo (a , b , x ) # buf2, buf3, buf4
1760
+ c = torch .mm (a , w ) # buf5
1761
+ torch .ops .mylib .foo (c , b , x ) # buf6, buf7, buf8
1762
+ return c
1763
+
1764
+ input = torch .rand (2 , 2 )
1765
+ weight = torch .rand (2 , 2 )
1766
+ [inductor_args , output , graph_inductor ] = self .run_inductor (
1767
+ func ,
1768
+ [input , weight ],
1769
+ False ,
1770
+ "torch._inductor.scheduler" ,
1771
+ "compute_dependencies" ,
1772
+ )
1773
+ name_to_users = eval (graph_inductor )
1774
+ self .assertNotEqual (name_to_users ["buf1" ], name_to_users ["buf5" ])
1775
+
1736
1776
1737
1777
if __name__ == "__main__" :
1738
1778
from torch ._inductor .test_case import run_tests
0 commit comments