1
1
#include < Python.h>
2
2
3
+ #include < memory>
3
4
#include < unordered_map>
5
+ #include < vector>
4
6
5
7
#include " THDP.h"
6
8
@@ -10,6 +12,7 @@ static std::unordered_map<std::string, THDChannelType> name2channel_type = {
10
12
};
11
13
12
14
static std::unordered_map<PyObject*, THDReduceOp> obj2reduceop;
15
+ static std::unordered_map<PyObject*, THDGroup> obj2group;
13
16
14
17
static THPObjectPtr _ensureBytes (PyObject *obj)
15
18
{
@@ -83,6 +86,18 @@ static THDReduceOp _getReduceOp(PyObject *obj)
83
86
return it->second ;
84
87
}
85
88
89
+ static THDGroup _getGroup (PyObject *obj)
90
+ {
91
+ auto it = obj2group.find (obj);
92
+ if (it == obj2group.end ()) {
93
+ if (!THPUtils_checkLong (obj))
94
+ throw std::runtime_error (" group should be an int or one of the values "
95
+ " from torch.distributed.group" );
96
+ return THPUtils_unpackLong (obj);
97
+ }
98
+ return it->second ;
99
+ }
100
+
86
101
PyObject* THDPModule_send (PyObject *_unused, PyObject *args)
87
102
{
88
103
HANDLE_TH_ERRORS
@@ -118,53 +133,97 @@ PyObject* THDPModule_recv(PyObject *_unused, PyObject *args)
118
133
PyObject* THDPModule_allReduce (PyObject *_unused, PyObject *args)
119
134
{
120
135
HANDLE_TH_ERRORS
121
- if (PyTuple_GET_SIZE (args) != 2 || !THPModule_isTensor (PyTuple_GET_ITEM (args, 0 ))) {
122
- THPUtils_invalidArguments (args, " all_reduce" , 1 , " (tensor in_out, reduce_op op)" );
136
+ if (PyTuple_GET_SIZE (args) != 3 || !THPModule_isTensor (PyTuple_GET_ITEM (args, 0 ))) {
137
+ THPUtils_invalidArguments (args, " all_reduce" , 1 , " (tensor in_out, reduce_op op, group gr )" );
123
138
return NULL ;
124
139
}
125
140
141
+ THDGroup group = _getGroup (PyTuple_GET_ITEM (args, 2 ));
126
142
THDReduceOp op = _getReduceOp (PyTuple_GET_ITEM (args, 1 ));
127
143
THDPTensorDesc desc = _makeDescriptor (PyTuple_GET_ITEM (args, 0 ));
128
- THDAllReduce (desc, op);
144
+ THDAllReduce (desc, op, group );
129
145
Py_RETURN_NONE;
130
146
END_HANDLE_TH_ERRORS
131
147
}
132
148
133
149
PyObject* THDPModule_reduce (PyObject *_unused, PyObject *args)
134
150
{
135
151
HANDLE_TH_ERRORS
136
- if (PyTuple_GET_SIZE (args) != 3 || !THPModule_isTensor (PyTuple_GET_ITEM (args, 0 )) ||
152
+ if (PyTuple_GET_SIZE (args) != 4 || !THPModule_isTensor (PyTuple_GET_ITEM (args, 0 )) ||
137
153
!THPUtils_checkLong (PyTuple_GET_ITEM (args, 1 ))) {
138
154
THPUtils_invalidArguments (args, " reduce" , 1 ,
139
- " (tensor reduced, int dst_rank, reduce_op op)" );
155
+ " (tensor reduced, int dst_rank, reduce_op op, group gr )" );
140
156
return NULL ;
141
157
}
142
158
159
+ THDGroup group = _getGroup (PyTuple_GET_ITEM (args, 3 ));
143
160
THDReduceOp op = _getReduceOp (PyTuple_GET_ITEM (args, 2 ));
144
161
THDPTensorDesc desc = _makeDescriptor (PyTuple_GET_ITEM (args, 0 ));
145
162
int dst_rank = THPUtils_unpackLong (PyTuple_GET_ITEM (args, 1 ));
146
- THDReduce (desc, op, dst_rank);
163
+ THDReduce (desc, op, dst_rank, group );
147
164
Py_RETURN_NONE;
148
165
END_HANDLE_TH_ERRORS
149
166
}
150
167
151
168
PyObject* THDPModule_broadcast (PyObject *_unused, PyObject *args)
152
169
{
153
170
HANDLE_TH_ERRORS
154
- if (PyTuple_GET_SIZE (args) != 2 || !THPModule_isTensor (PyTuple_GET_ITEM (args, 0 )) ||
171
+ if (PyTuple_GET_SIZE (args) != 3 || !THPModule_isTensor (PyTuple_GET_ITEM (args, 0 )) ||
155
172
!THPUtils_checkLong (PyTuple_GET_ITEM (args, 1 ))) {
156
- THPUtils_invalidArguments (args, " broadcast" , 1 , " (tensor src_dst, int src_rank)" );
173
+ THPUtils_invalidArguments (args, " broadcast" , 1 ,
174
+ " (tensor src_dst, int src_rank, group gr)" );
157
175
return NULL ;
158
176
}
159
177
178
+ THDGroup group = _getGroup (PyTuple_GET_ITEM (args, 2 ));
160
179
THDPTensorDesc desc = _makeDescriptor (PyTuple_GET_ITEM (args, 0 ));
161
180
int src_rank = THPUtils_unpackLong (PyTuple_GET_ITEM (args, 1 ));
162
- THDBroadcast (desc, src_rank);
181
+ THDBroadcast (desc, src_rank, group );
163
182
Py_RETURN_NONE;
164
183
END_HANDLE_TH_ERRORS
165
184
}
166
185
167
- PyObject* THDPModule_initExtension (PyObject *_unused, PyObject *reduce_op_obj) {
186
+ PyObject* THDPModule_newGroup (PyObject *_unused, PyObject *args)
187
+ {
188
+ HANDLE_TH_ERRORS
189
+ PyObject* sequence = PyTuple_GET_ITEM (args, 0 );
190
+ Py_ssize_t tmp_length;
191
+ std::vector<int > ranks;
192
+
193
+ if (PyTuple_GET_SIZE (args) != 1 || !PySequence_Check (sequence))
194
+ goto invalid_arguments;
195
+
196
+ tmp_length = PySequence_Length (sequence);
197
+ THPUtils_assert (tmp_length >= 0 , " couldn't obtain the length of %s" ,
198
+ THPUtils_typename (sequence));
199
+
200
+ ranks.reserve (static_cast <std::size_t >(tmp_length));
201
+ for (std::size_t i = 0 ; i < ranks.capacity (); ++i) {
202
+ if (!THPUtils_checkLong (PySequence_ITEM (sequence, i)))
203
+ goto invalid_arguments;
204
+
205
+ ranks.push_back (THPUtils_unpackLong (PySequence_ITEM (sequence, i)));
206
+ for (std::size_t j = 0 ; j < i; ++j)
207
+ THPUtils_assert (ranks[i] != ranks[j], " ranks should be unique" );
208
+ }
209
+
210
+ return PyInt_FromLong (THDNewGroup (ranks.data (), ranks.size ()));
211
+
212
+ invalid_arguments:
213
+ THPUtils_invalidArguments (args, " newGroup" , 1 , " (list[int] ranks)" );
214
+ return NULL ;
215
+ END_HANDLE_TH_ERRORS
216
+ }
217
+
218
+ PyObject* THDPModule_initExtension (PyObject *_unused, PyObject *args) {
219
+ if (PyTuple_GET_SIZE (args) != 2 ) {
220
+ THPUtils_invalidArguments (args, " initExtension" , 1 , " (reduce_op obj, group obj)" );
221
+ return NULL ;
222
+ }
223
+
224
+ PyObject* reduce_op_obj = PyTuple_GET_ITEM (args, 0 );
225
+ PyObject* group_obj = PyTuple_GET_ITEM (args, 1 );
226
+
168
227
THPObjectPtr reduce_op;
169
228
#define REGISTER_REDUCE_OP (NAME ) \
170
229
reduce_op = PyObject_GetAttrString (reduce_op_obj, #NAME); \
@@ -175,11 +234,19 @@ PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *reduce_op_obj) {
175
234
REGISTER_REDUCE_OP (MIN);
176
235
REGISTER_REDUCE_OP (MAX);
177
236
#undef REGISTER_REDUCE_OP
237
+
238
+ THPObjectPtr group;
239
+ #define REGISTER_GROUP (NAME ) \
240
+ group = PyObject_GetAttrString (group_obj, #NAME); \
241
+ THPUtils_assert (group, " Missing object for group " #NAME); \
242
+ obj2group.emplace (group.get (), THDGroup##NAME);
243
+ REGISTER_GROUP (WORLD);
244
+ #undef REGISTER_GROUP
178
245
Py_RETURN_TRUE;
179
246
}
180
247
181
248
static struct PyMethodDef _THDPModule_methods[] = {
182
- {" _dist_init_extension" , (PyCFunction)THDPModule_initExtension, METH_O , NULL },
249
+ {" _dist_init_extension" , (PyCFunction)THDPModule_initExtension, METH_VARARGS , NULL },
183
250
{" _dist_init_process_group" , (PyCFunction)THDPModule_initProcessGroup, METH_O, NULL },
184
251
{" _dist_get_rank" , (PyCFunction)THDPModule_getRank, METH_NOARGS, NULL },
185
252
{" _dist_get_num_processes" , (PyCFunction)THDPModule_getNumProcesses, METH_NOARGS, NULL },
@@ -188,6 +255,7 @@ static struct PyMethodDef _THDPModule_methods[] = {
188
255
{" _dist_all_reduce" , (PyCFunction)THDPModule_allReduce, METH_VARARGS, NULL },
189
256
{" _dist_reduce" , (PyCFunction)THDPModule_reduce, METH_VARARGS, NULL },
190
257
{" _dist_broadcast" , (PyCFunction)THDPModule_broadcast, METH_VARARGS, NULL },
258
+ {" _dist_new_group" , (PyCFunction)THDPModule_newGroup, METH_VARARGS, NULL },
191
259
{NULL }
192
260
};
193
261
0 commit comments