-
Notifications
You must be signed in to change notification settings - Fork 105
[WIP] [Do not merge] Reduce overhead of AOT Module #612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
functorch/_src/aot_autograd.py
Outdated
######################################################### | ||
""" | ||
(1) Create a new fn_for_tracing that lifts params as inputs (TODO: buffers) | ||
(2) A new tracer - MyTracer (slight modification of PythonKeyTracer). This | ||
works with Proxy tensors instead of PythonTensors. Goal is to get a torch | ||
graph, and not torch.aten graph yet. | ||
(3) This traced function is then passed on to aot_function. | ||
(4) The params are read and flattened on every forward call, as they can change during training. | ||
""" | ||
######################################################### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -12,6 +12,8 @@ | |||
from .named_members_polyfill import _named_parameters, _named_buffers | |||
from typing import Callable, List, Dict, Any, Tuple, Optional | |||
|
|||
import torchdynamo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps:
try:
from torchdynamo import disable as disable_dynamo
except ImportError:
def disable_dynamo(x):
return x
So we can still use this without dynamo.
functorch/_src/aot_autograd.py
Outdated
return fn_with_params_as_args | ||
|
||
|
||
gm = torch.fx.symbolic_trace(mod) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this? Seems not ideal, as this could be lossy.
functorch/_src/aot_autograd.py
Outdated
else: | ||
params_flat, _ = pytree.tree_flatten(params) | ||
|
||
params_flat = tuple(params_flat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the torchdynamo case (not in the general case), this list is a constant you don't need to recompute every time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For training, the params will change after each update. Don't we have to flatten for each forward call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The data will change, but the Tensor objects should be the same. This just holds pointers to a bunch of Tensors (that will mutate in-place)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I realized while driving today. These are just references. And the update happens in place.
Will remove it.
functorch/_src/aot_autograd.py
Outdated
|
||
if compiled_f is None: | ||
fn_with_params_as_args = flattened_fn(gm, nargs) | ||
compiled_f = aot_function(fn_with_params_as_args, *top_args, **top_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the torchdynamo case (not in the general case), there should only every be one compiled_f
. You can just compute it ahead of time and don't need to wait until the first call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Yeah, we can send example_inputs and compile it beforehand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is also a bunch of pytree overhead inside aot_function, plus an unneeded caching layer.
Closing. Following PRs are covering this
|
@jansel @Chillee
Inspired heavily from - https://github.com/facebookresearch/torchdynamo/blob/ce0c84a62d3287e2afde22e0d823a8d1ae4758a8/torchdynamo/optimizations/python_key.py#L121