aboutsummaryrefslogtreecommitdiff
path: root/Source
diff options
context:
space:
mode:
Diffstat (limited to 'Source')
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp18
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp30
-rw-r--r--Source/Utils/WarpXUtil.H17
-rw-r--r--Source/Utils/WarpXUtil.cpp12
4 files changed, 58 insertions, 19 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
index 86fe8f7bd..6d7d18b5f 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
@@ -8,6 +8,7 @@
#include "SpectralFieldData.H"
#include "Utils/WarpXAlgorithmSelection.H"
+#include "Utils/WarpXUtil.H"
#include "WarpX.H"
#include <AMReX_Array4.H>
@@ -104,6 +105,7 @@ SpectralFieldData::SpectralFieldData( const int lev,
const bool periodic_single_box)
{
amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev);
+ bool do_costs = WarpXUtilLoadBalance::doCosts(cost, realspace_ba, dm);
m_periodic_single_box = periodic_single_box;
@@ -147,7 +149,7 @@ SpectralFieldData::SpectralFieldData( const int lev,
// Loop over boxes and allocate the corresponding plan
// for each box owned by the local MPI proc
for ( MFIter mfi(spectralspace_ba, dm); mfi.isValid(); ++mfi ){
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
}
@@ -168,7 +170,7 @@ SpectralFieldData::SpectralFieldData( const int lev,
reinterpret_cast<AnyFFT::Complex*>( tmpSpectralField[mfi].dataPtr()),
AnyFFT::direction::C2R, AMREX_SPACEDIM);
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
wt = amrex::second() - wt;
@@ -193,10 +195,11 @@ SpectralFieldData::~SpectralFieldData()
* (in the spectral field specified by `field_index`) */
void
SpectralFieldData::ForwardTransform (const int lev,
- const MultiFab& mf, const int field_index,
+ const MultiFab& mf, const int field_index,
const int i_comp, const IntVect& stag)
{
amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev);
+ bool do_costs = WarpXUtilLoadBalance::doCosts(cost, mf.boxArray(), mf.DistributionMap());
// Check field index type, in order to apply proper shift in spectral space
#if (AMREX_SPACEDIM >= 2)
@@ -215,7 +218,7 @@ SpectralFieldData::ForwardTransform (const int lev,
// Note: we do NOT OpenMP parallelize here, since we use OpenMP threads for
// the FFTs on each box!
for ( MFIter mfi(mf); mfi.isValid(); ++mfi ){
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
}
@@ -283,7 +286,7 @@ SpectralFieldData::ForwardTransform (const int lev,
});
}
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
wt = amrex::second() - wt;
@@ -303,6 +306,7 @@ SpectralFieldData::BackwardTransform (const int lev,
const amrex::IntVect& fill_guards)
{
amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev);
+ bool do_costs = WarpXUtilLoadBalance::doCosts(cost, mf.boxArray(), mf.DistributionMap());
// Check field index type, in order to apply proper shift in spectral space
#if (AMREX_SPACEDIM >= 2)
@@ -339,7 +343,7 @@ SpectralFieldData::BackwardTransform (const int lev,
// Note: we do NOT OpenMP parallelize here, since we use OpenMP threads for
// the iFFTs on each box!
for ( MFIter mfi(mf); mfi.isValid(); ++mfi ){
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
}
@@ -443,7 +447,7 @@ SpectralFieldData::BackwardTransform (const int lev,
});
}
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
wt = amrex::second() - wt;
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
index a44ecb47e..57eac3a2c 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
@@ -408,6 +408,7 @@ SpectralFieldDataRZ::ForwardTransform (const int lev,
int const i_comp)
{
amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev);
+ bool do_costs = WarpXUtilLoadBalance::doCosts(cost, field_mf.boxArray(), field_mf.DistributionMap());
// Check field index type, in order to apply proper shift in spectral space.
// Only cell centered in r is supported.
@@ -430,7 +431,7 @@ SpectralFieldDataRZ::ForwardTransform (const int lev,
// Loop over boxes.
for (amrex::MFIter mfi(field_mf); mfi.isValid(); ++mfi){
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
}
@@ -451,7 +452,7 @@ SpectralFieldDataRZ::ForwardTransform (const int lev,
FABZForwardTransform(mfi, realspace_bx, tempHTransformedSplit, field_index, is_nodal_z);
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
wt = amrex::second() - wt;
@@ -469,6 +470,7 @@ SpectralFieldDataRZ::ForwardTransform (const int lev,
amrex::MultiFab const & field_mf_t, int const field_index_t)
{
amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev);
+ bool do_costs = WarpXUtilLoadBalance::doCosts(cost, field_mf_r.boxArray(), field_mf_r.DistributionMap());
// Check field index type, in order to apply proper shift in spectral space.
// Only cell centered in r is supported.
@@ -486,7 +488,7 @@ SpectralFieldDataRZ::ForwardTransform (const int lev,
// Loop over boxes.
for (amrex::MFIter mfi(field_mf_r); mfi.isValid(); ++mfi){
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
}
@@ -515,7 +517,7 @@ SpectralFieldDataRZ::ForwardTransform (const int lev,
FABZForwardTransform(mfi, realspace_bx, tempHTransformedSplit_p, field_index_r, is_nodal_z);
FABZForwardTransform(mfi, realspace_bx, tempHTransformedSplit_m, field_index_t, is_nodal_z);
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
wt = amrex::second() - wt;
@@ -532,6 +534,7 @@ SpectralFieldDataRZ::BackwardTransform (const int lev,
int const i_comp)
{
amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev);
+ bool do_costs = WarpXUtilLoadBalance::doCosts(cost, field_mf.boxArray(), field_mf.DistributionMap());
// Check field index type, in order to apply proper shift in spectral space.
bool const is_nodal_z = field_mf.is_nodal(1);
@@ -548,7 +551,7 @@ SpectralFieldDataRZ::BackwardTransform (const int lev,
// Loop over boxes.
for (amrex::MFIter mfi(field_mf); mfi.isValid(); ++mfi){
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
}
@@ -600,7 +603,7 @@ SpectralFieldDataRZ::BackwardTransform (const int lev,
field_mf_array(i,j,k,ic) = sign*field_mf_copy_array(ii,j,k,icomp);
});
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
wt = amrex::second() - wt;
@@ -617,6 +620,7 @@ SpectralFieldDataRZ::BackwardTransform (const int lev,
amrex::MultiFab& field_mf_t, int const field_index_t)
{
amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev);
+ bool do_costs = WarpXUtilLoadBalance::doCosts(cost, field_mf_r.boxArray(), field_mf_r.DistributionMap());
// Check field index type, in order to apply proper shift in spectral space.
bool const is_nodal_z = field_mf_r.is_nodal(1);
@@ -632,7 +636,7 @@ SpectralFieldDataRZ::BackwardTransform (const int lev,
// Loop over boxes.
for (amrex::MFIter mfi(field_mf_r); mfi.isValid(); ++mfi){
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
}
@@ -695,7 +699,7 @@ SpectralFieldDataRZ::BackwardTransform (const int lev,
}
});
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
wt = amrex::second() - wt;
@@ -727,10 +731,11 @@ void
SpectralFieldDataRZ::ApplyFilter (const int lev, int const field_index)
{
amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev);
+ bool do_costs = WarpXUtilLoadBalance::doCosts(cost, binomialfilter.boxArray(), binomialfilter.DistributionMap());
for (amrex::MFIter mfi(binomialfilter); mfi.isValid(); ++mfi){
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
}
@@ -756,7 +761,7 @@ SpectralFieldDataRZ::ApplyFilter (const int lev, int const field_index)
fields_arr(i,j,k,ic) *= filter_r_arr[ir]*filter_z_arr[j];
});
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
wt = amrex::second() - wt;
@@ -771,10 +776,11 @@ SpectralFieldDataRZ::ApplyFilter (const int lev, int const field_index1,
int const field_index2, int const field_index3)
{
amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(lev);
+ bool do_costs = WarpXUtilLoadBalance::doCosts(cost, binomialfilter.boxArray(), binomialfilter.DistributionMap());
for (amrex::MFIter mfi(binomialfilter); mfi.isValid(); ++mfi){
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
}
@@ -804,7 +810,7 @@ SpectralFieldDataRZ::ApplyFilter (const int lev, int const field_index1,
fields_arr(i,j,k,ic3) *= filter_r_arr[ir]*filter_z_arr[j];
});
- if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
+ if (do_costs)
{
amrex::Gpu::synchronize();
wt = amrex::second() - wt;
diff --git a/Source/Utils/WarpXUtil.H b/Source/Utils/WarpXUtil.H
index 61952da17..dfbed3e82 100644
--- a/Source/Utils/WarpXUtil.H
+++ b/Source/Utils/WarpXUtil.H
@@ -8,8 +8,11 @@
#ifndef WARPX_UTILS_H_
#define WARPX_UTILS_H_
+#include <AMReX_BoxArray.H>
+#include <AMReX_DistributionMapping.H>
#include <AMReX_Extension.H>
#include <AMReX_GpuQualifiers.H>
+#include <AMReX_LayoutData.H>
#include <AMReX_ParmParse.H>
#include <AMReX_Parser.H>
#include <AMReX_REAL.H>
@@ -397,4 +400,18 @@ namespace WarpXUtilStr
}
+namespace WarpXUtilLoadBalance
+{
+ /** \brief We only want to update the cost data if the grids we are working on
+ * are the main grids, i.e. not the PML grids. This function returns whether
+ * this is the case or not.
+ * @param[in] cost pointer to the cost data
+ * @param[in] ba the grids to check
+ * @param[in] dm the dmap to check
+ * @return consistent whether the grids are consistent or not.
+ */
+ bool doCosts (const amrex::LayoutData<amrex::Real>* cost, const amrex::BoxArray ba,
+ const amrex::DistributionMapping& dm);
+}
+
#endif //WARPX_UTILS_H_
diff --git a/Source/Utils/WarpXUtil.cpp b/Source/Utils/WarpXUtil.cpp
index 03b3768a7..f7b89ee0d 100644
--- a/Source/Utils/WarpXUtil.cpp
+++ b/Source/Utils/WarpXUtil.cpp
@@ -740,3 +740,15 @@ namespace WarpXUtilStr
}
}
+
+namespace WarpXUtilLoadBalance
+{
+ bool doCosts (const amrex::LayoutData<amrex::Real>* costs, const amrex::BoxArray ba,
+ const amrex::DistributionMapping& dm)
+ {
+ bool consistent = costs && (dm == costs->DistributionMap()) &&
+ (ba.CellEqual(costs->boxArray())) &&
+ (WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers);
+ return consistent;
+ }
+}