aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Axel Huebl <axel.huebl@plasma.ninja> 2022-01-27 15:35:22 -0800
committerGravatar GitHub <noreply@github.com> 2022-01-27 15:35:22 -0800
commit72c9c98548d4a6b78810a88746612ad135300038 (patch)
tree7de6aa9229fe7da8177deb2a42392eacb3be4b18
parentd839c70c7342c3b336ddc535a1b4ff909d30a23b (diff)
downloadWarpX-72c9c98548d4a6b78810a88746612ad135300038.tar.gz
WarpX-72c9c98548d4a6b78810a88746612ad135300038.tar.zst
WarpX-72c9c98548d4a6b78810a88746612ad135300038.zip
Fix: Move MPI Thread Level Check (#2786)
* Fix: Move MPI Thread Level Check Move the MPI thread level check (requested vs. provided) into `InitData()`. This is needed, since we will access the warning logger, which itself needs to be accessed via a WarpX instance. This is also cleaner than just moving this behind the constructor in `main.cpp`, as we have less functions to call around the `WarpX` object usage. * Fix MPI=OFF
-rw-r--r--Source/Initialization/WarpXInitData.cpp3
-rw-r--r--Source/Utils/MPIInitHelpers.H11
-rw-r--r--Source/Utils/MPIInitHelpers.cpp24
-rw-r--r--Source/main.cpp4
4 files changed, 28 insertions, 14 deletions
diff --git a/Source/Initialization/WarpXInitData.cpp b/Source/Initialization/WarpXInitData.cpp
index 3d7325e62..73193e069 100644
--- a/Source/Initialization/WarpXInitData.cpp
+++ b/Source/Initialization/WarpXInitData.cpp
@@ -22,6 +22,7 @@
#include "Filter/NCIGodfreyFilter.H"
#include "Particles/MultiParticleContainer.H"
#include "Parallelization/WarpXCommUtil.H"
+#include "Utils/MPIInitHelpers.H"
#include "Utils/WarpXAlgorithmSelection.H"
#include "Utils/WarpXConst.H"
#include "Utils/WarpXProfilerWrapper.H"
@@ -112,6 +113,8 @@ void
WarpX::InitData ()
{
WARPX_PROFILE("WarpX::InitData()");
+ utils::warpx_check_mpi_thread_level();
+
Print() << "WarpX (" << WarpX::Version() << ")\n";
#ifdef WARPX_QED
Print() << "PICSAR (" << WarpX::PicsarVersion() << ")\n";
diff --git a/Source/Utils/MPIInitHelpers.H b/Source/Utils/MPIInitHelpers.H
index 7147ee0e1..d0c2f8e1d 100644
--- a/Source/Utils/MPIInitHelpers.H
+++ b/Source/Utils/MPIInitHelpers.H
@@ -11,6 +11,13 @@
namespace utils
{
+ /** Return the required MPI threading
+ *
+ * @return the MPI_THREAD_* level required for MPI_Init_thread
+ */
+ int
+ warpx_mpi_thread_required ();
+
/** Initialize MPI
*
* @return pair(required, provided) of MPI thread level from MPI_Init_thread
@@ -21,11 +28,9 @@ namespace utils
/** Check if the requested MPI thread level is valid
*
* Prints warnings and notes otherwise.
- *
- * @param mpi_thread_levels pair(required, provided) of MPI thread level from MPI_Init_thread
*/
void
- warpx_check_mpi_thread_level (std::pair< int, int > const mpi_thread_levels);
+ warpx_check_mpi_thread_level ();
} // namespace utils
diff --git a/Source/Utils/MPIInitHelpers.cpp b/Source/Utils/MPIInitHelpers.cpp
index da9409e31..4dc5b9ae8 100644
--- a/Source/Utils/MPIInitHelpers.cpp
+++ b/Source/Utils/MPIInitHelpers.cpp
@@ -22,11 +22,10 @@
namespace utils
{
- std::pair< int, int >
- warpx_mpi_init (int argc, char* argv[])
+ int
+ warpx_mpi_thread_required ()
{
int thread_required = -1;
- int thread_provided = -1;
#ifdef AMREX_USE_MPI
thread_required = MPI_THREAD_SINGLE; // equiv. to MPI_Init
# ifdef AMREX_USE_OMP
@@ -35,6 +34,16 @@ namespace utils
# ifdef AMREX_MPI_THREAD_MULTIPLE // i.e. for async_io
thread_required = MPI_THREAD_MULTIPLE;
# endif
+#endif
+ return thread_required;
+ }
+
+ std::pair< int, int >
+ warpx_mpi_init (int argc, char* argv[])
+ {
+ int thread_required = warpx_mpi_thread_required();
+ int thread_provided = -1;
+#ifdef AMREX_USE_MPI
MPI_Init_thread(&argc, &argv, thread_required, &thread_provided);
#else
amrex::ignore_unused(argc, argv);
@@ -43,11 +52,12 @@ namespace utils
}
void
- warpx_check_mpi_thread_level (std::pair< int, int > const mpi_thread_levels)
+ warpx_check_mpi_thread_level ()
{
#ifdef AMREX_USE_MPI
- auto const thread_required = mpi_thread_levels.first;
- auto const thread_provided = mpi_thread_levels.second;
+ int thread_required = warpx_mpi_thread_required();
+ int thread_provided = -1;
+ MPI_Query_thread(&thread_provided);
auto mtn = amrex::ParallelDescriptor::mpi_level_to_string;
std::stringstream ss;
@@ -65,8 +75,6 @@ namespace utils
<< "communication performance.";
WarpX::GetInstance().RecordWarning("MPI", ss.str());
}
-#else
- amrex::ignore_unused(mpi_thread_levels);
#endif
}
diff --git a/Source/main.cpp b/Source/main.cpp
index d6939a6f1..5587d4f93 100644
--- a/Source/main.cpp
+++ b/Source/main.cpp
@@ -36,12 +36,10 @@ int main(int argc, char* argv[])
{
using namespace amrex;
- auto mpi_thread_levels = utils::warpx_mpi_init(argc, argv);
+ utils::warpx_mpi_init(argc, argv);
warpx_amrex_init(argc, argv);
- utils::warpx_check_mpi_thread_level(mpi_thread_levels);
-
#if defined(AMREX_USE_HIP) && defined(WARPX_USE_PSATD)
rocfft_setup();
#endif