Skip to content

[WIP] [Do not merge] Reduce overhead of AOT Module #612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

anijain2305
Copy link
Contributor

Comment on lines 559 to 582
#########################################################
"""
(1) Create a new fn_for_tracing that lifts params as inputs (TODO: buffers)
(2) A new tracer - MyTracer (slight modification of PythonKeyTracer). This
works with Proxy tensors instead of PythonTensors. Goal is to get a torch
graph, and not torch.aten graph yet.
(3) This traced function is then passed on to aot_function.
(4) The params are read and flattened on every forward call, as they can change during training.
"""
#########################################################
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel @Chillee This is the flow.

@@ -12,6 +12,8 @@
from .named_members_polyfill import _named_parameters, _named_buffers
from typing import Callable, List, Dict, Any, Tuple, Optional

import torchdynamo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps:

try:
    from torchdynamo import disable as disable_dynamo
except ImportError:
    def disable_dynamo(x):
         return x

So we can still use this without dynamo.

return fn_with_params_as_args


gm = torch.fx.symbolic_trace(mod)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Seems not ideal, as this could be lossy.

else:
params_flat, _ = pytree.tree_flatten(params)

params_flat = tuple(params_flat)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the torchdynamo case (not in the general case), this list is a constant you don't need to recompute every time.

Copy link
Contributor Author

@anijain2305 anijain2305 Mar 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For training, the params will change after each update. Don't we have to flatten for each forward call?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data will change, but the Tensor objects should be the same. This just holds pointers to a bunch of Tensors (that will mutate in-place)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I realized while driving today. These are just references. And the update happens in place.

Will remove it.


if compiled_f is None:
fn_with_params_as_args = flattened_fn(gm, nargs)
compiled_f = aot_function(fn_with_params_as_args, *top_args, **top_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the torchdynamo case (not in the general case), there should only every be one compiled_f. You can just compute it ahead of time and don't need to wait until the first call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Yeah, we can send example_inputs and compile it beforehand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is also a bunch of pytree overhead inside aot_function, plus an unneeded caching layer.

@anijain2305
Copy link
Contributor Author

Closing. Following PRs are covering this

@anijain2305 anijain2305 closed this Apr 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants