diff options
Diffstat (limited to 'Source/FieldSolver/SpectralSolver')
-rw-r--r-- | Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H | 10 | ||||
-rw-r--r-- | Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp | 39 |
2 files changed, 36 insertions, 13 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H index e4d522c9f..8b131894d 100644 --- a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H +++ b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H @@ -11,6 +11,7 @@ #include "SpectralFieldData.H" #include "SpectralHankelTransform/SpectralHankelTransformer.H" #include "SpectralKSpaceRZ.H" +#include "FieldSolver/SpectralSolver/AnyFFT.H" #include <AMReX_MultiFab.H> @@ -25,13 +26,8 @@ class SpectralFieldDataRZ // Define the FFTplans type, which holds one fft plan per box // (plans are only initialized for the boxes that are owned by // the local MPI rank) -#if defined(AMREX_USE_CUDA) - using FFTplans = amrex::LayoutData<cufftHandle>; -#elif defined(AMREX_USE_HIP) - using FFTplans = amrex::LayoutData<rocfft_plan>; -#else - using FFTplans = amrex::LayoutData<fftw_plan>; -#endif + using FFTplans = amrex::LayoutData<AnyFFT::VendorFFTPlan>; + // Similarly, define the Hankel transformers and filter for each box. using MultiSpectralHankelTransformer = amrex::LayoutData<SpectralHankelTransformer>; diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp index b13c3121a..eecae3037 100644 --- a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp +++ b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp @@ -7,10 +7,14 @@ #include "SpectralFieldDataRZ.H" #include "Utils/WarpXUtil.H" +#include "FieldSolver/SpectralSolver/AnyFFT.H" #include "WarpX.H" #include <ablastr/warn_manager/WarnManager.H> +#include <AMReX_Config.H> + + using amrex::operator""_rt; /* \brief Initialize fields in spectral space, and FFT plans @@ -162,21 +166,31 @@ SpectralFieldDataRZ::SpectralFieldDataRZ (const int lev, howmany_dims[1].os = 1; forward_plan[mfi] = // Note that AMReX FAB are Fortran-order. - fftw_plan_guru_dft(1, // int rank +#ifdef AMREX_USE_FLOAT + fftwf_plan_guru_dft +#else + fftw_plan_guru_dft +#endif + (1, // int rank dims, 2, // int howmany_rank, howmany_dims, - reinterpret_cast<fftw_complex*>(tempHTransformed[mfi].dataPtr()), // fftw_complex *in - reinterpret_cast<fftw_complex*>(tmpSpectralField[mfi].dataPtr()), // fftw_complex *out + reinterpret_cast<AnyFFT::Complex*>(tempHTransformed[mfi].dataPtr()), // complex *in + reinterpret_cast<AnyFFT::Complex*>(tmpSpectralField[mfi].dataPtr()), // complex *out FFTW_FORWARD, // int sign FFTW_ESTIMATE); // unsigned flags backward_plan[mfi] = - fftw_plan_guru_dft(1, // int rank +#ifdef AMREX_USE_FLOAT + fftwf_plan_guru_dft +#else + fftw_plan_guru_dft +#endif + (1, // int rank dims, 2, // int howmany_rank, howmany_dims, - reinterpret_cast<fftw_complex*>(tmpSpectralField[mfi].dataPtr()), // fftw_complex *in - reinterpret_cast<fftw_complex*>(tempHTransformed[mfi].dataPtr()), // fftw_complex *out + reinterpret_cast<AnyFFT::Complex*>(tmpSpectralField[mfi].dataPtr()), // complex *in + reinterpret_cast<AnyFFT::Complex*>(tempHTransformed[mfi].dataPtr()), // complex *out FFTW_BACKWARD, // int sign FFTW_ESTIMATE); // unsigned flags #endif @@ -201,8 +215,13 @@ SpectralFieldDataRZ::~SpectralFieldDataRZ() rocfft_plan_destroy(backward_plan[mfi]); #else // Destroy FFTW plans. +# ifdef AMREX_USE_FLOAT + fftwf_destroy_plan(forward_plan[mfi]); + fftwf_destroy_plan(backward_plan[mfi]); +# else fftw_destroy_plan(forward_plan[mfi]); fftw_destroy_plan(backward_plan[mfi]); +# endif #endif } } @@ -280,7 +299,11 @@ SpectralFieldDataRZ::FABZForwardTransform (amrex::MFIter const & mfi, amrex::Box amrex::The_Arena()->free(buffer); result = rocfft_execution_info_destroy(execinfo); #else +# ifdef AMREX_USE_FLOAT + fftwf_execute(forward_plan[mfi]); +# else fftw_execute(forward_plan[mfi]); +# endif #endif // Copy the spectral-space field `tmpSpectralField` to the appropriate @@ -393,7 +416,11 @@ SpectralFieldDataRZ::FABZBackwardTransform (amrex::MFIter const & mfi, amrex::Bo amrex::The_Arena()->free(buffer); result = rocfft_execution_info_destroy(execinfo); #else +# ifdef AMREX_USE_FLOAT + fftwf_execute(backward_plan[mfi]); +# else fftw_execute(backward_plan[mfi]); +# endif #endif // Copy the interleaved complex to the split complex. |