From bcf1d5b9ec75c3879cd9def4cae8b025df2c35d1 Mon Sep 17 00:00:00 2001 From: "Matthew G. F. Dosanjh" Date: Thu, 1 Aug 2024 13:42:45 -0600 Subject: [PATCH] Added a partcomm fix for mismatched types Signed-off-by: Matthew G. F. Dosanjh --- ompi/mca/part/persist/part_persist.h | 28 +++++++++++++------- ompi/mca/part/persist/part_persist_request.h | 2 ++ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/ompi/mca/part/persist/part_persist.h b/ompi/mca/part/persist/part_persist.h index eea447c274c..9ed0d05ea23 100644 --- a/ompi/mca/part/persist/part_persist.h +++ b/ompi/mca/part/persist/part_persist.h @@ -258,19 +258,32 @@ mca_part_persist_progress(void) req->my_recv_tag = req->setup_info[1].setup_tag; req->real_parts = req->setup_info[1].num_parts; req->real_count = req->setup_info[1].count; + req->real_dt_size = req->setup_info[1].dt_size; + err = opal_datatype_type_size(&(req->req_datatype->super), &dt_size_); if(OMPI_SUCCESS != err) return OMPI_ERROR; dt_size = (dt_size_ > (size_t) INT_MAX) ? MPI_UNDEFINED : (int) dt_size_; int32_t bytes = req->real_count * dt_size; - /* Set up persistent sends */ + + + /* Set up persistent sends */ req->persist_reqs = (ompi_request_t**) malloc(sizeof(ompi_request_t*)*(req->real_parts)); req->flags = (int*) calloc(req->real_parts,sizeof(int)); - for(i = 0; i < req->real_parts; i++) { - void *buf = ((void*) (((char*)req->req_addr) + (bytes * i))); - err = MCA_PML_CALL(irecv_init(buf, req->real_count, req->req_datatype, req->world_peer, req->my_send_tag+i, ompi_part_persist.part_comm, &(req->persist_reqs[i]))); - } + + if(req->real_dt_size == dt_size) { + + for(i = 0; i < req->real_parts; i++) { + void *buf = ((void*) (((char*)req->req_addr) + (bytes * i))); + err = MCA_PML_CALL(irecv_init(buf, req->real_count, req->req_datatype, req->world_peer, req->my_send_tag+i, ompi_part_persist.part_comm, &(req->persist_reqs[i]))); + } + } else { + for(i = 0; i < req->real_parts; i++) { + void *buf = ((void*) (((char*)req->req_addr) + (req->real_count * req->real_dt_size * i))); + err = MCA_PML_CALL(irecv_init(buf, req->real_count * req->real_dt_size, MPI_BYTE, req->world_peer, req->my_send_tag+i, ompi_part_persist.part_comm, &(req->persist_reqs[i]))); + } + } err = req->persist_reqs[0]->req_start(req->real_parts, (&(req->persist_reqs[0]))); /* Send back a message */ @@ -372,7 +385,6 @@ mca_part_persist_precv_init(void *buf, dt_size = (dt_size_ > (size_t) INT_MAX) ? MPI_UNDEFINED : (int) dt_size_; req->req_bytes = parts * count * dt_size; - /* Set ompi request initial values */ req->req_ompi.req_persistent = true; req->req_part_complete = true; @@ -433,8 +445,6 @@ mca_part_persist_psend_init(const void* buf, dt_size = (dt_size_ > (size_t) INT_MAX) ? MPI_UNDEFINED : (int) dt_size_; req->req_bytes = parts * count * dt_size; - - /* non-blocking send set-up data */ req->setup_info[0].world_rank = ompi_comm_rank(&ompi_mpi_comm_world.comm); req->setup_info[0].start_tag = ompi_part_persist.next_send_tag; ompi_part_persist.next_send_tag += parts; @@ -445,7 +455,7 @@ mca_part_persist_psend_init(const void* buf, req->real_parts = parts; req->setup_info[0].count = count; req->real_count = count; - + req->setup_info[0].dt_size = dt_size; req->flags = (int*) calloc(req->real_parts, sizeof(int)); diff --git a/ompi/mca/part/persist/part_persist_request.h b/ompi/mca/part/persist/part_persist_request.h index 3eea69109d2..ba55c6bd920 100644 --- a/ompi/mca/part/persist/part_persist_request.h +++ b/ompi/mca/part/persist/part_persist_request.h @@ -40,6 +40,7 @@ struct ompi_mca_persist_setup_t { int start_tag; int setup_tag; size_t num_parts; + size_t dt_size; size_t count; }; @@ -71,6 +72,7 @@ struct mca_part_persist_request_t { size_t real_parts; /**< internal number of partitions */ size_t real_count; + size_t real_dt_size; /**< receiver needs to know how large the sender's datatype is. */ size_t part_size; ompi_request_t** persist_reqs; /**< requests for persistent sends/recvs */