Skip to content

Commit 8b3daf1

Browse files
masnesralpytorchmergebot
authored andcommitted
Add FloatTrueDiv and ToFloat to SYMPY_INTERP (pytorch#128418)
Summary: I admit I'm not 100% sure what I'm doing here. I'm hitting a bug in the FX graph cache when we try to evaluate a guards expression. We're creating guards that look like this: ``` Ne(CeilToInt(FloatTrueDiv(ToFloat(8*L['t0']) - 4.0, 8.0))*CeilToInt(FloatTrueDiv(ToFloat(8*L['t1']) - 4.0, 8.0)), CeilToInt(FloatTrueDiv(ToFloat(8*L['t1']) - 4.0, 8.0))) and ... ``` It looks like we have a facility to define these operators in the SYMPY_INTERP map and we're just missing FloatTrueDiv and ToFloat. What's surprsing to me is that we're only hitting this problem with the FX graph enabled. We can create such guards, but we've never actually evaluated any? Test Plan: `TORCHINDUCTOR_FX_GRAPH_CACHE=1 python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --only detectron2_fcos_r_50_fpn` Pull Request resolved: pytorch#128418 Approved by: https://github.com/ezyang
1 parent a421699 commit 8b3daf1

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

torch/fx/experimental/symbolic_shapes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,6 +1393,8 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt:
13931393
'RoundDecimal': builtins.round,
13941394
'TruncToInt': math.trunc,
13951395
'IntTrueDiv': operator.truediv,
1396+
'FloatTrueDiv': operator.truediv,
1397+
'ToFloat': builtins.float,
13961398
}
13971399

13981400

0 commit comments

Comments
 (0)