Skip to content

Commit a5a4322

Browse files
VirrageSapaszke
authored andcommitted
Implement data channel groups (#25)
1 parent 23feafb commit a5a4322

14 files changed

+467
-97
lines changed

torch/csrc/distributed/Module.cpp

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <Python.h>
22

3+
#include <memory>
34
#include <unordered_map>
5+
#include <vector>
46

57
#include "THDP.h"
68

@@ -10,6 +12,7 @@ static std::unordered_map<std::string, THDChannelType> name2channel_type = {
1012
};
1113

1214
static std::unordered_map<PyObject*, THDReduceOp> obj2reduceop;
15+
static std::unordered_map<PyObject*, THDGroup> obj2group;
1316

1417
static THPObjectPtr _ensureBytes(PyObject *obj)
1518
{
@@ -83,6 +86,18 @@ static THDReduceOp _getReduceOp(PyObject *obj)
8386
return it->second;
8487
}
8588

89+
static THDGroup _getGroup(PyObject *obj)
90+
{
91+
auto it = obj2group.find(obj);
92+
if (it == obj2group.end()) {
93+
if (!THPUtils_checkLong(obj))
94+
throw std::runtime_error("group should be an int or one of the values "
95+
"from torch.distributed.group");
96+
return THPUtils_unpackLong(obj);
97+
}
98+
return it->second;
99+
}
100+
86101
PyObject* THDPModule_send(PyObject *_unused, PyObject *args)
87102
{
88103
HANDLE_TH_ERRORS
@@ -118,53 +133,97 @@ PyObject* THDPModule_recv(PyObject *_unused, PyObject *args)
118133
PyObject* THDPModule_allReduce(PyObject *_unused, PyObject *args)
119134
{
120135
HANDLE_TH_ERRORS
121-
if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0))) {
122-
THPUtils_invalidArguments(args, "all_reduce", 1, "(tensor in_out, reduce_op op)");
136+
if (PyTuple_GET_SIZE(args) != 3 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0))) {
137+
THPUtils_invalidArguments(args, "all_reduce", 1, "(tensor in_out, reduce_op op, group gr)");
123138
return NULL;
124139
}
125140

141+
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
126142
THDReduceOp op = _getReduceOp(PyTuple_GET_ITEM(args, 1));
127143
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
128-
THDAllReduce(desc, op);
144+
THDAllReduce(desc, op, group);
129145
Py_RETURN_NONE;
130146
END_HANDLE_TH_ERRORS
131147
}
132148

133149
PyObject* THDPModule_reduce(PyObject *_unused, PyObject *args)
134150
{
135151
HANDLE_TH_ERRORS
136-
if (PyTuple_GET_SIZE(args) != 3 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
152+
if (PyTuple_GET_SIZE(args) != 4 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
137153
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
138154
THPUtils_invalidArguments(args, "reduce", 1,
139-
"(tensor reduced, int dst_rank, reduce_op op)");
155+
"(tensor reduced, int dst_rank, reduce_op op, group gr)");
140156
return NULL;
141157
}
142158

159+
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 3));
143160
THDReduceOp op = _getReduceOp(PyTuple_GET_ITEM(args, 2));
144161
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
145162
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
146-
THDReduce(desc, op, dst_rank);
163+
THDReduce(desc, op, dst_rank, group);
147164
Py_RETURN_NONE;
148165
END_HANDLE_TH_ERRORS
149166
}
150167

151168
PyObject* THDPModule_broadcast(PyObject *_unused, PyObject *args)
152169
{
153170
HANDLE_TH_ERRORS
154-
if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
171+
if (PyTuple_GET_SIZE(args) != 3 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
155172
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
156-
THPUtils_invalidArguments(args, "broadcast", 1, "(tensor src_dst, int src_rank)");
173+
THPUtils_invalidArguments(args, "broadcast", 1,
174+
"(tensor src_dst, int src_rank, group gr)");
157175
return NULL;
158176
}
159177

178+
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
160179
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
161180
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
162-
THDBroadcast(desc, src_rank);
181+
THDBroadcast(desc, src_rank, group);
163182
Py_RETURN_NONE;
164183
END_HANDLE_TH_ERRORS
165184
}
166185

