aboutsummaryrefslogtreecommitdiff
path: root/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp')
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp45
1 files changed, 36 insertions, 9 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
index 3fd7177e5..7d712ba03 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
@@ -25,13 +25,18 @@ SpectralFieldData::SpectralFieldData( const BoxArray& realspace_ba,
tmpRealField = SpectralField(realspace_ba, dm, 1, 0);
tmpSpectralField = SpectralField(spectralspace_ba, dm, 1, 0);
- // Allocate the vectors that allow to shift between nodal and cell-centered
- for (int i_dim=0; i_dim<AMREX_SPACEDIM; i_dim++) {
- shift_C2N[i_dim] = k_space.AllocateAndFillSpectralShiftFactor(
- dm, i_dim, ShiftType::CenteredToNodal );
- shift_N2C[i_dim] = k_space.AllocateAndFillSpectralShiftFactor(
- dm, i_dim, ShiftType::NodalToCentered );
- }
+ // Allocate the coefficients that allow to shift between nodal and cell-centered
+ xshift_C2N = k_space.getSpectralShiftFactor(dm, 0, ShiftType::CenteredToNodal);
+ xshift_N2C = k_space.getSpectralShiftFactor(dm, 0, ShiftType::NodalToCentered);
+#if (AMREX_SPACEDIM == 3)
+ yshift_C2N = k_space.getSpectralShiftFactor(dm, 1, ShiftType::CenteredToNodal);
+ yshift_N2C = k_space.getSpectralShiftFactor(dm, 1, ShiftType::NodalToCentered);
+ zshift_C2N = k_space.getSpectralShiftFactor(dm, 2, ShiftType::CenteredToNodal);
+ zshift_N2C = k_space.getSpectralShiftFactor(dm, 2, ShiftType::NodalToCentered);
+#else
+ zshift_C2N = k_space.getSpectralShiftFactor(dm, 1, ShiftType::CenteredToNodal);
+ zshift_N2C = k_space.getSpectralShiftFactor(dm, 1, ShiftType::NodalToCentered);
+#endif
// Allocate and initialize the FFT plans
forward_plan = FFTplans(spectralspace_ba, dm);
@@ -139,6 +144,15 @@ void
SpectralFieldData::BackwardTransform( MultiFab& mf,
const int field_index, const int i_comp )
{
+ // Check field index type, in order to apply proper shift in spectral space
+ const bool is_nodal_x = mf.is_nodal(0);
+#if (AMREX_SPACEDIM == 3)
+ const bool is_nodal_y = mf.is_nodal(1);
+ const bool is_nodal_z = mf.is_nodal(2);
+#else
+ const bool is_nodal_z = mf.is_nodal(1);
+#endif
+
// Loop over boxes
for ( MFIter mfi(mf); mfi.isValid(); ++mfi ){
@@ -148,10 +162,22 @@ SpectralFieldData::BackwardTransform( MultiFab& mf,
SpectralField& field = getSpectralField( field_index );
Array4<const Complex> field_arr = field[mfi].array();
Array4<Complex> tmp_arr = tmpSpectralField[mfi].array();
+ const Complex* xshift_C2N_arr = xshift_C2N[mfi].dataPtr();
+ const Complex* yshift_C2N_arr = yshift_C2N[mfi].dataPtr();
+ const Complex* zshift_C2N_arr = zshift_C2N[mfi].dataPtr();
+ // Loop over indices within one box
const Box spectralspace_bx = tmpSpectralField[mfi].box();
ParallelFor( spectralspace_bx,
[=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept {
- tmp_arr(i,j,k) = field_arr(i,j,k);
+ Complex spectral_field_value = field_arr(i,j,k);
+ // Apply proper shift in each dimension
+ if (is_nodal_x==false) spectral_field_value *= xshift_C2N_arr[i];
+#if (AMREX_SPACEDIM == 3)
+ if (is_nodal_y==false) spectral_field_value *= yshift_C2N_arr[j];
+#endif
+ if (is_nodal_z==false) spectral_field_value *= zshift_C2N_arr[k];
+ // Copy field into temporary array
+ tmp_arr(i,j,k) = spectral_field_value;
});
}
@@ -170,7 +196,8 @@ SpectralFieldData::BackwardTransform( MultiFab& mf,
// Normalize (divide by 1/N) since the FFT result in a factor N
{
Box bx = mf[mfi].box();
- const Box realspace_bx = bx.enclosedCells(); // discards last point in each nodal direction
+ const Box realspace_bx = bx.enclosedCells();
+ // `enclosedells` discards last point in each nodal direction
AMREX_ALWAYS_ASSERT( realspace_bx == tmpRealField[mfi].box() );
Array4<Real> mf_arr = mf[mfi].array();
Array4<const Complex> tmp_arr = tmpRealField[mfi].array();