Skip to content

Commit a342039

Browse files
authored
Avoid communicating su2double/passivedouble when passivedouble/su2mixedfloat are enough (#2788)
* enum class * select the correct buffer * fix request types * fix nompi build * fix * simplify forward type and mixed precision logic * fix and updates * try to fix race condition * updates * more tolerance
1 parent a9a8f68 commit a342039

36 files changed

Lines changed: 841 additions & 701 deletions

Common/include/code_config.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,20 @@ using su2double = double;
139139
using passivedouble = double;
140140

141141
/*--- Define a type for potentially lower precision operations. ---*/
142+
#ifndef CODI_FORWARD_TYPE
142143
#ifdef USE_MIXED_PRECISION
143144
using su2mixedfloat = float;
144145
#else
145146
using su2mixedfloat = passivedouble;
146147
#endif
148+
#else
149+
/*--- There is no lower precision for forward AD so undefine the macro to simplify
150+
* the logic needed to deal with the multiple type configurations. ---*/
151+
#ifdef USE_MIXED_PRECISION
152+
#undef USE_MIXED_PRECISION
153+
#endif
154+
using su2mixedfloat = su2double;
155+
#endif
147156

148157
/*--- Detect if OpDiLib has to be used. ---*/
149158
#if defined(HAVE_OMP) && defined(CODI_REVERSE_TYPE)

Common/include/geometry/CGeometry.hpp

Lines changed: 106 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,23 @@ class CGeometry {
284284
in point-to-point comms. */
285285
su2double* bufD_P2PRecv{nullptr}; /*!< \brief Data structure for su2double point-to-point receive. */
286286
su2double* bufD_P2PSend{nullptr}; /*!< \brief Data structure for su2double point-to-point send. */
287+
#ifdef CODI_REVERSE_TYPE
288+
passivedouble* bufPD_P2PRecv{nullptr}; /*!< \brief Data structure for passivedouble point-to-point receive. */
289+
passivedouble* bufPD_P2PSend{nullptr}; /*!< \brief Data structure for passivedouble point-to-point send. */
290+
#endif
291+
#ifdef USE_MIXED_PRECISION
292+
su2mixedfloat* bufF_P2PRecv{nullptr}; /*!< \brief Data structure for su2mixedfloat point-to-point receive. */
293+
su2mixedfloat* bufF_P2PSend{nullptr}; /*!< \brief Data structure for su2mixedfloat point-to-point send. */
294+
#endif
287295
unsigned short* bufS_P2PRecv{nullptr}; /*!< \brief Data structure for unsigned long point-to-point receive. */
288296
unsigned short* bufS_P2PSend{nullptr}; /*!< \brief Data structure for unsigned long point-to-point send. */
289297
SU2_MPI::Request* req_P2PSend{nullptr}; /*!< \brief Data structure for point-to-point send requests. */
290298
SU2_MPI::Request* req_P2PRecv{nullptr}; /*!< \brief Data structure for point-to-point recv requests. */
291299

300+
using PassiveRequest = typename SelectMPIWrapper<passivedouble>::W::Request;
301+
PassiveRequest* reqP_P2PSend{nullptr}; /*!< \brief Data structure for point-to-point send requests. */
302+
PassiveRequest* reqP_P2PRecv{nullptr}; /*!< \brief Data structure for point-to-point recv requests. */
303+
292304
/*--- Data structures for periodic communications. ---*/
293305

294306
int maxCountPerPeriodicPoint{0}; /*!< \brief Maximum number of pieces of data sent per vertex in periodic comms. */
@@ -370,7 +382,7 @@ class CGeometry {
370382
* \param[in] countPerPoint - Number of variables per point.
371383
* \param[in] val_reverse - Boolean controlling forward or reverse communication between neighbors.
372384
*/
373-
void PostP2PRecvs(CGeometry* geometry, const CConfig* config, unsigned short commType, unsigned short countPerPoint,
385+
void PostP2PRecvs(CGeometry* geometry, const CConfig* config, COMM_TYPE commType, unsigned short countPerPoint,
374386
bool val_reverse) const;
375387

376388
/*!
@@ -383,9 +395,98 @@ class CGeometry {
383395
* \param[in] val_iMessage - Index of the message in the order they are stored.
384396
* \param[in] val_reverse - Boolean controlling forward or reverse communication between neighbors.
385397
*/
386-
void PostP2PSends(CGeometry* geometry, const CConfig* config, unsigned short commType, unsigned short countPerPoint,
398+
void PostP2PSends(CGeometry* geometry, const CConfig* config, COMM_TYPE commType, unsigned short countPerPoint,
387399
int val_iMessage, bool val_reverse) const;
388400

401+
/*!
402+
* \brief Returns the COMM_TYPE enum for a given data type.
403+
*/
404+
template <class T>
405+
COMM_TYPE GetCommType() const {
406+
if constexpr (std::is_same_v<T, su2double>) {
407+
return COMM_TYPE::DOUBLE;
408+
} else if constexpr (std::is_same_v<T, passivedouble>) {
409+
return COMM_TYPE::PASSIVE_DOUBLE;
410+
} else if constexpr (std::is_same_v<T, su2mixedfloat>) {
411+
return COMM_TYPE::FLOAT;
412+
} else {
413+
static_assert(std::is_same_v<T, unsigned short>);
414+
return COMM_TYPE::UNSIGNED_SHORT;
415+
}
416+
}
417+
418+
/*!
419+
* \brief Returns the send buffer for a given data type.
420+
*/
421+
template <class T>
422+
auto* GetP2PSendBuf() const {
423+
if constexpr (std::is_same_v<T, su2double>) {
424+
return bufD_P2PSend;
425+
#ifdef CODI_REVERSE_TYPE
426+
} else if constexpr (std::is_same_v<T, passivedouble>) {
427+
return bufPD_P2PSend;
428+
#endif
429+
#ifdef USE_MIXED_PRECISION
430+
} else if constexpr (std::is_same_v<T, su2mixedfloat>) {
431+
return bufF_P2PSend;
432+
#endif
433+
} else {
434+
static_assert(std::is_same_v<T, unsigned short>);
435+
return bufS_P2PSend;
436+
}
437+
}
438+
439+
/*!
440+
* \brief Returns the receive buffer for a given data type.
441+
*/
442+
template <class T>
443+
auto* GetP2PRecvBuf() const {
444+
if constexpr (std::is_same_v<T, su2double>) {
445+
return bufD_P2PRecv;
446+
#ifdef CODI_REVERSE_TYPE
447+
} else if constexpr (std::is_same_v<T, passivedouble>) {
448+
return bufPD_P2PRecv;
449+
#endif
450+
#ifdef USE_MIXED_PRECISION
451+
} else if constexpr (std::is_same_v<T, su2mixedfloat>) {
452+
return bufF_P2PRecv;
453+
#endif
454+
} else {
455+
static_assert(std::is_same_v<T, unsigned short>);
456+
return bufS_P2PRecv;
457+
}
458+
}
459+
460+
/*!
461+
* \brief Returns the send requests for a given data type.
462+
*/
463+
template <class T>
464+
auto* GetP2PSendReq() const {
465+
if constexpr (std::is_same_v<T, su2double>) {
466+
return req_P2PSend;
467+
} else if constexpr (std::is_same_v<T, passivedouble> || std::is_same_v<T, su2mixedfloat>) {
468+
return reqP_P2PSend;
469+
} else {
470+
static_assert(std::is_same_v<T, unsigned short>);
471+
return req_P2PSend;
472+
}
473+
}
474+
475+
/*!
476+
* \brief Returns the receive requests for a given data type.
477+
*/
478+
template <class T>
479+
auto* GetP2PRecvReq() const {
480+
if constexpr (std::is_same_v<T, su2double>) {
481+
return req_P2PRecv;
482+
} else if constexpr (std::is_same_v<T, passivedouble> || std::is_same_v<T, su2mixedfloat>) {
483+
return reqP_P2PRecv;
484+
} else {
485+
static_assert(std::is_same_v<T, unsigned short>);
486+
return req_P2PRecv;
487+
}
488+
}
489+
389490
/*!
390491
* \brief Routine to set up persistent data structures for periodic communications.
391492
* \param[in] geometry - Geometrical definition of the problem.
@@ -408,7 +509,7 @@ class CGeometry {
408509
* \param[in] commType - Enumerated type for the quantity to be communicated.
409510
* \param[in] countPerPeriodicPoint - Number of variables per point.
410511
*/
411-
void PostPeriodicRecvs(CGeometry* geometry, const CConfig* config, unsigned short commType,
512+
void PostPeriodicRecvs(CGeometry* geometry, const CConfig* config, COMM_TYPE commType,
412513
unsigned short countPerPeriodicPoint);
413514

414515
/*!
@@ -420,7 +521,7 @@ class CGeometry {
420521
* \param[in] countPerPeriodicPoint - Number of variables per point.
421522
* \param[in] val_iMessage - Index of the message in the order they are stored.
422523
*/
423-
void PostPeriodicSends(CGeometry* geometry, const CConfig* config, unsigned short commType,
524+
void PostPeriodicSends(CGeometry* geometry, const CConfig* config, COMM_TYPE commType,
424525
unsigned short countPerPeriodicPoint, int val_iMessage) const;
425526

426527
/*!
@@ -431,7 +532,7 @@ class CGeometry {
431532
* \param[out] MPI_TYPE - Enumerated type for the datatype of the quantity to be communicated.
432533
*/
433534
void GetCommCountAndType(const CConfig* config, MPI_QUANTITIES commType, unsigned short& COUNT_PER_POINT,
434-
unsigned short& MPI_TYPE) const;
535+
COMM_TYPE& MPI_TYPE) const;
435536

436537
/*!
437538
* \brief Routine to load a geometric quantity into the data structures for MPI point-to-point communication and to

Common/include/geometry/CPhysicalGeometry.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ class CPhysicalGeometry final : public CGeometry {
220220
*/
221221
void InitiateCommsAll(void* bufSend, const int* nElemSend, SU2_MPI::Request* sendReq, void* bufRecv,
222222
const int* nElemRecv, SU2_MPI::Request* recvReq, unsigned short countPerElem,
223-
unsigned short commType);
223+
COMM_TYPE commType);
224224

225225
/*!
226226
* \brief Routine to complete the set of non-blocking communications launched with InitiateComms() with MPI_Waitany().

Common/include/geometry/meshreader/CCGNSMeshReaderFVM.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class CCGNSMeshReaderFVM final : public CCGNSMeshReaderBase {
7676
*/
7777
void InitiateCommsAll(void* bufSend, const int* nElemSend, SU2_MPI::Request* sendReq, void* bufRecv,
7878
const int* nElemRecv, SU2_MPI::Request* recvReq, unsigned short countPerElem,
79-
unsigned short commType);
79+
COMM_TYPE commType);
8080

8181
/*!
8282
* \brief Routine to complete the set of non-blocking communications launched with InitiateComms() with MPI_Waitany().

Common/include/grid_movement/CLinearElasticity.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,8 @@ class CLinearElasticity final : public CVolumetricMovement {
4545

4646
unsigned long nIterMesh; /*!< \brief Number of iterations in the mesh update. +*/
4747

48-
#ifndef CODI_FORWARD_TYPE
4948
CSysMatrix<su2mixedfloat> StiffMatrix; /*!< \brief Stiffness matrix of the elasticity problem. */
5049
CSysSolve<su2mixedfloat> System; /*!< \brief Linear solver/smoother. */
51-
#else
52-
CSysMatrix<su2double> StiffMatrix;
53-
CSysSolve<su2double> System;
54-
#endif
5550
CSysVector<su2double> LinSysSol;
5651
CSysVector<su2double> LinSysRes;
5752

Common/include/linear_algebra/CSysSolve.hpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,17 @@ class CSysSolve {
425425
* \param[in] config - Definition of the particular problem.
426426
* \param[in] directCall - If this method is called directly, or in AD context.
427427
*/
428-
unsigned long Solve_b(MatrixType& Jacobian, const CSysVector<su2double>& LinSysRes, CSysVector<su2double>& LinSysSol,
429-
CGeometry* geometry, const CConfig* config, bool directCall = true);
428+
unsigned long Solve_b(MatrixType& Jacobian, const VectorType& LinSysRes, VectorType& LinSysSol, CGeometry* geometry,
429+
const CConfig* config, bool directCall = true);
430+
431+
template <class OtherType, su2enable_if<!std::is_same_v<ScalarType, OtherType>> = 0>
432+
unsigned long Solve_b(MatrixType& Jacobian, const CSysVector<OtherType>& LinSysRes, CSysVector<OtherType>& LinSysSol,
433+
CGeometry* geometry, const CConfig* config, bool directCall = true) {
434+
HandleTemporariesIn(LinSysRes, LinSysSol);
435+
auto iter = Solve_b(Jacobian, *LinSysRes_ptr, *LinSysSol_ptr, geometry, config, directCall);
436+
HandleTemporariesOut(LinSysSol);
437+
return iter;
438+
}
430439

431440
/*!
432441
* \brief Get the number of iterations.

Common/include/option_structure.hpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,17 @@ const int MASTER_NODE = 0; /*!< \brief Master node for MPI parallelization.
113113
const int SINGLE_NODE = 1; /*!< \brief There is only a node in the MPI parallelization. */
114114
const int SINGLE_ZONE = 1; /*!< \brief There is only a zone. */
115115

116-
const unsigned short COMM_TYPE_UNSIGNED_LONG = 1; /*!< \brief Communication type for unsigned long. */
117-
const unsigned short COMM_TYPE_LONG = 2; /*!< \brief Communication type for long. */
118-
const unsigned short COMM_TYPE_UNSIGNED_SHORT = 3; /*!< \brief Communication type for unsigned short. */
119-
const unsigned short COMM_TYPE_DOUBLE = 4; /*!< \brief Communication type for double. */
120-
const unsigned short COMM_TYPE_CHAR = 5; /*!< \brief Communication type for char. */
121-
const unsigned short COMM_TYPE_SHORT = 6; /*!< \brief Communication type for short. */
122-
const unsigned short COMM_TYPE_INT = 7; /*!< \brief Communication type for int. */
116+
enum class COMM_TYPE {
117+
UNSIGNED_LONG, /*!< \brief Communication type for unsigned long. */
118+
LONG, /*!< \brief Communication type for long. */
119+
UNSIGNED_SHORT, /*!< \brief Communication type for unsigned short. */
120+
FLOAT, /*!< \brief Communication type for su2mixedfloat. */
121+
DOUBLE, /*!< \brief Communication type for double. */
122+
PASSIVE_DOUBLE, /*!< \brief Communication type for passivedouble. */
123+
CHAR, /*!< \brief Communication type for char. */
124+
SHORT, /*!< \brief Communication type for short. */
125+
INT, /*!< \brief Communication type for int. */
126+
};
123127

124128
/*!
125129
* \brief Types of geometric entities based on VTK nomenclature

Common/include/parallelization/mpi_structure.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ extern MediTypes* mediTypes;
8181

8282
#else
8383
class CBaseMPIWrapper;
84-
typedef CBaseMPIWrapper SU2_MPI;
84+
using SU2_MPI = CBaseMPIWrapper;
8585
#endif // defined CODI_REVERSE_TYPE || defined CODI_FORWARD_TYPE
8686

8787
/*!
@@ -632,7 +632,7 @@ struct SelectMPIWrapper<passivedouble> {
632632
#endif
633633

634634
/*--- Specialize for the low precision type. ---*/
635-
#if defined USE_MIXED_PRECISION
635+
#if defined(USE_MIXED_PRECISION)
636636
template <>
637637
struct SelectMPIWrapper<su2mixedfloat> {
638638
#if defined HAVE_MPI

0 commit comments

Comments
 (0)