-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Allow forward functions with single output to return Variable #23803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
@@ -56,10 +56,13 @@ TORCH_API variable_list _wrap_outputs( | |||
// auto y = MyFunction::apply(6, x); | |||
// Example backward call: | |||
// y[0].sum().backward(); | |||
template<typename X, typename... Args> | |||
using forward_t = decltype(X::forward(nullptr, std::declval<Args>()...)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably didn't mean to attach the docblock above to forward_t
. (Speaking of which, would be nice to have a doc for forward_t
too!)
@smessmer could you please help review the C++ template metaprogramming in this diff? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to work, but I don't know too much about how the template metaprogramming works here.
…ble" Summary: Custom `forward()` can return a `Variable` in case of single outputs instead of returning a `variable_list` of size 1. Test Plan: Modified tests involving single output forward functions.
…ble" Summary: Custom `forward()` can return a `Variable` in case of single outputs instead of returning a `variable_list` of size 1. Test Plan: Modified tests involving single output forward functions. Differential Revision: [D16673857](https://our.internmc.facebook.com/intern/diff/D16673857)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metaprogramming looks good
@@ -182,7 +198,8 @@ variable_list Function<T>::apply(Args&&... args) { | |||
node->input_info_.emplace_back(var); | |||
} | |||
|
|||
variable_list outputs; | |||
typedef forward_t<X, Args...> forward_return_t; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: using
is more readable than typedef
@@ -159,9 +169,15 @@ template <typename... Args> | |||
void extract_vars(std::vector<bool> &is_var, variable_list& list, Args&& ... args) { | |||
} | |||
|
|||
template <typename T> | |||
typename std::enable_if<std::is_same<T, variable_list>::value, T&>::type to_output_type(variable_list& output_list) { return output_list; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we have guts::enable_if_t
to avoid the typename
…ble" Summary: Custom `forward()` can return a `Variable` in case of single outputs instead of returning a `variable_list` of size 1. Test Plan: Modified tests involving single output forward functions. Differential Revision: [D16673857](https://our.internmc.facebook.com/intern/diff/D16673857)
…ble" Summary: Custom `forward()` can return a `Variable` in case of single outputs instead of returning a `variable_list` of size 1. Test Plan: Modified tests involving single output forward functions. Differential Revision: [D16673857](https://our.internmc.facebook.com/intern/diff/D16673857)
…pytorch#23803)" This reverts commit 81ba2df.
Stack from ghstack:
Summary:
Custom
forward()
can return aVariable
in case of single outputs instead of returning avariable_list
of size 1.Test Plan: Modified tests involving single output forward functions.
Differential Revision: D16673857