Skip to content

Commit c71d630

Browse files
authored
Merge pull request #13030 from hppritcha/add_support_for_user_defined_bigops
ops: add support for user-defined big count ops
2 parents 398b8d4 + c484f68 commit c71d630

File tree

3 files changed

+69
-18
lines changed

3 files changed

+69
-18
lines changed

ompi/mpi/c/op_create.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
* Copyright (c) 2008-2009 Cisco Systems, Inc. All rights reserved.
1313
* Copyright (c) 2015 Research Organization for Information Science
1414
* and Technology (RIST). All rights reserved.
15+
* Copyright (c) 2025 Triad National Security, LLC. All rights
16+
* reserved.
1517
* $COPYRIGHT$
1618
*
1719
* Additional copyrights may follow
@@ -57,6 +59,7 @@ int MPI_Op_create(MPI_User_function * function, int commute, MPI_Op * op)
5759
/* Create and cache the op. Sets a refcount of 1. */
5860

5961
*op = ompi_op_create_user(OPAL_INT_TO_BOOL(commute),
62+
false,
6063
(ompi_op_fortran_handler_fn_t *) function);
6164
if (NULL == *op) {
6265
err = MPI_ERR_INTERN;

ompi/op/op.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
* Copyright (c) 2015 Research Organization for Information Science
1818
* and Technology (RIST). All rights reserved.
1919
* Copyright (c) 2018 FUJITSU LIMITED. All rights reserved.
20-
* Copyright (c) 2018 Triad National Security, LLC. All rights
20+
* Copyright (c) 2018-2025 Triad National Security, LLC. All rights
2121
* reserved.
2222
* $COPYRIGHT$
2323
*
@@ -353,6 +353,7 @@ static int ompi_op_finalize (void)
353353
* Create a new MPI_Op
354354
*/
355355
ompi_op_t *ompi_op_create_user(bool commute,
356+
bool bigcount,
356357
ompi_op_fortran_handler_fn_t func)
357358
{
358359
ompi_op_t *new_op;
@@ -382,6 +383,9 @@ ompi_op_t *ompi_op_create_user(bool commute,
382383
if (commute) {
383384
new_op->o_flags |= OMPI_OP_FLAGS_COMMUTE;
384385
}
386+
if(bigcount) {
387+
new_op->o_flags |= OMPI_OP_FLAGS_BIGCOUNT;
388+
}
385389

386390
opal_string_copy(new_op->o_name, "USER OP", sizeof(new_op->o_name));
387391
new_op->o_name[sizeof(new_op->o_name) - 1] = '\0';

ompi/op/op.h

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
* Copyright (c) 2018 Amazon.com, Inc. or its affiliates. All Rights reserved.
1919
* Copyright (c) 2019 Research Organization for Information Science
2020
* and Technology (RIST). All rights reserved.
21-
* Copyright (c) 2018 Triad National Security, LLC. All rights
21+
* Copyright (c) 2018-2025 Triad National Security, LLC. All rights
2222
* reserved.
2323
* Copyright (c) 2021 IBM Corporation. All rights reserved.
2424
* $COPYRIGHT$
@@ -61,12 +61,16 @@ BEGIN_C_DECLS
6161
*/
6262
typedef void (ompi_op_c_handler_fn_t)(const void *, void *, int *,
6363
struct ompi_datatype_t **);
64+
typedef void (ompi_op_c_handler_bc_fn_t)(const void *, void *, size_t *,
65+
struct ompi_datatype_t **);
6466

6567
/**
6668
* Typedef for fortran user-defined MPI_Ops.
6769
*/
6870
typedef void (ompi_op_fortran_handler_fn_t)(const void *, void *,
6971
MPI_Fint *, MPI_Fint *);
72+
typedef void (ompi_op_fortran_handler_bc_fn_t)(const void *, void *,
73+
size_t *, MPI_Fint *);
7074

7175
/**
7276
* Typedef for Java op functions intercept (used for user-defined
@@ -98,8 +102,8 @@ typedef void (ompi_op_java_handler_fn_t)(const void *, void *, int *,
98102
#define OMPI_OP_FLAGS_FLOAT_ASSOC 0x0020
99103
/** Set if the callback function is communative */
100104
#define OMPI_OP_FLAGS_COMMUTE 0x0040
101-
102-
105+
/** Set if the callback function is using bigcount */
106+
#define OMPI_OP_FLAGS_BIGCOUNT 0x0080
103107

104108

105109
/*
@@ -152,8 +156,12 @@ struct ompi_op_t {
152156
ompi_op_base_op_fns_t intrinsic;
153157
/** C handler function pointer */
154158
ompi_op_c_handler_fn_t *c_fn;
159+
/** C handler function pointer - bigcount*/
160+
ompi_op_c_handler_bc_fn_t *c_fn_bc;
155161
/** Fortran handler function pointer */
156162
ompi_op_fortran_handler_fn_t *fort_fn;
163+
/** Fortran handler function pointer - bigcount*/
164+
ompi_op_fortran_handler_bc_fn_t *fort_fn_bc;
157165
/** Java intercept function data */
158166
struct {
159167
/* The OMPI C++ callback/intercept function */
@@ -333,6 +341,8 @@ int ompi_op_init(void);
333341
*
334342
* @param commute Boolean indicating whether the operation is
335343
* communative or not
344+
* @param bigcount Boolean indicating whether or not the op is
345+
* using the bigcount (MPI_Count) interface
336346
* @param func Function pointer of the error handler
337347
*
338348
* @returns op Pointer to the ompi_op_t that will be
@@ -355,6 +365,7 @@ int ompi_op_init(void);
355365
* manually.
356366
*/
357367
ompi_op_t *ompi_op_create_user(bool commute,
368+
bool bigcount,
358369
ompi_op_fortran_handler_fn_t func);
359370

360371
/**
@@ -512,11 +523,9 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
512523
* in iterations of counts <= INT_MAX since it has an `int *len`
513524
* parameter.
514525
*
515-
* Note: When we add BigCount support then we can distinguish between
516-
* a reduction operation with `int *len` and `MPI_Count *len`. At which
517-
* point we can avoid this loop.
518526
*/
519-
if( OPAL_UNLIKELY(full_count > INT_MAX) ) {
527+
if(OPAL_UNLIKELY((full_count > INT_MAX) &&
528+
(0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)))) {
520529
size_t done_count = 0, shift;
521530
int iter_count;
522531
ptrdiff_t ext, lb;
@@ -578,8 +587,12 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
578587
/* User-defined function */
579588
if (0 != (op->o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC)) {
580589
f_dtype = OMPI_INT_2_FINT(dtype->d_f_to_c_index);
581-
f_count = OMPI_INT_2_FINT(count);
582-
op->o_func.fort_fn(source, target, &f_count, &f_dtype);
590+
if (0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)) {
591+
f_count = OMPI_INT_2_FINT(count);
592+
op->o_func.fort_fn(source, target, &f_count, &f_dtype);
593+
} else {
594+
op->o_func.fort_fn_bc(source, target, &full_count, &f_dtype);
595+
}
583596
return;
584597
} else if (0 != (op->o_flags & OMPI_OP_FLAGS_JAVA_FUNC)) {
585598
op->o_func.java_data.intercept_fn(source, target, &count, &dtype,
@@ -588,15 +601,25 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
588601
op->o_func.java_data.object);
589602
return;
590603
}
591-
op->o_func.c_fn(source, target, &count, &dtype);
604+
if (0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)) {
605+
op->o_func.c_fn(source, target, &count, &dtype);
606+
} else {
607+
op->o_func.c_fn_bc(source, target, &full_count, &dtype);
608+
}
592609
return;
593610
}
594611

595612
static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, void * restrict source2,
596-
void * restrict result, int count, struct ompi_datatype_t *dtype)
613+
void * restrict result, size_t full_count, struct ompi_datatype_t *dtype)
597614
{
598-
ompi_datatype_copy_content_same_ddt (dtype, count, (char*)result, (char*)source1);
599-
op->o_func.c_fn (source2, result, &count, &dtype);
615+
ompi_datatype_copy_content_same_ddt (dtype, full_count, (char*)result, (char*)source1);
616+
if (0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)) {
617+
assert(full_count <= INT_MAX);
618+
int count = (int)full_count; /* protected by loop in only caller of this function */
619+
op->o_func.c_fn (source2, result, &count, &dtype);
620+
} else {
621+
op->o_func.c_fn_bc (source2, result, &full_count, &dtype);
622+
}
600623
}
601624

