Skip to content

Commit

Permalink
Determine MPI Data Types in col_on_comm() & dst_on_comm() to prevent…
Browse files Browse the repository at this point in the history
… displacements overflow. (Fix for #2156) (#2157)

Determine MPI Data Types in col_on_comm() & dst_on_comm() to prevent
displacements overflow.

TYPE: bug fix

KEYWORDS: prevent displacements overflow in MPI_Gatherv() and
MPI_Scatterv() operations

SOURCE: Benjamin Kirk & Negin Sobhani (NSF NCAR / CISL)

DESCRIPTION OF CHANGES:
Problem:
The MPI_Gatherv() and MPI_Scatterv() operations require integer
displacements into the communications buffers. Historically everything
is passed as an MPI_CHAR, causing these displacements to be larger than
otherwise necessary. For large domain sizes this can cause the
displace[] offsets to exceed the maximum int, wrapping to negative
values.

Solution:
This change introduces additional error checking and then uses the
function MPI_Type_match_size() (available since MPI-2.0) to determine a
suitable MPI_Datatype given the input *typesize. The result then is that
the displace[] offsets are in terms of data type extents, rather than
bytes, and less likely to overflow.

ISSUE: Fixes #2156 

LIST OF MODIFIED FILES: 
M       frame/collect_on_comm.c

TESTS CONDUCTED: 
Failed cases run now.

RELEASE NOTE: 
Determine MPI Data Types in col_on_comm() & dst_on_comm() to prevent
displacements overflow.
  • Loading branch information
benkirk authored Feb 5, 2025
1 parent 33ce70c commit af81014
Showing 1 changed file with 54 additions and 19 deletions.
73 changes: 54 additions & 19 deletions frame/collect_on_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
# endif
#endif


int col_on_comm ( int *, int *, void *, int *, void *, int *, int);
int dst_on_comm ( int *, int *, void *, int *, void *, int *, int);

void
void
COLLECT_ON_COMM ( int * comm, int * typesize ,
void * inbuf, int *ninbuf , void * outbuf, int * noutbuf )
{
Expand All @@ -67,8 +67,9 @@ col_on_comm ( int * Fcomm, int * typesize ,
int *displace ;
int noutbuf_loc ;
int root_task ;
MPI_Datatype dtype;
int ierr = -1;
MPI_Comm *comm, dummy_comm ;
int ierr ;

comm = &dummy_comm ;
*comm = MPI_Comm_f2c( *Fcomm ) ;
Expand All @@ -90,28 +91,45 @@ col_on_comm ( int * Fcomm, int * typesize ,
for ( p = 1 , displace[0] = 0 , noutbuf_loc = recvcounts[0] ; p < ntasks ; p++ ) {
displace[p] = displace[p-1]+recvcounts[p-1] ;
noutbuf_loc = noutbuf_loc + recvcounts[p] ;

/* check for overflow: displace is the partial sum of recvcounts, which can overflow for large problems. */
if (displace[p] < 0) {
#ifndef MS_SUA
fprintf(stderr,"%s %d buffer offset overflow!!\n",__FILE__,__LINE__) ;
fprintf(stderr," ---> p = %d,\n ---> displace[%d] = %d,\n ---> typesize = %d\n",
p, p, displace[p], *typesize);
#endif
MPI_Abort(MPI_COMM_WORLD,1) ;
}
}

if ( noutbuf_loc > * noutbuf )
{
#ifndef MS_SUA
fprintf(stderr,"FATAL ERROR: collect_on_comm: noutbuf_loc (%d) > noutbuf (%d)\n",
noutbuf_loc , * noutbuf ) ;
noutbuf_loc , * noutbuf ) ;
fprintf(stderr,"WILL NOT perform the collection operation\n") ;
#endif
MPI_Abort(MPI_COMM_WORLD,1) ;
}

/* multiply everything by the size of the type */
for ( p = 0 ; p < ntasks ; p++ ) {
displace[p] *= *typesize ;
recvcounts[p] *= *typesize ;
}

/* handle different sized data types appropriately. */
ierr = MPI_Type_match_size (MPI_TYPECLASS_REAL, *typesize, &dtype);
if (MPI_SUCCESS != ierr) {
ierr = MPI_Type_match_size (MPI_TYPECLASS_INTEGER, *typesize, &dtype);
if (MPI_SUCCESS != ierr) {
#ifndef MS_SUA
fprintf(stderr,"%s %d FATAL ERROR: unhandled typesize = %d!!\n", __FILE__,__LINE__,*typesize) ;
#endif
MPI_Abort(MPI_COMM_WORLD,1) ;
}
}

ierr = MPI_Gatherv( inbuf , *ninbuf * *typesize , MPI_CHAR ,
outbuf , recvcounts , displace, MPI_CHAR ,
root_task , *comm ) ;
ierr = MPI_Gatherv( inbuf , *ninbuf, dtype,
outbuf , recvcounts , displace, dtype,
root_task , *comm ) ;
#ifndef MS_SUA
if ( ierr != 0 ) fprintf(stderr,"%s %d MPI_Gatherv returns %d\n",__FILE__,__LINE__,ierr ) ;
#endif
Expand Down Expand Up @@ -152,6 +170,8 @@ dst_on_comm ( int * Fcomm, int * typesize ,
int *displace ;
int noutbuf_loc ;
int root_task ;
MPI_Datatype dtype;
int ierr = -1;
MPI_Comm *comm, dummy_comm ;

comm = &dummy_comm ;
Expand All @@ -171,18 +191,34 @@ dst_on_comm ( int * Fcomm, int * typesize ,
for ( p = 1 , displace[0] = 0 , noutbuf_loc = sendcounts[0] ; p < ntasks ; p++ ) {
displace[p] = displace[p-1]+sendcounts[p-1] ;
noutbuf_loc = noutbuf_loc + sendcounts[p] ;

/* check for overflow: displace is the partial sum of sendcounts, which can overflow for large problems. */
if ( (displace[p] < 0) || (noutbuf_loc < 0) ) {
#ifndef MS_SUA
fprintf(stderr,"%s %d buffer offset overflow!!\n",__FILE__,__LINE__) ;
fprintf(stderr," ---> p = %d,\n ---> displace[%d] = %d,\n ---> noutbuf_loc = %d,\n ---> typesize = %d\n",
p, p, displace[p], noutbuf_loc, *typesize);
#endif
MPI_Abort(MPI_COMM_WORLD,1) ;
}
}
}

/* multiply everything by the size of the type */
for ( p = 0 ; p < ntasks ; p++ ) {
displace[p] *= *typesize ;
sendcounts[p] *= *typesize ;
/* handle different sized data types appropriately. */
ierr = MPI_Type_match_size (MPI_TYPECLASS_REAL, *typesize, &dtype);
if (MPI_SUCCESS != ierr) {
ierr = MPI_Type_match_size (MPI_TYPECLASS_INTEGER, *typesize, &dtype);
if (MPI_SUCCESS != ierr) {
#ifndef MS_SUA
fprintf(stderr,"%s %d FATAL ERROR: unhandled typesize = %d!!\n", __FILE__,__LINE__,*typesize) ;
#endif
MPI_Abort(MPI_COMM_WORLD,1) ;
}
}

MPI_Scatterv( inbuf , sendcounts , displace, MPI_CHAR ,
outbuf , *noutbuf * *typesize , MPI_CHAR ,
root_task , *comm ) ;
MPI_Scatterv( inbuf, sendcounts, displace, dtype,
outbuf, *noutbuf, dtype,
root_task, *comm ) ;

free(sendcounts) ;
free(displace) ;
Expand Down Expand Up @@ -241,4 +277,3 @@ rlim_ ()
}
#endif
#endif

0 comments on commit af81014

Please sign in to comment.