Skip to content

Commit 6dbeab8

Browse files
coll/accelerator: duplicate reduce code for reduce_local
Signed-off-by: Akshay Venkatesh <[email protected]>
1 parent 4617d96 commit 6dbeab8

File tree

1 file changed

+54
-13
lines changed

1 file changed

+54
-13
lines changed

ompi/mca/coll/accelerator/coll_accelerator_reduce.c

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
2+
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
33
* Copyright (c) 2004-2023 The University of Tennessee and The University
44
* of Tennessee Research Foundation. All rights
55
* reserved.
@@ -36,7 +36,7 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
3636
mca_coll_base_module_t *module)
3737
{
3838
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
39-
int rank = (comm == NULL) ? -1 : ompi_comm_rank(comm);
39+
int rank = ompi_comm_rank(comm);
4040
ptrdiff_t gap;
4141
char *rbuf1 = NULL, *sbuf1 = NULL, *rbuf2 = NULL;
4242
size_t bufsize;
@@ -71,15 +71,9 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
7171
rbuf2 = rbuf; /* save away original buffer */
7272
rbuf = rbuf1 - gap;
7373
}
74-
75-
if ((comm == NULL) && (root == -1)) {
76-
ompi_op_reduce(op, (void *)sbuf, rbuf, count, dtype);
77-
rc = OMPI_SUCCESS;
78-
} else {
79-
rc = s->c_coll.coll_reduce((void *) sbuf, rbuf, count,
80-
dtype, op, root, comm,
81-
s->c_coll.coll_reduce_module);
82-
}
74+
rc = s->c_coll.coll_reduce((void *) sbuf, rbuf, count,
75+
dtype, op, root, comm,
76+
s->c_coll.coll_reduce_module);
8377

8478
if (NULL != sbuf1) {
8579
free(sbuf1);
@@ -98,6 +92,53 @@ mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count,
9892
struct ompi_op_t *op,
9993
mca_coll_base_module_t *module)
10094
{
101-
return mca_coll_accelerator_reduce(sbuf, rbuf, count, dtype, op, -1, NULL,
102-
module);
95+
ptrdiff_t gap;
96+
char *rbuf1 = NULL, *sbuf1 = NULL, *rbuf2 = NULL;
97+
size_t bufsize;
98+
int rc;
99+
100+
bufsize = opal_datatype_span(&dtype->super, count, &gap);
101+
102+
rc = mca_coll_accelerator_check_buf((void *)sbuf);
103+
if (rc < 0) {
104+
return rc;
105+
}
106+
107+
if ((MPI_IN_PLACE != sbuf) && (rc > 0)) {
108+
sbuf1 = (char*)malloc(bufsize);
109+
if (NULL == sbuf1) {
110+
return OMPI_ERR_OUT_OF_RESOURCE;
111+
}
112+
mca_coll_accelerator_memcpy(sbuf1, sbuf, bufsize);
113+
sbuf = sbuf1 - gap;
114+
}
115+
116+
rc = mca_coll_accelerator_check_buf(rbuf);
117+
if (rc < 0) {
118+
return rc;
119+
}
120+
121+
if (rc > 0) {
122+
rbuf1 = (char*)malloc(bufsize);
123+
if (NULL == rbuf1) {
124+
if (NULL != sbuf1) free(sbuf1);
125+
return OMPI_ERR_OUT_OF_RESOURCE;
126+
}
127+
mca_coll_accelerator_memcpy(rbuf1, rbuf, bufsize);
128+
rbuf2 = rbuf; /* save away original buffer */
129+
rbuf = rbuf1 - gap;
130+
}
131+
132+
ompi_op_reduce(op, (void *)sbuf, rbuf, count, dtype);
133+
rc = OMPI_SUCCESS;
134+
135+
if (NULL != sbuf1) {
136+
free(sbuf1);
137+
}
138+
if (NULL != rbuf1) {
139+
rbuf = rbuf2;
140+
mca_coll_accelerator_memcpy(rbuf, rbuf1, bufsize);
141+
free(rbuf1);
142+
}
143+
return rc;
103144
}

0 commit comments

Comments
 (0)