1
1
/*
2
- * Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
2
+ * Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
3
3
* Copyright (c) 2004-2023 The University of Tennessee and The University
4
4
* of Tennessee Research Foundation. All rights
5
5
* reserved.
@@ -36,7 +36,7 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
36
36
mca_coll_base_module_t * module )
37
37
{
38
38
mca_coll_accelerator_module_t * s = (mca_coll_accelerator_module_t * ) module ;
39
- int rank = ( comm == NULL ) ? -1 : ompi_comm_rank (comm );
39
+ int rank = ompi_comm_rank (comm );
40
40
ptrdiff_t gap ;
41
41
char * rbuf1 = NULL , * sbuf1 = NULL , * rbuf2 = NULL ;
42
42
size_t bufsize ;
@@ -71,15 +71,9 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
71
71
rbuf2 = rbuf ; /* save away original buffer */
72
72
rbuf = rbuf1 - gap ;
73
73
}
74
-
75
- if ((comm == NULL ) && (root == -1 )) {
76
- ompi_op_reduce (op , (void * )sbuf , rbuf , count , dtype );
77
- rc = OMPI_SUCCESS ;
78
- } else {
79
- rc = s -> c_coll .coll_reduce ((void * ) sbuf , rbuf , count ,
80
- dtype , op , root , comm ,
81
- s -> c_coll .coll_reduce_module );
82
- }
74
+ rc = s -> c_coll .coll_reduce ((void * ) sbuf , rbuf , count ,
75
+ dtype , op , root , comm ,
76
+ s -> c_coll .coll_reduce_module );
83
77
84
78
if (NULL != sbuf1 ) {
85
79
free (sbuf1 );
@@ -98,6 +92,53 @@ mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count,
98
92
struct ompi_op_t * op ,
99
93
mca_coll_base_module_t * module )
100
94
{
101
- return mca_coll_accelerator_reduce (sbuf , rbuf , count , dtype , op , -1 , NULL ,
102
- module );
95
+ ptrdiff_t gap ;
96
+ char * rbuf1 = NULL , * sbuf1 = NULL , * rbuf2 = NULL ;
97
+ size_t bufsize ;
98
+ int rc ;
99
+
100
+ bufsize = opal_datatype_span (& dtype -> super , count , & gap );
101
+
102
+ rc = mca_coll_accelerator_check_buf ((void * )sbuf );
103
+ if (rc < 0 ) {
104
+ return rc ;
105
+ }
106
+
107
+ if ((MPI_IN_PLACE != sbuf ) && (rc > 0 )) {
108
+ sbuf1 = (char * )malloc (bufsize );
109
+ if (NULL == sbuf1 ) {
110
+ return OMPI_ERR_OUT_OF_RESOURCE ;
111
+ }
112
+ mca_coll_accelerator_memcpy (sbuf1 , sbuf , bufsize );
113
+ sbuf = sbuf1 - gap ;
114
+ }
115
+
116
+ rc = mca_coll_accelerator_check_buf (rbuf );
117
+ if (rc < 0 ) {
118
+ return rc ;
119
+ }
120
+
121
+ if (rc > 0 ) {
122
+ rbuf1 = (char * )malloc (bufsize );
123
+ if (NULL == rbuf1 ) {
124
+ if (NULL != sbuf1 ) free (sbuf1 );
125
+ return OMPI_ERR_OUT_OF_RESOURCE ;
126
+ }
127
+ mca_coll_accelerator_memcpy (rbuf1 , rbuf , bufsize );
128
+ rbuf2 = rbuf ; /* save away original buffer */
129
+ rbuf = rbuf1 - gap ;
130
+ }
131
+
132
+ ompi_op_reduce (op , (void * )sbuf , rbuf , count , dtype );
133
+ rc = OMPI_SUCCESS ;
134
+
135
+ if (NULL != sbuf1 ) {
136
+ free (sbuf1 );
137
+ }
138
+ if (NULL != rbuf1 ) {
139
+ rbuf = rbuf2 ;
140
+ mca_coll_accelerator_memcpy (rbuf , rbuf1 , bufsize );
141
+ free (rbuf1 );
142
+ }
143
+ return rc ;
103
144
}
0 commit comments