Skip to content

Start to handle branching in simple cases #16

Closed
@narendasan

Description

@narendasan

The system works pretty well for traced models, but not much work has been done with torch script models that have branching. I noticed some common cases that we should be able to handle include branching for none arguments such as graphs like this:

  %50 : Function = prim::Constant[name="linear"]()
  %53 : bool = prim::Constant[value=0]() # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:7
  %54 : None = prim::Constant() # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:40
  %55 : int = prim::Constant[value=2]() # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:22
  %56 : int = prim::Constant[value=1]() # :0:0
  %57 : int = aten::dim(%input1.1) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:7
  %58 : bool = aten::eq(%57, %55) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:7
  %59 : bool = prim::If(%58) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:7
    block0():
      %60 : bool = aten::__isnot__(%94, %54) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:28
      -> (%60)
    block1():
      -> (%53)
  %input2.1 : Tensor = prim::If(%59) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1368:4
    block0():
      %bias0.4 : Tensor = prim::unchecked_cast(%94)
      %101 : Tensor = aten::linear(%input1.1, %95, %bias0.4)
      -> (%101)
    block1():
      %106 : Tensor? = prim::Constant()
      %107 : Tensor = aten::linear(%input1.1, %95, %106)
      %67 : bool = aten::__isnot__(%94, %54) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1373:11
      %output0.6 : Tensor = prim::If(%67) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1373:8
        block0():
          %bias1.4 : Tensor = prim::unchecked_cast(%94)
          %output0.7 : Tensor = aten::add_(%107, %bias1.4, %56) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1374:12
          -> (%output0.7)
        block1():
          -> (%107)
      -> (%output0.6)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions