18
18
* Copyright (c) 2018 Amazon.com, Inc. or its affiliates. All Rights reserved.
19
19
* Copyright (c) 2019 Research Organization for Information Science
20
20
* and Technology (RIST). All rights reserved.
21
- * Copyright (c) 2018 Triad National Security, LLC. All rights
21
+ * Copyright (c) 2018-2025 Triad National Security, LLC. All rights
22
22
* reserved.
23
23
* Copyright (c) 2021 IBM Corporation. All rights reserved.
24
24
* $COPYRIGHT$
@@ -61,12 +61,16 @@ BEGIN_C_DECLS
61
61
*/
62
62
typedef void (ompi_op_c_handler_fn_t )(const void * , void * , int * ,
63
63
struct ompi_datatype_t * * );
64
+ typedef void (ompi_op_c_handler_bc_fn_t )(const void * , void * , size_t * ,
65
+ struct ompi_datatype_t * * );
64
66
65
67
/**
66
68
* Typedef for fortran user-defined MPI_Ops.
67
69
*/
68
70
typedef void (ompi_op_fortran_handler_fn_t )(const void * , void * ,
69
71
MPI_Fint * , MPI_Fint * );
72
+ typedef void (ompi_op_fortran_handler_bc_fn_t )(const void * , void * ,
73
+ size_t * , MPI_Fint * );
70
74
71
75
/**
72
76
* Typedef for Java op functions intercept (used for user-defined
@@ -98,8 +102,8 @@ typedef void (ompi_op_java_handler_fn_t)(const void *, void *, int *,
98
102
#define OMPI_OP_FLAGS_FLOAT_ASSOC 0x0020
99
103
/** Set if the callback function is communative */
100
104
#define OMPI_OP_FLAGS_COMMUTE 0x0040
101
-
102
-
105
+ /** Set if the callback function is using bigcount */
106
+ #define OMPI_OP_FLAGS_BIGCOUNT 0x0080
103
107
104
108
105
109
/*
@@ -152,8 +156,12 @@ struct ompi_op_t {
152
156
ompi_op_base_op_fns_t intrinsic ;
153
157
/** C handler function pointer */
154
158
ompi_op_c_handler_fn_t * c_fn ;
159
+ /** C handler function pointer - bigcount*/
160
+ ompi_op_c_handler_bc_fn_t * c_fn_bc ;
155
161
/** Fortran handler function pointer */
156
162
ompi_op_fortran_handler_fn_t * fort_fn ;
163
+ /** Fortran handler function pointer - bigcount*/
164
+ ompi_op_fortran_handler_bc_fn_t * fort_fn_bc ;
157
165
/** Java intercept function data */
158
166
struct {
159
167
/* The OMPI C++ callback/intercept function */
@@ -333,6 +341,8 @@ int ompi_op_init(void);
333
341
*
334
342
* @param commute Boolean indicating whether the operation is
335
343
* communative or not
344
+ * @param bigcount Boolean indicating whether or not the op is
345
+ * using the bigcount (MPI_Count) interface
336
346
* @param func Function pointer of the error handler
337
347
*
338
348
* @returns op Pointer to the ompi_op_t that will be
@@ -355,6 +365,7 @@ int ompi_op_init(void);
355
365
* manually.
356
366
*/
357
367
ompi_op_t * ompi_op_create_user (bool commute ,
368
+ bool bigcount ,
358
369
ompi_op_fortran_handler_fn_t func );
359
370
360
371
/**
@@ -512,11 +523,9 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
512
523
* in iterations of counts <= INT_MAX since it has an `int *len`
513
524
* parameter.
514
525
*
515
- * Note: When we add BigCount support then we can distinguish between
516
- * a reduction operation with `int *len` and `MPI_Count *len`. At which
517
- * point we can avoid this loop.
518
526
*/
519
- if ( OPAL_UNLIKELY (full_count > INT_MAX ) ) {
527
+ if (OPAL_UNLIKELY ((full_count > INT_MAX ) &&
528
+ (0 == (op -> o_flags & OMPI_OP_FLAGS_BIGCOUNT )))) {
520
529
size_t done_count = 0 , shift ;
521
530
int iter_count ;
522
531
ptrdiff_t ext , lb ;
@@ -578,8 +587,12 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
578
587
/* User-defined function */
579
588
if (0 != (op -> o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC )) {
580
589
f_dtype = OMPI_INT_2_FINT (dtype -> d_f_to_c_index );
581
- f_count = OMPI_INT_2_FINT (count );
582
- op -> o_func .fort_fn (source , target , & f_count , & f_dtype );
590
+ if (0 == (op -> o_flags & OMPI_OP_FLAGS_BIGCOUNT )) {
591
+ f_count = OMPI_INT_2_FINT (count );
592
+ op -> o_func .fort_fn (source , target , & f_count , & f_dtype );
593
+ } else {
594
+ op -> o_func .fort_fn_bc (source , target , & full_count , & f_dtype );
595
+ }
583
596
return ;
584
597
} else if (0 != (op -> o_flags & OMPI_OP_FLAGS_JAVA_FUNC )) {
585
598
op -> o_func .java_data .intercept_fn (source , target , & count , & dtype ,
@@ -588,15 +601,25 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
588
601
op -> o_func .java_data .object );
589
602
return ;
590
603
}
591
- op -> o_func .c_fn (source , target , & count , & dtype );
604
+ if (0 == (op -> o_flags & OMPI_OP_FLAGS_BIGCOUNT )) {
605
+ op -> o_func .c_fn (source , target , & count , & dtype );
606
+ } else {
607
+ op -> o_func .c_fn_bc (source , target , & full_count , & dtype );
608
+ }
592
609
return ;
593
610
}
594
611
595
612
static inline void ompi_3buff_op_user (ompi_op_t * op , void * restrict source1 , void * restrict source2 ,
596
- void * restrict result , int count , struct ompi_datatype_t * dtype )
613
+ void * restrict result , size_t full_count , struct ompi_datatype_t * dtype )
597
614
{
598
- ompi_datatype_copy_content_same_ddt (dtype , count , (char * )result , (char * )source1 );
599
- op -> o_func .c_fn (source2 , result , & count , & dtype );
615
+ ompi_datatype_copy_content_same_ddt (dtype , full_count , (char * )result , (char * )source1 );
616
+ if (0 == (op -> o_flags & OMPI_OP_FLAGS_BIGCOUNT )) {
617
+ assert (full_count <= INT_MAX );
618
+ int count = (int )full_count ; /* protected by loop in only caller of this function */
619
+ op -> o_func .c_fn (source2 , result , & count , & dtype );
620
+ } else {
621
+ op -> o_func .c_fn_bc (source2 , result , & full_count , & dtype );
622
+ }
600
623
}
601
624
602
625
/**
@@ -618,13 +641,11 @@ static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, v
618
641
* with the values in the source buffer and the result is stored in
619
642
* the target buffer).
620
643
*
621
- * This function will *only* be invoked on intrinsic MPI_Ops.
622
- *
623
644
* Otherwise, this function is the same as ompi_op_reduce.
624
645
*/
625
646
static inline void ompi_3buff_op_reduce (ompi_op_t * op , void * source1 ,
626
647
void * source2 , void * target ,
627
- int count , ompi_datatype_t * dtype )
648
+ size_t full_count , ompi_datatype_t * dtype )
628
649
{
629
650
void * restrict src1 ;
630
651
void * restrict src2 ;
@@ -633,13 +654,36 @@ static inline void ompi_3buff_op_reduce(ompi_op_t * op, void *source1,
633
654
src2 = source2 ;
634
655
tgt = target ;
635
656
657
+ if (OPAL_UNLIKELY ((full_count > INT_MAX ) &&
658
+ (0 == (op -> o_flags & OMPI_OP_FLAGS_BIGCOUNT )))) {
659
+ size_t done_count = 0 , shift , iter_count ;
660
+ ptrdiff_t ext , lb ;
661
+
662
+ ompi_datatype_get_extent (dtype , & lb , & ext );
663
+
664
+ while (done_count < full_count ) {
665
+ if (done_count + INT_MAX > full_count ) {
666
+ iter_count = full_count - done_count ;
667
+ } else {
668
+ iter_count = INT_MAX ;
669
+ }
670
+ shift = done_count * ext ;
671
+ // Recurse one level in iterations of 'int'
672
+ ompi_3buff_op_reduce (op , (char * )source1 + shift , (char * )source2 + shift ,
673
+ (char * )target + shift , iter_count , dtype );
674
+ done_count += iter_count ;
675
+ }
676
+ return ;
677
+ }
678
+
636
679
if (OPAL_LIKELY (ompi_op_is_intrinsic (op ))) {
680
+ int count = (int )full_count ;
637
681
op -> o_3buff_intrinsic .fns [ompi_op_ddt_map [dtype -> id ]](src1 , src2 ,
638
682
tgt , & count ,
639
683
& dtype ,
640
684
op -> o_3buff_intrinsic .modules [ompi_op_ddt_map [dtype -> id ]]);
641
685
} else {
642
- ompi_3buff_op_user (op , src1 , src2 , tgt , count , dtype );
686
+ ompi_3buff_op_user (op , src1 , src2 , tgt , full_count , dtype );
643
687
}
644
688
}
645
689
0 commit comments