aboutsummaryrefslogtreecommitdiff
path: root/Source/FieldSolver/SpectralSolver
diff options
context:
space:
mode:
Diffstat (limited to 'Source/FieldSolver/SpectralSolver')
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H10
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp39
2 files changed, 36 insertions, 13 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H
index e4d522c9f..8b131894d 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H
@@ -11,6 +11,7 @@
#include "SpectralFieldData.H"
#include "SpectralHankelTransform/SpectralHankelTransformer.H"
#include "SpectralKSpaceRZ.H"
+#include "FieldSolver/SpectralSolver/AnyFFT.H"
#include <AMReX_MultiFab.H>
@@ -25,13 +26,8 @@ class SpectralFieldDataRZ
// Define the FFTplans type, which holds one fft plan per box
// (plans are only initialized for the boxes that are owned by
// the local MPI rank)
-#if defined(AMREX_USE_CUDA)
- using FFTplans = amrex::LayoutData<cufftHandle>;
-#elif defined(AMREX_USE_HIP)
- using FFTplans = amrex::LayoutData<rocfft_plan>;
-#else
- using FFTplans = amrex::LayoutData<fftw_plan>;
-#endif
+ using FFTplans = amrex::LayoutData<AnyFFT::VendorFFTPlan>;
+
// Similarly, define the Hankel transformers and filter for each box.
using MultiSpectralHankelTransformer = amrex::LayoutData<SpectralHankelTransformer>;
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
index b13c3121a..eecae3037 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
@@ -7,10 +7,14 @@
#include "SpectralFieldDataRZ.H"
#include "Utils/WarpXUtil.H"
+#include "FieldSolver/SpectralSolver/AnyFFT.H"
#include "WarpX.H"
#include <ablastr/warn_manager/WarnManager.H>
+#include <AMReX_Config.H>
+
+
using amrex::operator""_rt;
/* \brief Initialize fields in spectral space, and FFT plans
@@ -162,21 +166,31 @@ SpectralFieldDataRZ::SpectralFieldDataRZ (const int lev,
howmany_dims[1].os = 1;
forward_plan[mfi] =
// Note that AMReX FAB are Fortran-order.
- fftw_plan_guru_dft(1, // int rank
+#ifdef AMREX_USE_FLOAT
+ fftwf_plan_guru_dft
+#else
+ fftw_plan_guru_dft
+#endif
+ (1, // int rank
dims,
2, // int howmany_rank,
howmany_dims,
- reinterpret_cast<fftw_complex*>(tempHTransformed[mfi].dataPtr()), // fftw_complex *in
- reinterpret_cast<fftw_complex*>(tmpSpectralField[mfi].dataPtr()), // fftw_complex *out
+ reinterpret_cast<AnyFFT::Complex*>(tempHTransformed[mfi].dataPtr()), // complex *in
+ reinterpret_cast<AnyFFT::Complex*>(tmpSpectralField[mfi].dataPtr()), // complex *out
FFTW_FORWARD, // int sign
FFTW_ESTIMATE); // unsigned flags
backward_plan[mfi] =
- fftw_plan_guru_dft(1, // int rank
+#ifdef AMREX_USE_FLOAT
+ fftwf_plan_guru_dft
+#else
+ fftw_plan_guru_dft
+#endif
+ (1, // int rank
dims,
2, // int howmany_rank,
howmany_dims,
- reinterpret_cast<fftw_complex*>(tmpSpectralField[mfi].dataPtr()), // fftw_complex *in
- reinterpret_cast<fftw_complex*>(tempHTransformed[mfi].dataPtr()), // fftw_complex *out
+ reinterpret_cast<AnyFFT::Complex*>(tmpSpectralField[mfi].dataPtr()), // complex *in
+ reinterpret_cast<AnyFFT::Complex*>(tempHTransformed[mfi].dataPtr()), // complex *out
FFTW_BACKWARD, // int sign
FFTW_ESTIMATE); // unsigned flags
#endif
@@ -201,8 +215,13 @@ SpectralFieldDataRZ::~SpectralFieldDataRZ()
rocfft_plan_destroy(backward_plan[mfi]);
#else
// Destroy FFTW plans.
+# ifdef AMREX_USE_FLOAT
+ fftwf_destroy_plan(forward_plan[mfi]);
+ fftwf_destroy_plan(backward_plan[mfi]);
+# else
fftw_destroy_plan(forward_plan[mfi]);
fftw_destroy_plan(backward_plan[mfi]);
+# endif
#endif
}
}
@@ -280,7 +299,11 @@ SpectralFieldDataRZ::FABZForwardTransform (amrex::MFIter const & mfi, amrex::Box
amrex::The_Arena()->free(buffer);
result = rocfft_execution_info_destroy(execinfo);
#else
+# ifdef AMREX_USE_FLOAT
+ fftwf_execute(forward_plan[mfi]);
+# else
fftw_execute(forward_plan[mfi]);
+# endif
#endif
// Copy the spectral-space field `tmpSpectralField` to the appropriate
@@ -393,7 +416,11 @@ SpectralFieldDataRZ::FABZBackwardTransform (amrex::MFIter const & mfi, amrex::Bo
amrex::The_Arena()->free(buffer);
result = rocfft_execution_info_destroy(execinfo);
#else
+# ifdef AMREX_USE_FLOAT
+ fftwf_execute(backward_plan[mfi]);
+# else
fftw_execute(backward_plan[mfi]);
+# endif
#endif
// Copy the interleaved complex to the split complex.