Skip to content

Commit 571c658

Browse files
ompi/coll/accelerator/cuda: implement reduce_local
Signed-off-by: Akshay Venkatesh <[email protected]> bot:notacherrypick
1 parent 57f2404 commit 571c658

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

ompi/mca/coll/accelerator/coll_accelerator.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/*
2+
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
23
* Copyright (c) 2014 The University of Tennessee and The University
34
* of Tennessee Research Foundation. All rights
45
* reserved.
@@ -45,6 +46,11 @@ mca_coll_accelerator_allreduce(const void *sbuf, void *rbuf, size_t count,
4546
struct ompi_communicator_t *comm,
4647
mca_coll_base_module_t *module);
4748

49+
int mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count,
50+
struct ompi_datatype_t *dtype,
51+
struct ompi_op_t *op,
52+
mca_coll_base_module_t *module);
53+
4854
int mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
4955
struct ompi_datatype_t *dtype,
5056
struct ompi_op_t *op,

ompi/mca/coll/accelerator/coll_accelerator_module.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/*
2+
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
23
* Copyright (c) 2014-2017 The University of Tennessee and The University
34
* of Tennessee Research Foundation. All rights
45
* reserved.
@@ -94,6 +95,7 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
9495

9596
accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce;
9697
accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce;
98+
accelerator_module->super.coll_reduce_local = mca_coll_accelerator_reduce_local;
9799
accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block;
98100
if (!OMPI_COMM_IS_INTER(comm)) {
99101
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
@@ -141,6 +143,7 @@ mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
141143

142144
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
143145
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
146+
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local);
144147
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
145148
if (!OMPI_COMM_IS_INTER(comm)) {
146149
/* MPI does not define scan/exscan on intercommunicators */
@@ -159,6 +162,7 @@ mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
159162

160163
ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce);
161164
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce);
165+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_local);
162166
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block);
163167
if (!OMPI_COMM_IS_INTER(comm))
164168
{

ompi/mca/coll/accelerator/coll_accelerator_reduce.c

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/*
2+
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
23
* Copyright (c) 2004-2023 The University of Tennessee and The University
34
* of Tennessee Research Foundation. All rights
45
* reserved.
@@ -35,7 +36,7 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
3536
mca_coll_base_module_t *module)
3637
{
3738
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
38-
int rank = ompi_comm_rank(comm);
39+
int rank = (comm == NULL) ? -1 : ompi_comm_rank(comm);
3940
ptrdiff_t gap;
4041
char *rbuf1 = NULL, *sbuf1 = NULL, *rbuf2 = NULL;
4142
size_t bufsize;
@@ -70,11 +71,18 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
7071
rbuf2 = rbuf; /* save away original buffer */
7172
rbuf = rbuf1 - gap;
7273
}
73-
rc = s->c_coll.coll_reduce((void *) sbuf, rbuf, count,
74-
dtype, op, root, comm,
75-
s->c_coll.coll_reduce_module);
74+
75+
if ((comm == NULL) && (root == -1)) {
76+
ompi_op_reduce(op, (void *)sbuf, rbuf, count, dtype);
77+
rc = OMPI_SUCCESS;
78+
} else {
79+
rc = s->c_coll.coll_reduce((void *) sbuf, rbuf, count,
80+
dtype, op, root, comm,
81+
s->c_coll.coll_reduce_module);
82+
}
7683

7784
if (NULL != sbuf1) {
85+
sbuf = sbuf2;
7886
free(sbuf1);
7987
}
8088
if (NULL != rbuf1) {
@@ -84,3 +92,13 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
8492
}
8593
return rc;
8694
}
95+
96+
int
97+
mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count,
98+
struct ompi_datatype_t *dtype,
99+
struct ompi_op_t *op,
100+
mca_coll_base_module_t *module)
101+
{
102+
return mca_coll_accelerator_reduce(sbuf, rbuf, count, dtype, op, -1, NULL,
103+
module);
104+
}

0 commit comments

Comments
 (0)