-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[C++ API] Make Sequential ref-counted #9151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -191,7 +191,7 @@ TEST_CASE("module/clone") { | |||
buffer = register_buffer("buf", torch::ones({2, 2})); | |||
} | |||
|
|||
Linear l1, l2, l3; | |||
Linear l1{nullptr}, l2{nullptr}, l3{nullptr}; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
/// Constructs the `ModuleHolder` with a contained module, forwarding all | ||
/// arguments to its constructor. | ||
template <typename... Ts> | ||
explicit ModuleHolder(Ts&&... ts) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
class Name : public torch::nn::ModuleHolder<Impl> { \ | ||
public: \ | ||
using torch::nn::ModuleHolder<Impl>::ModuleHolder; \ | ||
using torch::nn::ModuleHolder<Impl>::operator->; \ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -315,6 +322,16 @@ template <typename ModuleType> | |||
AnyModule::AnyModule(const ModuleHolder<ModuleType>& module_holder) | |||
: AnyModule(module_holder.get()) {} | |||
|
|||
inline AnyModule::AnyModule(const AnyModule& other) | |||
: content_(other.content_ ? other.content_->clone() : nullptr) {} |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@pytorchbot retest this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Not a complete review)
test/cpp/api/sequential.cpp
Outdated
REQUIRE(sequential[i].get() == modules[i].get()); | ||
REQUIRE(sequential.ptr<M>(i).get() == modules[i].get()); | ||
REQUIRE(sequential->ptr(i).get() == modules[i].get()); | ||
REQUIRE((*sequential)[i].get() == modules[i].get()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@goldsborough has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`. This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool. One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall. ezyang ebetica apaszke Pull Request resolved: pytorch#9151 Reviewed By: ezyang Differential Revision: D8809298 Pulled By: goldsborough fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
In the C++ API,
Sequential
currently was not refcounted itself, but storedshared_ptr<AnyModule>
to get the reference semantics. This is unfortunate because most modules in the API are accessed via->
, e.g.Linear l(1, 2); l->forward(...);
.Sequential
was different in that it had value semantics itself, thus was accessed via.
.This PR makes
Sequential
storeAnyModule
(without extra indirection), and uses the same pImpl mechanism we use for all other modules to makeSequential
have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside ofSequential
, which is cool.One thing I had to change was that the
ModuleHolder
with which the whole pImpl thing is implemented previously did some tricks to makeLinear(3, 4)
actually constructLinear(LinearOptions(3, 4))
. This doesn't work well withSequential
since it takes a variadic parameter pack. Instead, I madeModuleHolder
forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.@ezyang @ebetica @apaszke