aboutsummaryrefslogtreecommitdiff
path: root/Source
diff options
context:
space:
mode:
Diffstat (limited to 'Source')
-rw-r--r--Source/Initialization/WarpXInitData.cpp3
-rw-r--r--Source/Utils/MPIInitHelpers.H11
-rw-r--r--Source/Utils/MPIInitHelpers.cpp24
-rw-r--r--Source/main.cpp4
4 files changed, 28 insertions, 14 deletions
diff --git a/Source/Initialization/WarpXInitData.cpp b/Source/Initialization/WarpXInitData.cpp
index 3d7325e62..73193e069 100644
--- a/Source/Initialization/WarpXInitData.cpp
+++ b/Source/Initialization/WarpXInitData.cpp
@@ -22,6 +22,7 @@
#include "Filter/NCIGodfreyFilter.H"
#include "Particles/MultiParticleContainer.H"
#include "Parallelization/WarpXCommUtil.H"
+#include "Utils/MPIInitHelpers.H"
#include "Utils/WarpXAlgorithmSelection.H"
#include "Utils/WarpXConst.H"
#include "Utils/WarpXProfilerWrapper.H"
@@ -112,6 +113,8 @@ void
WarpX::InitData ()
{
WARPX_PROFILE("WarpX::InitData()");
+ utils::warpx_check_mpi_thread_level();
+
Print() << "WarpX (" << WarpX::Version() << ")\n";
#ifdef WARPX_QED
Print() << "PICSAR (" << WarpX::PicsarVersion() << ")\n";
diff --git a/Source/Utils/MPIInitHelpers.H b/Source/Utils/MPIInitHelpers.H
index 7147ee0e1..d0c2f8e1d 100644
--- a/Source/Utils/MPIInitHelpers.H
+++ b/Source/Utils/MPIInitHelpers.H
@@ -11,6 +11,13 @@
namespace utils
{
+ /** Return the required MPI threading
+ *
+ * @return the MPI_THREAD_* level required for MPI_Init_thread
+ */
+ int
+ warpx_mpi_thread_required ();
+
/** Initialize MPI
*
* @return pair(required, provided) of MPI thread level from MPI_Init_thread
@@ -21,11 +28,9 @@ namespace utils
/** Check if the requested MPI thread level is valid
*
* Prints warnings and notes otherwise.
- *
- * @param mpi_thread_levels pair(required, provided) of MPI thread level from MPI_Init_thread
*/
void
- warpx_check_mpi_thread_level (std::pair< int, int > const mpi_thread_levels);
+ warpx_check_mpi_thread_level ();
} // namespace utils
diff --git a/Source/Utils/MPIInitHelpers.cpp b/Source/Utils/MPIInitHelpers.cpp
index da9409e31..4dc5b9ae8 100644
--- a/Source/Utils/MPIInitHelpers.cpp
+++ b/Source/Utils/MPIInitHelpers.cpp
@@ -22,11 +22,10 @@
namespace utils
{
- std::pair< int, int >
- warpx_mpi_init (int argc, char* argv[])
+ int
+ warpx_mpi_thread_required ()
{
int thread_required = -1;
- int thread_provided = -1;
#ifdef AMREX_USE_MPI
thread_required = MPI_THREAD_SINGLE; // equiv. to MPI_Init
# ifdef AMREX_USE_OMP
@@ -35,6 +34,16 @@ namespace utils
# ifdef AMREX_MPI_THREAD_MULTIPLE // i.e. for async_io
thread_required = MPI_THREAD_MULTIPLE;
# endif
+#endif
+ return thread_required;
+ }
+
+ std::pair< int, int >
+ warpx_mpi_init (int argc, char* argv[])
+ {
+ int thread_required = warpx_mpi_thread_required();
+ int thread_provided = -1;
+#ifdef AMREX_USE_MPI
MPI_Init_thread(&argc, &argv, thread_required, &thread_provided);
#else
amrex::ignore_unused(argc, argv);
@@ -43,11 +52,12 @@ namespace utils
}
void
- warpx_check_mpi_thread_level (std::pair< int, int > const mpi_thread_levels)
+ warpx_check_mpi_thread_level ()
{
#ifdef AMREX_USE_MPI
- auto const thread_required = mpi_thread_levels.first;
- auto const thread_provided = mpi_thread_levels.second;
+ int thread_required = warpx_mpi_thread_required();
+ int thread_provided = -1;
+ MPI_Query_thread(&thread_provided);
auto mtn = amrex::ParallelDescriptor::mpi_level_to_string;
std::stringstream ss;
@@ -65,8 +75,6 @@ namespace utils
<< "communication performance.";
WarpX::GetInstance().RecordWarning("MPI", ss.str());
}
-#else
- amrex::ignore_unused(mpi_thread_levels);
#endif
}
diff --git a/Source/main.cpp b/Source/main.cpp
index d6939a6f1..5587d4f93 100644
--- a/Source/main.cpp
+++ b/Source/main.cpp
@@ -36,12 +36,10 @@ int main(int argc, char* argv[])
{
using namespace amrex;
- auto mpi_thread_levels = utils::warpx_mpi_init(argc, argv);
+ utils::warpx_mpi_init(argc, argv);
warpx_amrex_init(argc, argv);
- utils::warpx_check_mpi_thread_level(mpi_thread_levels);
-
#if defined(AMREX_USE_HIP) && defined(WARPX_USE_PSATD)
rocfft_setup();
#endif