-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[C++ API] Functional DataParallel #9234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Build failure looks legit. |
4579c1d
to
dc894fb
Compare
I'm trying to use the new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM. I have some comments that might help clean up the code. Would be good to fix the std::terminate
in case of an exception before merging.
for (auto& variable : variables) { | ||
set_history(variable, grad_fn); | ||
} | ||
} |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
"and return a vector."); | ||
} | ||
|
||
std::vector<at::Tensor> tensors; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
} | ||
grad_fn = std::make_shared<Scatter>( | ||
source_devices, | ||
input_sizes, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
return {variable}; | ||
#else | ||
AT_ERROR("Gather is only supported in CUDA environments"); | ||
#endif |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
}); | ||
std::vector<at::Tensor> tensors; | ||
tensors = | ||
torch::cuda::scatter(input, device_indices, chunk_sizes_, dim_, streams_); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/cpp/api/parallel.cpp
Outdated
|
||
auto a = torch::ones(5, torch::requires_grad(true).device({torch::kCUDA, 0})); | ||
auto b = torch::ones(5, torch::requires_grad(true).device({torch::kCUDA, 1})); | ||
auto output = gather.apply({a, b}); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
REQUIRE(b.grad().defined()); | ||
REQUIRE(b.grad().device() == torch::Device(torch::kCUDA, 1)); | ||
REQUIRE(b.grad().sum().toCInt() == 5); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/cpp/api/parallel.cpp
Outdated
TEST_CASE("Parallel/Replicate", "[cuda]") { | ||
Linear linear(3, 4); | ||
auto replicas = parallel::replicate( | ||
linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)}); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
replica2_parameters[i]->data().data<float>() != | ||
original_parameters[i]->data().data<float>()); | ||
} | ||
} |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/cpp/api/parallel.cpp
Outdated
linear, | ||
input, | ||
/*devices=*/at::nullopt, | ||
/*output_device=*/torch::Device(torch::kCUDA, 1)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
7a255ed
to
1f6e95d
Compare
@apaszke I think I addressed all your comments now. Thanks for the review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@goldsborough has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
fb26da6
to
ecdc3cd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@goldsborough has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Throw runtime error in non-CUDA environments Remove THCStream forward declarations Add retain() before converting from THCStream to CUDAStream Conditionally compile with OpenMP in libtorch Improve move-efficiency of comm.cpp and add multi-gpu guard Fix single-device case of data_parallel Include functional.h in python_comm.cpp Rethrow exception in parallel_apply Clarify data-parallel documentation
ecdc3cd
to
64527df
Compare
@pytorchbot retest this please |
@pytorchbot retest this please |
1 similar comment
@pytorchbot retest this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@goldsborough is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This PR adds the functional version of `DataParallel` (i.e. `data_parallel`) to the C++ frontend. For this, I had to: 1. Add "differentiable" versions of scatter and gather, which perform their inverse operation in the backward pass, to C++. I've added them under `torch/csrc/autograd/functions/comm.{h,cpp}`. I had to move some utilities from `VariableType.cpp` into `torch/csrc/autograd/functions/utils.h`, and changed them a bit to fix the `const_cast`s for which there were `TODO`s, 2. Implement the `replicate`, `parallel_apply` and the combining `data_parallel` functions in C++. `replicate` is implemented based on our existing `clone()` interface, along with the ability to set the current device via `at::OptionsGuard` (so nice). `parallel_apply` is implemented using `at::parallel_for` (CC cpuhrsch) and [follows the code from PyTorch](https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/parallel_apply.py). Added lots of tests for these things. apaszke ezyang ebetica colesbury Pull Request resolved: pytorch#9234 Differential Revision: D8865182 Pulled By: goldsborough fbshipit-source-id: 4f1fecf2b3f3bc1540c071dfb2d23dd45de433e4
Summary: This PR adds the functional version of `DataParallel` (i.e. `data_parallel`) to the C++ frontend. For this, I had to: 1. Add "differentiable" versions of scatter and gather, which perform their inverse operation in the backward pass, to C++. I've added them under `torch/csrc/autograd/functions/comm.{h,cpp}`. I had to move some utilities from `VariableType.cpp` into `torch/csrc/autograd/functions/utils.h`, and changed them a bit to fix the `const_cast`s for which there were `TODO`s, 2. Implement the `replicate`, `parallel_apply` and the combining `data_parallel` functions in C++. `replicate` is implemented based on our existing `clone()` interface, along with the ability to set the current device via `at::OptionsGuard` (so nice). `parallel_apply` is implemented using `at::parallel_for` (CC cpuhrsch) and [follows the code from PyTorch](https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/parallel_apply.py). Added lots of tests for these things. apaszke ezyang ebetica colesbury Pull Request resolved: pytorch#9234 Differential Revision: D8865182 Pulled By: goldsborough fbshipit-source-id: 4f1fecf2b3f3bc1540c071dfb2d23dd45de433e4
This PR adds the functional version of
DataParallel
(i.e.data_parallel
) to the C++ frontend.For this, I had to:
torch/csrc/autograd/functions/comm.{h,cpp}
. I had to move some utilities fromVariableType.cpp
intotorch/csrc/autograd/functions/utils.h
, and changed them a bit to fix theconst_cast
s for which there wereTODO
s,replicate
,parallel_apply
and the combiningdata_parallel
functions in C++.replicate
is implemented based on our existingclone()
interface, along with the ability to set the current device viaat::OptionsGuard
(so nice).parallel_apply
is implemented usingat::parallel_for
(CC @cpuhrsch) and follows the code from PyTorch.Added lots of tests for these things.
@apaszke @ezyang @ebetica @colesbury