Skip to content

Commit 951beff

Browse files
committed
accelerator: add interface to retrieve memkind
add an API to the accelerator component to retrieve the memory_alloc_kind information that is supported by the component. The values stored/returned are based on the side document that is about to be ratified, see https://github.com/mpi-forum/mem-alloc/blob/main/mem_alloc.tex Signed-off-by: Edgar Gabriel <[email protected]>
1 parent 53459cf commit 951beff

File tree

6 files changed

+72
-11
lines changed

6 files changed

+72
-11
lines changed

ompi/info/info_memkind.c

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,9 @@ static int ompi_info_memkind_get_available(int *num_memkinds, ompi_memkind_t **m
138138
}
139139

140140
int tmp_num = 2;
141-
#if 0
142141
if (0 != strcmp(opal_accelerator_base_selected_component.base_version.mca_component_name, "null")) {
143142
tmp_num++;
144143
}
145-
#endif
146144

147145
ompi_info_memkind_available = (ompi_memkind_t *) malloc (tmp_num * sizeof(ompi_memkind_t));
148146
if (NULL == ompi_info_memkind_available) {
@@ -164,11 +162,9 @@ static int ompi_info_memkind_get_available(int *num_memkinds, ompi_memkind_t **m
164162
ompi_info_memkind_available[1].im_restrictors[1] = strdup ("win_allocate");
165163
ompi_info_memkind_available[1].im_restrictors[2] = strdup ("win_allocate_shared");
166164

167-
#if 0
168165
if (tmp_num > 2) {
169-
opal_accelerator.get_memkind_info (&ompi_info_memkind_available[2]);
166+
opal_accelerator.get_memkind (&ompi_info_memkind_available[2]);
170167
}
171-
#endif
172168
ompi_info_memkind_num_available = tmp_num;
173169

174170
exit:

opal/mca/accelerator/accelerator.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
* reserved.
55
* Copyright (c) Amazon.com, Inc. or its affiliates.
66
* All Rights reserved.
7-
* Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights reserved.
7+
* Copyright (c) 2023-2025 Advanced Micro Devices, Inc. All Rights reserved.
88
* Copyright (c) 2024 The University of Tennessee and The University
99
* of Tennessee Research Foundation. All rights
1010
* reserved.
@@ -81,6 +81,7 @@
8181

8282
#include "opal/class/opal_object.h"
8383
#include "opal/mca/mca.h"
84+
#include "ompi/info/info_memkind.h"
8485

8586
BEGIN_C_DECLS
8687

@@ -654,6 +655,12 @@ typedef int (*opal_accelerator_base_module_get_num_devices_fn_t)(int *num_device
654655
*/
655656
typedef int (*opal_accelerator_base_module_get_mem_bw_fn_t)(int device, float *bw);
656657

658+
/**
659+
* Get the memkind information of the accelerator component.
660+
* @param[OUT] supported Memory alloc kinds supported by component
661+
*
662+
*/
663+
typedef void (*opal_accelerator_base_module_get_memkind_fn_t)(ompi_memkind_t *memkind);
657664

658665
/*
659666
* the standard public API data structure
@@ -700,6 +707,7 @@ typedef struct {
700707

701708
opal_accelerator_base_module_get_num_devices_fn_t num_devices;
702709
opal_accelerator_base_module_get_mem_bw_fn_t get_mem_bw;
710+
opal_accelerator_base_module_get_memkind_fn_t get_memkind;
703711
} opal_accelerator_base_module_t;
704712

705713
/**

opal/mca/accelerator/cuda/accelerator_cuda.c

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "opal/mca/rcache/rcache.h"
2727
#include "opal/util/show_help.h"
2828
#include "opal/util/proc.h"
29+
#include "ompi/info/info_memkind.h"
2930
/* Accelerator API's */
3031
static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *flags);
3132
static int accelerator_cuda_create_stream(int dev_id, opal_accelerator_stream_t **stream);
@@ -80,6 +81,7 @@ static int accelerator_cuda_get_buffer_id(int dev_id, const void *addr, opal_acc
8081
static int accelerator_cuda_sync_stream(opal_accelerator_stream_t *stream);
8182
static int accelerator_cuda_get_num_devices(int *num_devices);
8283
static int accelerator_cuda_get_mem_bw(int device, float *bw);
84+
static void accelerator_cuda_get_memkind(ompi_memkind_t *memkind);
8385

8486
#define GET_STREAM(_stream) \
8587
((_stream) == MCA_ACCELERATOR_STREAM_DEFAULT ? 0 : *((CUstream *) (_stream)->stream))
@@ -125,7 +127,8 @@ opal_accelerator_base_module_t opal_accelerator_cuda_module =
125127
accelerator_cuda_get_buffer_id,
126128

127129
accelerator_cuda_get_num_devices,
128-
accelerator_cuda_get_mem_bw
130+
accelerator_cuda_get_mem_bw,
131+
accelerator_cuda_get_memkind
129132
};
130133

131134
static inline int opal_accelerator_cuda_delayed_init_check(void)
@@ -1218,3 +1221,15 @@ static int accelerator_cuda_get_mem_bw(int device, float *bw)
12181221
*bw = opal_accelerator_cuda_mem_bw[device];
12191222
return OPAL_SUCCESS;
12201223
}
1224+
1225+
static void mca_accelerator_cuda_get_memkind (ompi_memkind_t *memkind)
1226+
{
1227+
memkind->im_name = strdup("cuda");
1228+
memkind->im_no_restrictors = false;
1229+
memkind->im_num_restrictors = 3;
1230+
memkind->im_restrictors[0] = strdup("host");
1231+
memkind->im_restrictors[1] = strdup("device");
1232+
memkind->im_restrictors[2] = strdup("managed");
1233+
1234+
return;
1235+
}

opal/mca/accelerator/null/accelerator_null_component.c

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include "accelerator_null_component.h"
2525
#include "opal/constants.h"
26+
#include "ompi/info/info_memkind.h"
2627
#include <string.h>
2728

2829
/*
@@ -94,6 +95,7 @@ static int accelerator_null_sync_stream(opal_accelerator_stream_t *stream);
9495
static int accelerator_null_get_num_devices(int *num_devices);
9596

9697
static int accelerator_null_get_mem_bw(int device, float *bw);
98+
static void accelerator_null_get_memkind(ompi_memkind_t *memkind);
9799

98100
/*
99101
* Instantiate the public struct with all of our public information
@@ -174,7 +176,8 @@ opal_accelerator_base_module_t opal_accelerator_null_module =
174176
accelerator_null_get_buffer_id,
175177

176178
accelerator_null_get_num_devices,
177-
accelerator_null_get_mem_bw
179+
accelerator_null_get_mem_bw,
180+
accelerator_null_get_memkind
178181
};
179182

180183
static int accelerator_null_open(void)
@@ -393,3 +396,12 @@ static int accelerator_null_get_mem_bw(int device, float *bw)
393396
*bw = 1.0; // return something that is not 0
394397
return OPAL_SUCCESS;
395398
}
399+
400+
static void accelerator_null_get_memkind (ompi_memkind_t *memkind)
401+
{
402+
memkind->im_name = NULL;
403+
memkind->im_no_restrictors = false;
404+
memkind->im_num_restrictors = 0;
405+
406+
return;
407+
}

opal/mca/accelerator/rocm/accelerator_rocm_module.c

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "opal/mca/accelerator/base/base.h"
1717
#include "opal/constants.h"
1818
#include "opal/util/output.h"
19+
#include "ompi/info/info_memkind.h"
1920

2021
/* Accelerator API's */
2122
static int mca_accelerator_rocm_check_addr(const void *addr, int *dev_id, uint64_t *flags);
@@ -74,6 +75,7 @@ static int mca_accelerator_rocm_sync_stream(opal_accelerator_stream_t *stream);
7475
static int mca_accelerator_rocm_get_num_devices(int *num_devices);
7576

7677
static int mca_accelerator_rocm_get_mem_bw(int device, float *bw);
78+
static void mca_accelerator_rocm_get_memkind(ompi_memkind_t *memkind);
7779

7880
#define GET_STREAM(_stream) (_stream == MCA_ACCELERATOR_STREAM_DEFAULT ? 0 : *((hipStream_t *)_stream->stream))
7981

@@ -118,7 +120,8 @@ opal_accelerator_base_module_t opal_accelerator_rocm_module =
118120
mca_accelerator_rocm_get_buffer_id,
119121

120122
mca_accelerator_rocm_get_num_devices,
121-
mca_accelerator_rocm_get_mem_bw
123+
mca_accelerator_rocm_get_mem_bw,
124+
mca_accelerator_rocm_get_memkind
122125
};
123126

