Skip to content

Commit a650419

Browse files
coll/ucc: fix bigcount support
fixing support of bigcount in UCC coll component, coll flags were not set correctly
1 parent 42c7353 commit a650419

File tree

5 files changed

+28
-39
lines changed

5 files changed

+28
-39
lines changed

ompi/mca/coll/ucc/coll_ucc_allgatherv.c

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
#include "coll_ucc_common.h"
1111

12-
static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int scount,
12+
static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount,
1313
struct ompi_datatype_t *sdtype,
1414
void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
1515
struct ompi_datatype_t *rdtype,
@@ -19,12 +19,13 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc
1919
{
2020
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
2121
bool is_inplace = (MPI_IN_PLACE == sbuf);
22+
uint64_t flags = 0;
2223

2324
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
2425
if (!is_inplace) {
2526
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
2627
}
27-
28+
2829
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
2930
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
3031
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
@@ -33,13 +34,13 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc
3334
goto fallback;
3435
}
3536

36-
uint64_t flags = ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0;
37-
flags |= ompi_disp_array_is_64bit(rdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0;
37+
flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) |
38+
(ompi_disp_array_is_64bit(rdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) |
39+
(is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0);
3840

3941
ucc_coll_args_t coll = {
42+
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
4043
.flags = flags,
41-
.mask = 0,
42-
.flags = 0,
4344
.coll_type = UCC_COLL_TYPE_ALLGATHERV,
4445
.src.info = {
4546
.buffer = (void*)sbuf,
@@ -56,10 +57,6 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc
5657
}
5758
};
5859

59-
if (is_inplace) {
60-
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
61-
coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
62-
}
6360
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
6461
return UCC_OK;
6562
fallback:

ompi/mca/coll/ucc/coll_ucc_alltoallv.c

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co
1919
{
2020
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
2121
bool is_inplace = (MPI_IN_PLACE == sbuf);
22+
uint64_t flags = 0;
2223

2324
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
2425
if (!is_inplace) {
2526
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
2627
}
27-
28+
2829
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
2930
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
3031
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
@@ -34,13 +35,13 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co
3435
}
3536

3637
/* Assumes that send counts/displs and recv counts/displs are both 32-bit or both 64-bit */
37-
uint64_t flags = ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0;
38-
flags |= ompi_disp_array_is_64bit(sdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0;
38+
flags = (ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) |
39+
(ompi_disp_array_is_64bit(sdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) |
40+
(is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0);
3941

4042
ucc_coll_args_t coll = {
43+
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
4144
.flags = flags,
42-
.mask = 0,
43-
.flags = 0,
4445
.coll_type = UCC_COLL_TYPE_ALLTOALLV,
4546
.src.info_v = {
4647
.buffer = (void*)sbuf,
@@ -58,10 +59,6 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co
5859
}
5960
};
6061

61-
if (is_inplace) {
62-
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
63-
coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
64-
}
6562
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
6663
return UCC_OK;
6764
fallback:

ompi/mca/coll/ucc/coll_ucc_gatherv.c

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc
2020
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
2121
bool is_inplace = (MPI_IN_PLACE == sbuf);
2222
int comm_rank = ompi_comm_rank(ucc_module->comm);
23+
uint64_t flags = 0;
2324

2425
if (comm_rank == root) {
2526
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
@@ -42,13 +43,13 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc
4243
}
4344
}
4445

45-
uint64_t flags = ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0;
46-
flags |= ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0;
46+
flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) |
47+
(ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) |
48+
(is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0);
4749

4850
ucc_coll_args_t coll = {
51+
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
4952
.flags = flags,
50-
.mask = 0,
51-
.flags = 0,
5253
.coll_type = UCC_COLL_TYPE_GATHERV,
5354
.root = root,
5455
.src.info = {
@@ -66,10 +67,6 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc
6667
},
6768
};
6869

69-
if (is_inplace) {
70-
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
71-
coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
72-
}
7370
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
7471
return UCC_OK;
7572
fallback:

ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ ucc_status_t mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi
2121
size_t total_count;
2222
int i;
2323
int comm_size = ompi_comm_size(ucc_module->comm);
24+
uint64_t flags = 0;
2425

2526
if (MPI_IN_PLACE == sbuf) {
2627
/* TODO: UCC defines inplace differently:
@@ -46,10 +47,11 @@ ucc_status_t mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi
4647
total_count += ompi_count_array_get(rcounts, i);
4748
}
4849

50+
flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0);
51+
4952
ucc_coll_args_t coll = {
50-
.flags = ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0,
51-
.mask = 0,
52-
.flags = 0,
53+
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
54+
.flags = flags,
5355
.coll_type = UCC_COLL_TYPE_REDUCE_SCATTERV,
5456
.src.info = {
5557
.buffer = (void*)sbuf,

ompi/mca/coll/ucc/coll_ucc_scatterv.c

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t sco
2121
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
2222
bool is_inplace = (MPI_IN_PLACE == rbuf);
2323
int comm_rank = ompi_comm_rank(ucc_module->comm);
24-
24+
uint64_t flags = 0;
2525
if (comm_rank == root) {
2626
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
2727
if (!is_inplace) {
@@ -44,13 +44,13 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t sco
4444
}
4545
}
4646

47-
uint64_t flags = ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0;
48-
flags |= ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0;
47+
flags = (ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) |
48+
(ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) |
49+
(is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0);
4950

5051
ucc_coll_args_t coll = {
52+
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
5153
.flags = flags,
52-
.mask = 0,
53-
.flags = 0,
5454
.coll_type = UCC_COLL_TYPE_SCATTERV,
5555
.root = root,
5656
.src.info_v = {
@@ -68,10 +68,6 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t sco
6868
},
6969
};
7070

71-
if (is_inplace) {
72-
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
73-
coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
74-
}
7571
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
7672
return UCC_OK;
7773
fallback:

0 commit comments

Comments
 (0)