diff options
Diffstat (limited to 'Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp')
-rw-r--r-- | Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp index 86fe8f7bd..6d7d18b5f 100644 --- a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp +++ b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp @@ -8,6 +8,7 @@ #include "SpectralFieldData.H" #include "Utils/WarpXAlgorithmSelection.H" +#include "Utils/WarpXUtil.H" #include "WarpX.H" #include <AMReX_Array4.H> @@ -104,6 +105,7 @@ SpectralFieldData::SpectralFieldData( const int lev, const bool periodic_single_box) { amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev); + bool do_costs = WarpXUtilLoadBalance::doCosts(cost, realspace_ba, dm); m_periodic_single_box = periodic_single_box; @@ -147,7 +149,7 @@ SpectralFieldData::SpectralFieldData( const int lev, // Loop over boxes and allocate the corresponding plan // for each box owned by the local MPI proc for ( MFIter mfi(spectralspace_ba, dm); mfi.isValid(); ++mfi ){ - if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers) + if (do_costs) { amrex::Gpu::synchronize(); } @@ -168,7 +170,7 @@ SpectralFieldData::SpectralFieldData( const int lev, reinterpret_cast<AnyFFT::Complex*>( tmpSpectralField[mfi].dataPtr()), AnyFFT::direction::C2R, AMREX_SPACEDIM); - if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers) + if (do_costs) { amrex::Gpu::synchronize(); wt = amrex::second() - wt; @@ -193,10 +195,11 @@ SpectralFieldData::~SpectralFieldData() * (in the spectral field specified by `field_index`) */ void SpectralFieldData::ForwardTransform (const int lev, - const MultiFab& mf, const int field_index, + const MultiFab& mf, const int field_index, const int i_comp, const IntVect& stag) { amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev); + bool do_costs = WarpXUtilLoadBalance::doCosts(cost, mf.boxArray(), mf.DistributionMap()); // Check field index type, in order to apply proper shift in spectral space #if (AMREX_SPACEDIM >= 2) @@ -215,7 +218,7 @@ SpectralFieldData::ForwardTransform (const int lev, // Note: we do NOT OpenMP parallelize here, since we use OpenMP threads for // the FFTs on each box! for ( MFIter mfi(mf); mfi.isValid(); ++mfi ){ - if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers) + if (do_costs) { amrex::Gpu::synchronize(); } @@ -283,7 +286,7 @@ SpectralFieldData::ForwardTransform (const int lev, }); } - if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers) + if (do_costs) { amrex::Gpu::synchronize(); wt = amrex::second() - wt; @@ -303,6 +306,7 @@ SpectralFieldData::BackwardTransform (const int lev, const amrex::IntVect& fill_guards) { amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev); + bool do_costs = WarpXUtilLoadBalance::doCosts(cost, mf.boxArray(), mf.DistributionMap()); // Check field index type, in order to apply proper shift in spectral space #if (AMREX_SPACEDIM >= 2) @@ -339,7 +343,7 @@ SpectralFieldData::BackwardTransform (const int lev, // Note: we do NOT OpenMP parallelize here, since we use OpenMP threads for // the iFFTs on each box! for ( MFIter mfi(mf); mfi.isValid(); ++mfi ){ - if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers) + if (do_costs) { amrex::Gpu::synchronize(); } @@ -443,7 +447,7 @@ SpectralFieldData::BackwardTransform (const int lev, }); } - if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers) + if (do_costs) { amrex::Gpu::synchronize(); wt = amrex::second() - wt; |