124127

@@ -946,3 +949,15 @@ static int mca_accelerator_rocm_get_mem_bw(int device, float *bw)
946949
*bw = opal_accelerator_rocm_mem_bw[device];
947950
return OPAL_SUCCESS;
948951
}
952+
953+
static void mca_accelerator_rocm_get_memkind (ompi_memkind_t *memkind)
954+
{
955+
memkind->im_name = strdup("rocm");
956+
memkind->im_no_restrictors = false;
957+
memkind->im_num_restrictors = 3;
958+
memkind->im_restrictors[0] = strdup("host");
959+
memkind->im_restrictors[1] = strdup("device");
960+
memkind->im_restrictors[2] = strdup("managed");
961+
962+
return;
963+
}

opal/mca/accelerator/ze/accelerator_ze_module.c

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "opal/util/printf.h"
1919
#include "opal/constants.h"
2020
#include "opal/util/output.h"
21+
#include "ompi/info/info_memkind.h"
2122

2223
/* Accelerator API's */
2324
static int mca_accelerator_ze_check_addr(const void *addr, int *dev_id, uint64_t *flags);
@@ -77,6 +78,7 @@ static int mca_accelerator_ze_sync_stream(opal_accelerator_stream_t *stream);
7778
static int mca_accelerator_ze_get_num_devices(int *num_devices);
7879

