-
Notifications
You must be signed in to change notification settings - Fork 364
fix: Add support for truncate_long_and_double
in FX
#1865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ref_output, | ||
rtol=1e-04, | ||
atol=1e-04, | ||
check_dtype=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output data type will be different, since TRT
cannot output int64
types
torch.int64
inputs in FXtruncate_long_and_double
in FX
65bc360
to
ad8aecf
Compare
elif dtype == torch.int64: | ||
if truncate_long_and_double: | ||
_LOGGER.warn( | ||
"Detected Int64 Input, Casting to Int32 for TRT Engine Compatibility" | ||
) | ||
return trt.int32 | ||
else: | ||
raise TypeError( | ||
"Detected Int64 Input which is not supported by tensorrt, enable compilation" | ||
+ "option truncate_long_and_double=True to cast input to Int32 for TRT Engine" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly to the TorchScript path, allow the truncate_long_and_double
argument to automatically cast inputs as needed by TRT Engines, while informing the user. This is primarily helpful for intermediate inputs (not user-provided), which happen to be long-type tensors (such as indices for embeddings).
@gs-olive is this PR still needed? |
Yes, this PR is still needed to support T5 in the |
Can we create a seperate PR for dynamo so we can land the feature there at least? |
- Add utility capabilities for accepting `int64` inputs to TRTModules to support multiple use cases - Support cases include situations where internal tensors in split modules are `int64` (generally used for indexing torch Tensors) - This also supports cases where the user wants to input `long` tensors as `forward` inputs - Add test cases to verify functionality and accuracy - Enable tests for `TRTModuleNext`, which are now fully supported on `main`
- Add support and testing for `double` type inputs
@gs-olive can you create separate PRs for each backend? Will be easier to merge then |
truncate_long_and_double
in FXtruncate_long_and_double
in Dynamo
truncate_long_and_double
in Dynamotruncate_long
in Dynamo
truncate_long
in Dynamotruncate_long_and_double
in FX
Closed in favor of the more robust #2021 (no need to manually downcast, have the FX graph/Dynamo utilities automatically handle this for us). |
Description
Fixes #1864
Addresses #1740
Type of change
Checklist: