Skip to content

Commit 53c4fb8

Browse files
authored
Merge pull request #13228 from Sergei-Lebedev/topic/fix_ucc_inplace_v5
v5.0.x: refactor UCC collective operations to handle MPI_IN_PLACE correctly
2 parents 48805a2 + 2affc32 commit 53c4fb8

8 files changed

+89
-42
lines changed

ompi/mca/coll/ucc/coll_ucc_allgather.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t
1515
ucc_coll_req_h *req,
1616
mca_coll_ucc_req_t *coll_req)
1717
{
18-
ucc_datatype_t ucc_sdt, ucc_rdt;
18+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
19+
bool is_inplace = (MPI_IN_PLACE == sbuf);
1920
int comm_size = ompi_comm_size(ucc_module->comm);
2021

21-
if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount) ||
22+
if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) ||
2223
!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
2324
goto fallback;
2425
}
25-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
26+
2627
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
28+
if (!is_inplace) {
29+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
30+
}
31+
2732
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
2833
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
2934
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
@@ -50,7 +55,7 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t
5055
}
5156
};
5257

53-
if (MPI_IN_PLACE == sbuf) {
58+
if (is_inplace) {
5459
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
5560
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
5661
}

ompi/mca/coll/ucc/coll_ucc_allgatherv.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t
1717
ucc_coll_req_h *req,
1818
mca_coll_ucc_req_t *coll_req)
1919
{
20-
ucc_datatype_t ucc_sdt, ucc_rdt;
20+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
21+
bool is_inplace = (MPI_IN_PLACE == sbuf);
2122

22-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
2323
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
24+
if (!is_inplace) {
25+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
26+
}
27+
2428
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
2529
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
2630
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
@@ -48,7 +52,7 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t
4852
}
4953
};
5054

51-
if (MPI_IN_PLACE == sbuf) {
55+
if (is_inplace) {
5256
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
5357
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
5458
}

ompi/mca/coll/ucc/coll_ucc_alltoall.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s
1515
ucc_coll_req_h *req,
1616
mca_coll_ucc_req_t *coll_req)
1717
{
18-
ucc_datatype_t ucc_sdt, ucc_rdt;
18+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
19+
bool is_inplace = (MPI_IN_PLACE == sbuf);
1920
int comm_size = ompi_comm_size(ucc_module->comm);
2021

21-
if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size) ||
22+
if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) ||
2223
!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
2324
goto fallback;
2425
}
25-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
26+
2627
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
28+
if (!is_inplace) {
29+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
30+
}
31+
2732
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
2833
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
2934
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
@@ -50,7 +55,7 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s
5055
}
5156
};
5257

53-
if (MPI_IN_PLACE == sbuf) {
58+
if (is_inplace) {
5459
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
5560
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
5661
}

ompi/mca/coll/ucc/coll_ucc_alltoallv.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, const i
1717
ucc_coll_req_h *req,
1818
mca_coll_ucc_req_t *coll_req)
1919
{
20-
ucc_datatype_t ucc_sdt, ucc_rdt;
20+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
21+
bool is_inplace = (MPI_IN_PLACE == sbuf);
2122

22-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
2323
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
24+
if (!is_inplace) {
25+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
26+
}
27+
2428
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
2529
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
2630
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
@@ -49,7 +53,7 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, const i
4953
}
5054
};
5155

52-
if (MPI_IN_PLACE == sbuf) {
56+
if (is_inplace) {
5357
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
5458
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
5559
}

ompi/mca/coll/ucc/coll_ucc_gather.c

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,35 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om
1717
ucc_coll_req_h *req,
1818
mca_coll_ucc_req_t *coll_req)
1919
{
20-
ucc_datatype_t ucc_sdt, ucc_rdt;
20+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
21+
bool is_inplace = (MPI_IN_PLACE == sbuf);
2122
int comm_rank = ompi_comm_rank(ucc_module->comm);
2223
int comm_size = ompi_comm_size(ucc_module->comm);
2324

24-
if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) {
25-
goto fallback;
26-
}
27-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
2825
if (comm_rank == root) {
29-
if (!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
26+
if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) ||
27+
!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
3028
goto fallback;
3129
}
30+
3231
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
33-
if ((COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ||
34-
(MPI_IN_PLACE != sbuf && COLL_UCC_DT_UNSUPPORTED == ucc_sdt)) {
32+
if (!is_inplace) {
33+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
34+
}
35+
36+
if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ||
37+
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
3538
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
36-
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ?
37-
rdtype->super.name : sdtype->super.name);
39+
(COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
40+
sdtype->super.name : rdtype->super.name);
3841
goto fallback;
3942
}
4043
} else {
44+
if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) {
45+
goto fallback;
46+
}
47+
48+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
4149
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) {
4250
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
4351
sdtype->super.name);
@@ -64,7 +72,7 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om
6472
},
6573
};
6674

67-
if (MPI_IN_PLACE == sbuf) {
75+
if (is_inplace) {
6876
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
6977
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
7078
}

ompi/mca/coll/ucc/coll_ucc_gatherv.c

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,24 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc
1717
ucc_coll_req_h *req,
1818
mca_coll_ucc_req_t *coll_req)
1919
{
20-
ucc_datatype_t ucc_sdt, ucc_rdt;
20+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
21+
bool is_inplace = (MPI_IN_PLACE == sbuf);
2122
int comm_rank = ompi_comm_rank(ucc_module->comm);
2223

23-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
2424
if (comm_rank == root) {
2525
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
26-
if ((COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ||
27-
(MPI_IN_PLACE != sbuf && COLL_UCC_DT_UNSUPPORTED == ucc_sdt)) {
26+
if (!is_inplace) {
27+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
28+
}
29+
if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ||
30+
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
2831
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
29-
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ?
30-
rdtype->super.name : sdtype->super.name);
32+
(COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
33+
sdtype->super.name : rdtype->super.name);
3134
goto fallback;
3235
}
3336
} else {
37+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
3438
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) {
3539
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
3640
sdtype->super.name);
@@ -58,7 +62,7 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc
5862
},
5963
};
6064

61-
if (MPI_IN_PLACE == sbuf) {
65+
if (is_inplace) {
6266
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
6367
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
6468
}

ompi/mca/coll/ucc/coll_ucc_scatter.c

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,35 @@ ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount,
1818
ucc_coll_req_h *req,
1919
mca_coll_ucc_req_t *coll_req)
2020
{
21-
ucc_datatype_t ucc_sdt, ucc_rdt;
21+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
22+
bool is_inplace = (MPI_IN_PLACE == rbuf);
2223
int comm_rank = ompi_comm_rank(ucc_module->comm);
2324
int comm_size = ompi_comm_size(ucc_module->comm);
2425

25-
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
2626
if (comm_rank == root) {
27+
if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) ||
28+
!ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) {
29+
goto fallback;
30+
}
31+
2732
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
33+
if (!is_inplace) {
34+
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
35+
}
36+
2837
if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ||
29-
(MPI_IN_PLACE != rbuf && COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
38+
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
3039
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
3140
(COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
3241
sdtype->super.name : rdtype->super.name);
3342
goto fallback;
3443
}
3544
} else {
45+
if (!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) {
46+
goto fallback;
47+
}
48+
49+
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
3650
if (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
3751
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
3852
rdtype->super.name);
@@ -59,7 +73,7 @@ ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount,
5973
},
6074
};
6175

62-
if (MPI_IN_PLACE == rbuf) {
76+
if (is_inplace) {
6377
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
6478
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
6579
}

ompi/mca/coll/ucc/coll_ucc_scatterv.c

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,25 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, const int *scounts,
1818
ucc_coll_req_h *req,
1919
mca_coll_ucc_req_t *coll_req)
2020
{
21-
ucc_datatype_t ucc_sdt, ucc_rdt;
21+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
22+
bool is_inplace = (MPI_IN_PLACE == rbuf);
2223
int comm_rank = ompi_comm_rank(ucc_module->comm);
23-
int comm_size = ompi_comm_size(ucc_module->comm);
2424

25-
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
2625
if (comm_rank == root) {
2726
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
27+
if (!is_inplace) {
28+
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
29+
}
30+
2831
if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ||
29-
(MPI_IN_PLACE != rbuf && COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
32+
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
3033
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
3134
(COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
3235
sdtype->super.name : rdtype->super.name);
3336
goto fallback;
3437
}
35-
3638
} else {
39+
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
3740
if (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
3841
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
3942
rdtype->super.name);
@@ -61,7 +64,7 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, const int *scounts,
6164
},
6265
};
6366

64-
if (MPI_IN_PLACE == rbuf) {
67+
if (is_inplace) {
6568
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
6669
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
6770
}

0 commit comments

Comments
 (0)