Skip to content

Questions for normalize_ir() #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
frank-wei opened this issue Apr 28, 2022 · 4 comments
Closed

Questions for normalize_ir() #185

frank-wei opened this issue Apr 28, 2022 · 4 comments
Assignees

Comments

@frank-wei
Copy link

frank-wei commented Apr 28, 2022

It looks like the normalize_ir() has different behavior than I thought.
For ex, I was hoping Functionalization will replace the in-place op to standard op like relu_ to relu.
I run a simple test program with gm.graph as follows:

    %x : torch.Tensor [#users=1] = placeholder[target=x]
    %relu : [#users=1] = call_function[target=torch.nn.functional.relu](args = (%x,), kwargs = {inplace: True})
    return (relu,)

After adding the normalizer in my_compiler. The result gm.graph does not change.

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
      gm = normalize_ir(gm, example_inputs)
      print("result gm=", gm.graph)
      return gm

After more debugging, I found that n.meta["is_input_mutation"] is True for relu node. Here is the code

So my question is,

  1. is there way we can do the expected functionalization pass by changing in-place to standard op in this case?
  2. just curious what is_input_mutation stands for or what is the situation where is_input_mutation is False for relu node?
@frank-wei frank-wei changed the title clarification in normalize_ir() Questions for normalize_ir() Apr 28, 2022
@frank-wei frank-wei assigned frank-wei and jansel and unassigned jansel and frank-wei Apr 28, 2022
@anijain2305
Copy link
Contributor

Hi Wei, in this case, you are mutating the input itself. Suppose, there is a user of the input x outside the Fx graph, then this user must see the updated/mutated value of x.

If the input was not mutated (opposed to your example), the def-and-use of the mutated variable would be contained within the scope of the Fx graph. And, therefore, we could do graph rewrite to get rid of mutation. This is what normalize-ir does for majority (not all) of the cases. But, handling input mutation requires little more effort.

  1. If the input is not mutated, you would see normalize-ir removing a large number of mutation, involving changing in-place op.

  2. Specifically for your example, no we don't handle input mutation today. We could implement it by adding extra outputs to the Fx graph, and these extra outputs would be the mutated input values. This way we can contain the scope within the FX graph, and remove the mutation. Finally, we can then overwrite the original inputs to these extra outputs outside the scope of Fx graph.

Somewhat related to this topic is Functionalization pass at the dispatcher level. Today, AOT Autograd relies on Dynamo's normalize-ir to remove mutation. But we plan to move over to the dispatcher-level functionalization pass soon. I am not sure if this benefits your work, but happy to discuss.

@frank-wei
Copy link
Author

@anijain2305 thanks for your explanation and it helps me a lot to my concerns. Mutating the input is an extreme case which we rarely see it. I tried with case not mutating the input and saw the expected behavior.
For Functionalization pass, I'd like to see the in-place to be changed to standard ones. Actually, just curious if "dispatcher-level" as you mentioned will happen somewhere inside AOT?
One more question maybe related to Functionalization pass and AOT Autograd, will it help to remove/transform the torch.ops.aten.copy_ in the case below? The reason is that TRT does not support any in-place operation.

  def forward(self, x, y):
               y = y+3
               y[:,0] = x[:,0]
               return y
  
gm.graph = 
   %x_1 : [#users=1] = placeholder[target=x_1]
   %y_1 : [#users=1] = placeholder[target=y_1]
   %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
   %add : [#users=2] = call_function[target=torch.ops.aten.add](args = (%y_1, %_tensor_constant0), kwargs = {})
   %slice_1 : [#users=1] = call_function[target=torch.ops.aten.slice](args = (%x_1, 0, 0, 9223372036854775807), kwargs = {})
   %select : [#users=1] = call_function[target=torch.ops.aten.select](args = (%slice_1, 1, 0), kwargs = {})
   %slice_2 : [#users=1] = call_function[target=torch.ops.aten.slice](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
   %select_1 : [#users=1] = call_function[target=torch.ops.aten.select](args = (%slice_2, 1, 0), kwargs = {})
   %copy_ : [#users=0] = call_function[target=torch.ops.aten.copy_](args = (%select_1, %select), kwargs = {})
   return add

@anijain2305
Copy link
Contributor

anijain2305 commented Apr 28, 2022

  1. The dispatcher-level functionalization traces the original model, and checks mutation at op-by-op level at the dispatcher. If it sees mutation, it removes it and keeps scorecard for handling the future uses. There is an excellent presentation from Brian Hirsh (author of Functionalization) on this matter.

So, this is not AOT Autograd per se. In fact, in the case of AOT Autograd, we plan to first get the forward and backward graph from the usual AOT tracing. And then call functionalization on top of these forward and backward graphs.

Therefore, I don't see any reason why we cannot use Functionalization for Dynamo-created Fx Graphs.

  1. Yes, it is supposed to remove copy_ in the graph. Handling mutation is tough, and various backends have different support for mutation - like nvfuser leaves the mutated ops untouched and leaves performance opportunity (but still runs correctly), while compilers like TVM (or likely TRT) fail if they see mutation.

More details on Functionalization and AOT Autograd integration is here - #88

@frank-wei
Copy link
Author

@anijain2305 it looks great for me of this PR https://github.com/pytorch/functorch/pull/703/files. I am expecting it will remove some mutations on dynamo created fx graph which will help us remove some blockers in some models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants