Skip to content

Commit 5a88242

Browse files
committed
accelerator: add interface to retrieve memkind
add an API to the accelerator component to retrieve the memory_alloc_kind information that is supported by the component. The values stored/returned are based on the side document that is about to be ratified, see https://github.com/mpi-forum/mem-alloc/blob/main/mem_alloc.tex Signed-off-by: Edgar Gabriel <[email protected]>
1 parent ce5ff37 commit 5a88242

File tree

6 files changed

+116
-11
lines changed

6 files changed

+116
-11
lines changed

ompi/info/info_memkind.c

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,9 @@ static int ompi_info_memkind_get_available(int *num_memkinds, ompi_memkind_t **m
166166
}
167167

168168
int tmp_num = 2;
169-
#if 0
170169
if (0 != strcmp(opal_accelerator_base_selected_component.base_version.mca_component_name, "null")) {
171170
tmp_num++;
172171
}
173-
#endif
174172

175173
ompi_info_memkind_available = (ompi_memkind_t *) malloc (tmp_num * sizeof(ompi_memkind_t));
176174
if (NULL == ompi_info_memkind_available) {
@@ -190,11 +188,12 @@ static int ompi_info_memkind_get_available(int *num_memkinds, ompi_memkind_t **m
190188
ompi_info_memkind_available[1].im_restrictors[1] = strdup ("win_allocate");
191189
ompi_info_memkind_available[1].im_restrictors[2] = strdup ("win_allocate_shared");
192190

193-
#if 0
194191
if (tmp_num > 2) {
195-
opal_accelerator.get_memkind_info (&ompi_info_memkind_available[2]);
192+
ompi_info_memkind_available[2].im_num_restrictors = OMPI_MAX_NUM_MEMKIND_RESTRICTORS;
193+
opal_accelerator.get_memkind (&(ompi_info_memkind_available[2].im_name),
194+
&(ompi_info_memkind_available[2].im_num_restrictors),
195+
(char**)ompi_info_memkind_available[2].im_restrictors);
196196
}
197-
#endif
198197
ompi_info_memkind_num_available = tmp_num;
199198

200199
exit:

opal/mca/accelerator/accelerator.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
* reserved.
55
* Copyright (c) Amazon.com, Inc. or its affiliates.
66
* All Rights reserved.
7-
* Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights reserved.
7+
* Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All Rights reserved.
88
* Copyright (c) 2024 The University of Tennessee and The University
99
* of Tennessee Research Foundation. All rights
1010
* reserved.
@@ -654,6 +654,22 @@ typedef int (*opal_accelerator_base_module_get_num_devices_fn_t)(int *num_device
654654
*/
655655
typedef int (*opal_accelerator_base_module_get_mem_bw_fn_t)(int device, float *bw);
656656

657+
/**
658+
* Get the memkind information of the accelerator component.
659+
*
660+
* @param[OUT] name Name of memory alloc kinds supported by component.
661+
* This field will have to be released by the calling function.
662+
* @param[INOUT] num_restrictors As input, this parameter represents the lenght of the
663+
* restrictors array allocated by the caller.
664+
* At return, this variable will indicate the number of
665+
* restrictors set by the function
666+
* @param[OUT] restrictors Array of restrictors supported by the component.
667+
* The array of char* pointers has been allocated by the caller.
668+
* The elements of the array will have to be released by the caller.
669+
*
670+
*/
671+
typedef void (*opal_accelerator_base_module_get_memkind_fn_t)(char **name, int *num_restrictors,
672+
char **restrictors);
657673

658674
/*
659675
* the standard public API data structure
@@ -700,6 +716,7 @@ typedef struct {
700716

701717
opal_accelerator_base_module_get_num_devices_fn_t num_devices;
702718
opal_accelerator_base_module_get_mem_bw_fn_t get_mem_bw;
719+
opal_accelerator_base_module_get_memkind_fn_t get_memkind;
703720
} opal_accelerator_base_module_t;
704721

705722
/**

opal/mca/accelerator/cuda/accelerator_cuda.c

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "opal/mca/rcache/rcache.h"
2727
#include "opal/util/show_help.h"
2828
#include "opal/util/proc.h"
29+
2930
/* Accelerator API's */
3031
static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *flags);
3132
static int accelerator_cuda_create_stream(int dev_id, opal_accelerator_stream_t **stream);
@@ -80,10 +81,14 @@ static int accelerator_cuda_get_buffer_id(int dev_id, const void *addr, opal_acc
8081
static int accelerator_cuda_sync_stream(opal_accelerator_stream_t *stream);
8182
static int accelerator_cuda_get_num_devices(int *num_devices);
8283
static int accelerator_cuda_get_mem_bw(int device, float *bw);
84+
static void accelerator_cuda_get_memkind(char **name, int *num_restrictors, char **restrictors);
8385

8486
#define GET_STREAM(_stream) \
8587
((_stream) == MCA_ACCELERATOR_STREAM_DEFAULT ? 0 : *((CUstream *) (_stream)->stream))
8688

89+
// This value is based on the memory kind MPI side document
90+
#define MCA_ACCELERATOR_CUDA_NUM_RESTRICTORS 3
91+
8792
opal_accelerator_base_module_t opal_accelerator_cuda_module =
8893
{
8994
accelerator_cuda_check_addr,
@@ -125,7 +130,8 @@ opal_accelerator_base_module_t opal_accelerator_cuda_module =
125130
accelerator_cuda_get_buffer_id,
126131

127132
accelerator_cuda_get_num_devices,
128-
accelerator_cuda_get_mem_bw
133+
accelerator_cuda_get_mem_bw,
134+
accelerator_cuda_get_memkind
129135
};
130136

131137
static inline int opal_accelerator_cuda_delayed_init_check(void)
@@ -1218,3 +1224,24 @@ static int accelerator_cuda_get_mem_bw(int device, float *bw)
12181224
*bw = opal_accelerator_cuda_mem_bw[device];
12191225
return OPAL_SUCCESS;
12201226
}
1227+
1228+
static void accelerator_cuda_get_memkind (char **name, int *num_restrictors, char **restrictors)
1229+
{
1230+
int n_restrictors = *num_restrictors > MCA_ACCELERATOR_CUDA_NUM_RESTRICTORS ?
1231+
MCA_ACCELERATOR_CUDA_NUM_RESTRICTORS : *num_restrictors;
1232+
1233+
*name = strdup("cuda");
1234+
1235+
if (n_restrictors > 0) {
1236+
restrictors[0] = strdup("host");
1237+
}
1238+
if (n_restrictors > 1) {
1239+
restrictors[1] = strdup("device");
1240+
}
1241+
if (n_restrictors > 2) {
1242+
restrictors[2] = strdup("managed");
1243+
}
1244+
*num_restrictors = n_restrictors;
1245+
1246+
return;
1247+
}

opal/mca/accelerator/null/accelerator_null_component.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ static int accelerator_null_sync_stream(opal_accelerator_stream_t *stream);
9494
static int accelerator_null_get_num_devices(int *num_devices);
9595

9696
static int accelerator_null_get_mem_bw(int device, float *bw);
97+
static void accelerator_null_get_memkind(char **name, int *num_restrictors, char **restrictors);
9798

9899
/*
99100
* Instantiate the public struct with all of our public information
@@ -174,7 +175,8 @@ opal_accelerator_base_module_t opal_accelerator_null_module =
174175
accelerator_null_get_buffer_id,
175176

176177
accelerator_null_get_num_devices,
177-
accelerator_null_get_mem_bw
178+
accelerator_null_get_mem_bw,
179+
accelerator_null_get_memkind
178180
};
179181

180182
static int accelerator_null_open(void)
@@ -393,3 +395,11 @@ static int accelerator_null_get_mem_bw(int device, float *bw)
393395
*bw = 1.0; // return something that is not 0
394396
return OPAL_SUCCESS;
395397
}
398+
399+
static void accelerator_null_get_memkind (char **name, int *num_restrictors, char **restrictors)
400+
{
401+
*name = NULL;
402+
*num_restrictors = 0;
403+
404+
return;
405+
}

opal/mca/accelerator/rocm/accelerator_rocm_module.c

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,13 @@ static int mca_accelerator_rocm_sync_stream(opal_accelerator_stream_t *stream);
7474
static int mca_accelerator_rocm_get_num_devices(int *num_devices);
7575

7676
static int mca_accelerator_rocm_get_mem_bw(int device, float *bw);
77+
static void mca_accelerator_rocm_get_memkind(char **name, int *num_restrictors, char **restrictors);
7778

7879
#define GET_STREAM(_stream) (_stream == MCA_ACCELERATOR_STREAM_DEFAULT ? 0 : *((hipStream_t *)_stream->stream))
7980

81+
// This value is based on the memory kind MPI side document
82+
#define MCA_ACCELERATOR_ROCM_NUM_RESTRICTORS 3
83+
8084
opal_accelerator_base_module_t opal_accelerator_rocm_module =
8185
{
8286
mca_accelerator_rocm_check_addr,
@@ -118,7 +122,8 @@ opal_accelerator_base_module_t opal_accelerator_rocm_module =
118122
mca_accelerator_rocm_get_buffer_id,
119123

120124
mca_accelerator_rocm_get_num_devices,
121-
mca_accelerator_rocm_get_mem_bw
125+
mca_accelerator_rocm_get_mem_bw,
126+
mca_accelerator_rocm_get_memkind
122127
};
123128

124129

@@ -946,3 +951,24 @@ static int mca_accelerator_rocm_get_mem_bw(int device, float *bw)
946951
*bw = opal_accelerator_rocm_mem_bw[device];
947952
return OPAL_SUCCESS;
948953
}
954+
955+
static void mca_accelerator_rocm_get_memkind (char **name, int *num_restrictors, char **restrictors)
956+
{
957+
int n_restrictors = *num_restrictors > MCA_ACCELERATOR_ROCM_NUM_RESTRICTORS ?
958+
MCA_ACCELERATOR_ROCM_NUM_RESTRICTORS : *num_restrictors;
959+
960+
*name = strdup("rocm");
961+
962+
if (n_restrictors > 0) {
963+
restrictors[0] = strdup("host");
964+
}
965+
if (n_restrictors > 1) {
966+
restrictors[1] = strdup("device");
967+
}
968+
if (n_restrictors > 2) {
969+
restrictors[2] = strdup("managed");
970+
}
971+
*num_restrictors = n_restrictors;
972+
973+
return;
974+
}

opal/mca/accelerator/ze/accelerator_ze_module.c

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ static int mca_accelerator_ze_sync_stream(opal_accelerator_stream_t *stream);
7777
static int mca_accelerator_ze_get_num_devices(int *num_devices);
7878

7979
static int mca_accelerator_ze_get_mem_bw(int device, float *bw);
80+
static void mca_accelerator_ze_get_memkind(char **name, int *num_restrictors, char **restrictors);
81+
82+
// This value is based on the memory kind MPI side document
83+
#define MCA_ACCELERATOR_ZE_NUM_RESTRICTORS 3
8084

8185
opal_accelerator_base_module_t opal_accelerator_ze_module =
8286
{
@@ -118,7 +122,8 @@ opal_accelerator_base_module_t opal_accelerator_ze_module =
118122

119123
.get_buffer_id = mca_accelerator_ze_get_buffer_id,
120124
.num_devices = mca_accelerator_ze_get_num_devices,
121-
.get_mem_bw = mca_accelerator_ze_get_mem_bw
125+
.get_mem_bw = mca_accelerator_ze_get_mem_bw,
126+
.get_memkind = mca_accelerator_ze_get_memkind
122127
};
123128

124129
static int accelerator_ze_dev_handle_to_dev_id(ze_device_handle_t hDevice)
@@ -872,4 +877,25 @@ static int mca_accelerator_ze_get_mem_bw(int device, float *bw)
872877
* TODO
873878
*/
874879
return OPAL_ERR_NOT_IMPLEMENTED;
875-
}
880+
}
881+
882+
static void mca_accelerator_ze_get_memkind (char **name, int *num_restrictors, char **restrictors)
883+
{
884+
int n_restrictors = *num_restrictors > MCA_ACCELERATOR_ZE_NUM_RESTRICTORS ?
885+
MCA_ACCELERATOR_ZE_NUM_RESTRICTORS : *num_restrictors;
886+
887+
*name = strdup("level_zero");
888+
889+
if (n_restrictors > 0) {
890+
restrictors[0] = strdup("host");
891+
}
892+
if (n_restrictors > 1) {
893+
restrictors[1] = strdup("device");
894+
}
895+
if (n_restrictors > 2) {
896+
restrictors[2] = strdup("shared");
897+
}
898+
*num_restrictors = n_restrictors;
899+
900+
return;
901+
}

0 commit comments

Comments
 (0)