Skip to content

Commit ebba421

Browse files
d4l3kpytorchmergebot
authored andcommitted
torch/distributed: move WorkerInfo registration into libtorch instead of libtorch_python (pytorch#78028)
Summary: This moves torch::class_<WorkerInfo> into `rpc_agent.cpp` so it gets registered in libtorch instead of libtorch_python. This is intermediate work to getting torch::deploy to load an unmodified copy of libtorch. Current RPC is incompatible due to duplicate registrations. ``` unknown file: Failure C++ exception with description "Exception Caught inside torch::deploy embedded library: Custom class with name __torch__.torch.classes.dist_rpc.WorkerInfo is already registered. Ensure that registration with torch::class_ is only called once. Exception raised from registerCustomClass at ../aten/src/ATen/core/custom_class.cpp:61 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x3e (0x7f3bd9adb92e in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libc10.so) frame ROCm#1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x5c (0x7f3bd9ab7068 in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libc10.so) frame ROCm#2: torch::registerCustomClass(std::shared_ptr<c10::ClassType>) + 0x110 (0x7f3bc2258980 in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so) frame ROCm#3: torch::detail::class_base::class_base(std::string const&, std::string const&, std::string, std::type_info const&, std::type_info const&) + 0x3b9 (0x7f3bc225a419 in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so) frame ROCm#4: [0x7f3ba45cfea1] frame ROCm#5: <unknown function> + 0x1b5334 (0x5652bdab9334 in ./test_deploy) frame ROCm#6: <unknown function> + 0x1b4f3e (0x5652bdab8f3e in ./test_deploy) frame ROCm#7: <unknown function> + 0x1b519b (0x5652bdab919b in ./test_deploy) frame ROCm#8: loadSearchFile(char const*) + 0x23e (0x7f3ba62f37f8 in /tmp/torch_deploy9ATEFg) frame ROCm#9: deploy_set_self + 0x51 (0x7f3ba62f38f9 in /tmp/torch_deploy9ATEFg) frame ROCm#10: torch::deploy::Interpreter::Interpreter(torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>) + 0x274 (0x5652bdaaa790 in ./test_deploy) frame ROCm#11: void __gnu_cxx::new_allocator<torch::deploy::Interpreter>::construct<torch::deploy::Interpreter, torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(torch::deploy::Interpreter*, torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0x81 (0x5652bdaaf58b in ./test_deploy) frame ROCm#12: void std::allocator_traits<std::allocator<torch::deploy::Interpreter> >::construct<torch::deploy::Interpreter, torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(std::allocator<torch::deploy::Interpreter>&, torch::deploy::Interpreter*, torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0x4a (0x5652bdaae320 in ./test_deploy) frame ROCm#13: void std::vector<torch::deploy::Interpreter, std::allocator<torch::deploy::Interpreter> >::_M_realloc_insert<torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(__gnu_cxx::__normal_iterator<torch::deploy::Interpreter*, std::vector<torch::deploy::Interpreter, std::allocator<torch::deploy::Interpreter> > >, torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0xee (0x5652bdaae4a0 in ./test_deploy) frame ROCm#14: void std::vector<torch::deploy::Interpreter, std::allocator<torch::deploy::Interpreter> >::emplace_back<torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0xb6 (0x5652bdaad258 in ./test_deploy) frame ROCm#15: torch::deploy::InterpreterManager::InterpreterManager(unsigned long, std::shared_ptr<torch::deploy::Environment>) + 0x123 (0x5652bdaa83b1 in ./test_deploy) frame ROCm#16: TorchpyTest_InitTwice_Test::TestBody() + 0x65 (0x5652bda075a9 in ./test_deploy) frame ROCm#17: void testing::internal::HandleSehExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x65 (0x5652bda944b7 in ./test_deploy) frame ROCm#18: void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x5a (0x5652bda8cfe7 in ./test_deploy) frame ROCm#19: testing::Test::Run() + 0x100 (0x5652bda68622 in ./test_deploy) frame ROCm#20: testing::TestInfo::Run() + 0x10f (0x5652bda68fb3 in ./test_deploy) frame ROCm#21: testing::TestSuite::Run() + 0x121 (0x5652bda6980d in ./test_deploy) frame ROCm#22: testing::internal::UnitTestImpl::RunAllTests() + 0x38e (0x5652bda756e6 in ./test_deploy) frame ROCm#23: bool testing::internal::HandleSehExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x65 (0x5652bda9586b in ./test_deploy) frame ROCm#24: bool testing::internal::HandleExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x5a (0x5652bda8e0f7 in ./test_deploy) frame ROCm#25: testing::UnitTest::Run() + 0xc9 (0x5652bda73fd1 in ./test_deploy) frame ROCm#26: RUN_ALL_TESTS() + 0x11 (0x5652bda169fa in ./test_deploy) frame ROCm#27: main + 0x27 (0x5652bda10ce2 in ./test_deploy) frame ROCm#28: <unknown function> + 0x2d310 (0x7f3bc0431310 in /usr/lib/libc.so.6) frame ROCm#29: __libc_start_main + 0x81 (0x7f3bc04313c1 in /usr/lib/libc.so.6) frame ROCm#30: _start + 0x25 (0x5652bda063b5 in ./test_deploy) ``` Test Plan: CI Differential Revision: D36564258 Pull Request resolved: pytorch#78028 Approved by: https://github.com/rohan-varma
1 parent 8412f20 commit ebba421

File tree

3 files changed

+49
-28
lines changed

3 files changed

+49
-28
lines changed

torch/csrc/distributed/rpc/rpc_agent.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,50 @@ namespace torch {
55
namespace distributed {
66
namespace rpc {
77

8+
namespace {
9+
// WorkerInfo needs to be registered exactly once. Since the op registration
10+
// happens in libtorch_python we wrap the class registration in a helper to make
11+
// sure that if there's multiple copies of Python such as used in torch::deploy
12+
// we only ever register it once.
13+
static std::once_flag workerInfoFlag;
14+
static c10::optional<torch::class_<WorkerInfo>> workerInfo;
15+
} // namespace
16+
17+
RegisterWorkerInfoOnce::RegisterWorkerInfoOnce() {
18+
std::call_once(workerInfoFlag, []() {
19+
workerInfo = torch::class_<WorkerInfo>("dist_rpc", "WorkerInfo")
20+
.def(torch::init<std::string, int64_t>());
21+
});
22+
}
23+
824
constexpr size_t WorkerInfo::MAX_NAME_LEN;
925

26+
WorkerInfo::WorkerInfo(std::string name, int64_t id)
27+
: WorkerInfo(std::move(name), (worker_id_t)id) {
28+
TORCH_CHECK(
29+
id <= std::numeric_limits<worker_id_t>::max(),
30+
"RPC worker id ",
31+
id,
32+
" out of bound of int16_t.");
33+
}
34+
35+
WorkerInfo::WorkerInfo(std::string name, worker_id_t id)
36+
: name_(std::move(name)), id_(id) {
37+
bool validSize = name_.length() < MAX_NAME_LEN && name_.length() > 0;
38+
bool validChar =
39+
std::find_if(name_.begin(), name_.end(), [](char c) {
40+
return !(std::isalnum(c) || c == '-' || c == '_' || c == ':');
41+
}) == name_.end();
42+
TORCH_CHECK(
43+
validSize && validChar,
44+
"Worker name must match ^[A-Za-z0-9-_:]*$, "
45+
"and must be non-empty and shorter than ",
46+
MAX_NAME_LEN,
47+
" chars, "
48+
"but got ",
49+
name_);
50+
}
51+
1052
// Large Time Duration for waiting on the condition variable until the map is
1153
// population. Cannot use
1254
// std::chrono::time_point<std::chrono::steady_clock>::max() due to a known

torch/csrc/distributed/rpc/rpc_agent.h

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -52,31 +52,9 @@ struct RpcBackendOptions {
5252

5353
// A globally unique ID to identify an RpcAgent
5454
struct TORCH_API WorkerInfo : torch::CustomClassHolder {
55-
WorkerInfo(std::string name, int64_t id)
56-
: WorkerInfo(std::move(name), (worker_id_t)id) {
57-
TORCH_CHECK(
58-
id <= std::numeric_limits<worker_id_t>::max(),
59-
"RPC worker id ",
60-
id,
61-
" out of bound of int16_t.");
62-
}
55+
WorkerInfo(std::string name, int64_t id);
6356

64-
WorkerInfo(std::string name, worker_id_t id)
65-
: name_(std::move(name)), id_(id) {
66-
bool validSize = name_.length() < MAX_NAME_LEN && name_.length() > 0;
67-
bool validChar =
68-
std::find_if(name_.begin(), name_.end(), [](char c) {
69-
return !(std::isalnum(c) || c == '-' || c == '_' || c == ':');
70-
}) == name_.end();
71-
TORCH_CHECK(
72-
validSize && validChar,
73-
"Worker name must match ^[A-Za-z0-9-_:]*$, "
74-
"and must be non-empty and shorter than ",
75-
MAX_NAME_LEN,
76-
" chars, "
77-
"but got ",
78-
name_);
79-
}
57+
WorkerInfo(std::string name, worker_id_t id);
8058

8159
bool operator==(const WorkerInfo& rhs) {
8260
return (id_ == rhs.id_) && (name_ == rhs.name_);
@@ -88,6 +66,10 @@ struct TORCH_API WorkerInfo : torch::CustomClassHolder {
8866
const worker_id_t id_;
8967
};
9068

69+
struct TORCH_API RegisterWorkerInfoOnce {
70+
RegisterWorkerInfoOnce();
71+
};
72+
9173
TORCH_API std::ostream& operator<<(
9274
std::ostream& os,
9375
const WorkerInfo& workerInfo);

torch/csrc/jit/runtime/register_distributed_ops.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@ namespace torch {
2222
namespace jit {
2323

2424
namespace {
25-
26-
static auto workerInfo =
27-
torch::class_<dist_rpc::WorkerInfo>("dist_rpc", "WorkerInfo")
28-
.def(torch::init<std::string, int64_t>());
25+
distributed::rpc::RegisterWorkerInfoOnce workerInfo{};
2926

3027
// prepare the rpc input arguments and call the C++ impls
3128
void prepare_and_call_rpc_op(

0 commit comments

Comments
 (0)