167-
PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *reduce_op_obj) {
186+
PyObject* THDPModule_newGroup(PyObject *_unused, PyObject *args)
187+
{
188+
HANDLE_TH_ERRORS
189+
PyObject* sequence = PyTuple_GET_ITEM(args, 0);
190+
Py_ssize_t tmp_length;
191+
std::vector<int> ranks;
192+
193+
if (PyTuple_GET_SIZE(args) != 1 || !PySequence_Check(sequence))
194+
goto invalid_arguments;
195+
196+
tmp_length = PySequence_Length(sequence);
197+
THPUtils_assert(tmp_length >= 0, "couldn't obtain the length of %s",
198+
THPUtils_typename(sequence));
199+
200+
ranks.reserve(static_cast<std::size_t>(tmp_length));
201+
for (std::size_t i = 0; i < ranks.capacity(); ++i) {
202+
if (!THPUtils_checkLong(PySequence_ITEM(sequence, i)))
203+
goto invalid_arguments;
204+
205+
ranks.push_back(THPUtils_unpackLong(PySequence_ITEM(sequence, i)));
206+
for (std::size_t j = 0; j < i; ++j)
207+
THPUtils_assert(ranks[i] != ranks[j], "ranks should be unique");
208+
}
209+
210+
return PyInt_FromLong(THDNewGroup(ranks.data(), ranks.size()));
211+
212+
invalid_arguments:
213+
THPUtils_invalidArguments(args, "newGroup", 1, "(list[int] ranks)");
214+
return NULL;
215+
END_HANDLE_TH_ERRORS
216+
}
217+
218+
PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *args) {
219+
if (PyTuple_GET_SIZE(args) != 2) {
220+
THPUtils_invalidArguments(args, "initExtension", 1, "(reduce_op obj, group obj)");
221+
return NULL;
222+
}
223+
224+
PyObject* reduce_op_obj = PyTuple_GET_ITEM(args, 0);
225+
PyObject* group_obj = PyTuple_GET_ITEM(args, 1);
226+
168227
THPObjectPtr reduce_op;
169228
#define REGISTER_REDUCE_OP(NAME) \
170229
reduce_op = PyObject_GetAttrString(reduce_op_obj, #NAME); \
@@ -175,11 +234,19 @@ PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *reduce_op_obj) {
175234
REGISTER_REDUCE_OP(MIN);
176235
REGISTER_REDUCE_OP(MAX);
177236
#undef REGISTER_REDUCE_OP
237+
238+
THPObjectPtr group;
239+
#define REGISTER_GROUP(NAME) \
240+
group = PyObject_GetAttrString(group_obj, #NAME); \
241+
THPUtils_assert(group, "Missing object for group " #NAME); \
242+
obj2group.emplace(group.get(), THDGroup##NAME);
243+
REGISTER_GROUP(WORLD);
244+
#undef REGISTER_GROUP
178245
Py_RETURN_TRUE;
179246
}
180247

181248
static struct PyMethodDef _THDPModule_methods[] = {
182-
{"_dist_init_extension", (PyCFunction)THDPModule_initExtension, METH_O, NULL},
249+
{"_dist_init_extension", (PyCFunction)THDPModule_initExtension, METH_VARARGS, NULL},
183250
{"_dist_init_process_group", (PyCFunction)THDPModule_initProcessGroup, METH_O, NULL},
184251
{"_dist_get_rank", (PyCFunction)THDPModule_getRank, METH_NOARGS, NULL},
185252
{"_dist_get_num_processes", (PyCFunction)THDPModule_getNumProcesses, METH_NOARGS, NULL},
@@ -188,6 +255,7 @@ static struct PyMethodDef _THDPModule_methods[] = {
188255
{"_dist_all_reduce", (PyCFunction)THDPModule_allReduce, METH_VARARGS, NULL},
189256
{"_dist_reduce", (PyCFunction)THDPModule_reduce, METH_VARARGS, NULL},
190257
{"_dist_broadcast", (PyCFunction)THDPModule_broadcast, METH_VARARGS, NULL},
258+
{"_dist_new_group", (PyCFunction)THDPModule_newGroup, METH_VARARGS, NULL},
191259
{NULL}
192260
};
193261

torch/distributed/__init__.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ class reduce_op(object):
1818
MAX = object()
1919
MIN = object()
2020

21+
class group(object):
22+
WORLD = object()
23+
2124
def get_rank():
2225
return torch._C._dist_get_rank()
2326

@@ -34,16 +37,19 @@ def recv(tensor, src_rank):
3437
return torch._C._dist_recv(tensor, src_rank)
3538

3639

37-
def broadcast(tensor, src_rank):
38-
return torch._C._dist_broadcast(tensor, src_rank)
40+
def broadcast(tensor, src_rank, group=group.WORLD):
41+
return torch._C._dist_broadcast(tensor, src_rank, group)
42+
3943

44+
def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
45+
return torch._C._dist_all_reduce(tensor, op, group)
4046

41-
def all_reduce(tensor, op=reduce_op.SUM):
42-
return torch._C._dist_all_reduce(tensor, op)
4347

48+
def reduce(tensor, dst_rank, op=reduce_op.SUM, group=group.WORLD):
49+
return torch._C._dist_reduce(tensor, dst_rank, op, group)
4450

45-
def reduce(tensor, dst_rank, op=reduce_op.SUM):
46-
return torch._C._dist_reduce(tensor, dst_rank, op)
4751

52+
def new_group(ranks):
53+
return torch._C._dist_new_group(ranks)
4854

49-
assert torch._C._dist_init_extension(reduce_op)
55+
assert torch._C._dist_init_extension(reduce_op, group)

torch/lib/THD/base/DataChannel.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ enum THDReduceOp {
66
THDReduceSUM,
77
THDReducePRODUCT,
88
};
9+
10+
typedef int THDGroup;
11+
static THDGroup THDGroupWORLD = 0;

torch/lib/THD/base/DataChannel.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ struct DataChannel {
1616
virtual int getRank() = 0;
1717
virtual int getNumProcesses() = 0;
1818

19-
virtual void allReduce(Tensor& data, THDReduceOp operation) = 0;
20-
virtual void reduce(Tensor& data, THDReduceOp operation, int dst_rank) = 0;
21-
virtual void broadcast(Tensor& data, int src_rank) = 0;
19+
virtual void allReduce(Tensor& data, THDReduceOp operation, THDGroup group_id = THDGroupWORLD) = 0;
20+
virtual void reduce(Tensor& data, THDReduceOp operation, int dst_rank,
21+
THDGroup group_id = THDGroupWORLD) = 0;
22+
virtual void broadcast(Tensor& data, int src_rank, THDGroup group_id = THDGroupWORLD) = 0;
2223
virtual void send(Tensor& data, int dst_rank) = 0;
2324
virtual void receive(Tensor& data, int src_rank) = 0;
2425

26+
virtual THDGroup newGroup(std::vector<int> ranks) = 0;
27+
2528
static DataChannel* newChannel(THDChannelType type);
2629
};
2730

0 commit comments

Comments
 (0)