Closed
Description
The system works pretty well for traced models, but not much work has been done with torch script models that have branching. I noticed some common cases that we should be able to handle include branching for none arguments such as graphs like this:
%50 : Function = prim::Constant[name="linear"]()
%53 : bool = prim::Constant[value=0]() # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:7
%54 : None = prim::Constant() # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:40
%55 : int = prim::Constant[value=2]() # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:22
%56 : int = prim::Constant[value=1]() # :0:0
%57 : int = aten::dim(%input1.1) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:7
%58 : bool = aten::eq(%57, %55) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:7
%59 : bool = prim::If(%58) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:7
block0():
%60 : bool = aten::__isnot__(%94, %54) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:28
-> (%60)
block1():
-> (%53)
%input2.1 : Tensor = prim::If(%59) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:4
block0():
%bias0.4 : Tensor = prim::unchecked_cast(%94)
%101 : Tensor = aten::linear(%input1.1, %95, %bias0.4)
-> (%101)
block1():
%106 : Tensor? = prim::Constant()
%107 : Tensor = aten::linear(%input1.1, %95, %106)
%67 : bool = aten::__isnot__(%94, %54) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1373:11
%output0.6 : Tensor = prim::If(%67) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1373:8
block0():
%bias1.4 : Tensor = prim::unchecked_cast(%94)
%output0.7 : Tensor = aten::add_(%107, %bias1.4, %56) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1374:12
-> (%output0.7)
block1():
-> (%107)
-> (%output0.6)