602625
/**
@@ -618,13 +641,11 @@ static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, v
618641
* with the values in the source buffer and the result is stored in
619642
* the target buffer).
620643
*
621-
* This function will *only* be invoked on intrinsic MPI_Ops.
622-
*
623644
* Otherwise, this function is the same as ompi_op_reduce.
624645
*/
625646
static inline void ompi_3buff_op_reduce(ompi_op_t * op, void *source1,
626647
void *source2, void *target,
627-
int count, ompi_datatype_t * dtype)
648+
size_t full_count, ompi_datatype_t * dtype)
628649
{
629650
void *restrict src1;
630651
void *restrict src2;
@@ -633,13 +654,36 @@ static inline void ompi_3buff_op_reduce(ompi_op_t * op, void *source1,
633654
src2 = source2;
634655
tgt = target;
635656

657+
if(OPAL_UNLIKELY((full_count > INT_MAX) &&
658+
(0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)))) {
659+
size_t done_count = 0, shift, iter_count;
660+
ptrdiff_t ext, lb;
661+
662+
ompi_datatype_get_extent(dtype, &lb, &ext);
663+
664+
while(done_count < full_count) {
665+
if(done_count + INT_MAX > full_count) {
666+
iter_count = full_count - done_count;
667+
} else {
668+
iter_count = INT_MAX;
669+
}
670+
shift = done_count * ext;
671+
// Recurse one level in iterations of 'int'
672+
ompi_3buff_op_reduce(op, (char*)source1 + shift, (char *)source2 + shift,
673+
(char*)target + shift, iter_count, dtype);
674+
done_count += iter_count;
675+
}
676+
return;
677+
}
678+
636679
if (OPAL_LIKELY(ompi_op_is_intrinsic (op))) {
680+
int count = (int)full_count;
637681
op->o_3buff_intrinsic.fns[ompi_op_ddt_map[dtype->id]](src1, src2,
638682
tgt, &count,
639683
&dtype,
640684
op->o_3buff_intrinsic.modules[ompi_op_ddt_map[dtype->id]]);
641685
} else {
642-
ompi_3buff_op_user (op, src1, src2, tgt, count, dtype);
686+
ompi_3buff_op_user (op, src1, src2, tgt, full_count, dtype);
643687
}
644688
}
645689

0 commit comments

Comments
 (0)