7980
static int mca_accelerator_ze_get_mem_bw(int device, float *bw);
81+
static void mca_accelerator_ze_get_memkind(ompi_memkind_t *memkind);
8082

8183
opal_accelerator_base_module_t opal_accelerator_ze_module =
8284
{
@@ -118,7 +120,8 @@ opal_accelerator_base_module_t opal_accelerator_ze_module =
118120

119121
.get_buffer_id = mca_accelerator_ze_get_buffer_id,
120122
.num_devices = mca_accelerator_ze_get_num_devices,
121-
.get_mem_bw = mca_accelerator_ze_get_mem_bw
123+
.get_mem_bw = mca_accelerator_ze_get_mem_bw,
124+
.get_memkind = mca_accelerator_ze_get_memkind
122125
};
123126

124127
static int accelerator_ze_dev_handle_to_dev_id(ze_device_handle_t hDevice)
@@ -872,4 +875,16 @@ static int mca_accelerator_ze_get_mem_bw(int device, float *bw)
872875
* TODO
873876
*/
874877
return OPAL_ERR_NOT_IMPLEMENTED;
875-
}
878+
}
879+
880+
static void mca_accelerator_ze_get_memkind (ompi_memkind_t *memkind)
881+
{
882+
memkind->im_name = strdup("level_zero");
883+
memkind->im_no_restrictors = false;
884+
memkind->im_num_restrictors = 3;
885+
memkind->im_restrictors[0] = strdup("host");
886+
memkind->im_restrictors[1] = strdup("device");
887+
memkind->im_restrictors[2] = strdup("shared");
888+
889+
return;
890+
}

0 commit comments

Comments
 (0)