Skip to content

[C++ API] Make Sequential ref-counted #9151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

goldsborough
Copy link
Contributor

In the C++ API, Sequential currently was not refcounted itself, but stored shared_ptr<AnyModule> to get the reference semantics. This is unfortunate because most modules in the API are accessed via ->, e.g. Linear l(1, 2); l->forward(...);. Sequential was different in that it had value semantics itself, thus was accessed via ..

This PR makes Sequential store AnyModule (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make Sequential have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of Sequential, which is cool.

One thing I had to change was that the ModuleHolder with which the whole pImpl thing is implemented previously did some tricks to make Linear(3, 4) actually construct Linear(LinearOptions(3, 4)). This doesn't work well with Sequential since it takes a variadic parameter pack. Instead, I made ModuleHolder forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.

@ezyang @ebetica @apaszke

@@ -191,7 +191,7 @@ TEST_CASE("module/clone") {
buffer = register_buffer("buf", torch::ones({2, 2}));
}

Linear l1, l2, l3;
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

/// Constructs the `ModuleHolder` with a contained module, forwarding all
/// arguments to its constructor.
template <typename... Ts>
explicit ModuleHolder(Ts&&... ts)

This comment was marked as off-topic.

class Name : public torch::nn::ModuleHolder<Impl> { \
public: \
using torch::nn::ModuleHolder<Impl>::ModuleHolder; \
using torch::nn::ModuleHolder<Impl>::operator->; \

This comment was marked as off-topic.

@@ -315,6 +322,16 @@ template <typename ModuleType>
AnyModule::AnyModule(const ModuleHolder<ModuleType>& module_holder)
: AnyModule(module_holder.get()) {}

inline AnyModule::AnyModule(const AnyModule& other)
: content_(other.content_ ? other.content_->clone() : nullptr) {}

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@goldsborough
Copy link
Contributor Author

@pytorchbot retest this please

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not a complete review)

REQUIRE(sequential[i].get() == modules[i].get());
REQUIRE(sequential.ptr<M>(i).get() == modules[i].get());
REQUIRE(sequential->ptr(i).get() == modules[i].get());
REQUIRE((*sequential)[i].get() == modules[i].get());

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@goldsborough
Copy link
Contributor Author

goldsborough commented Jul 9, 2018

Can we get closure on this PR? @ezyang @apaszke did you have more concerns?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldsborough has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`.

This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool.

One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall.

ezyang ebetica apaszke
Pull Request resolved: pytorch#9151

Reviewed By: ezyang

Differential Revision: D8809298

Pulled By: goldsborough

fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
@ezyang ezyang added the merged label Jun 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants