aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Source/FieldSolver/SpectralSolver/AnyFFT.H91
-rw-r--r--Source/FieldSolver/SpectralSolver/Make.package5
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldData.H17
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp218
-rw-r--r--Source/FieldSolver/SpectralSolver/WrapCuFFT.cpp119
-rw-r--r--Source/FieldSolver/SpectralSolver/WrapFFTW.cpp71
-rw-r--r--Source/Utils/WarpX_Complex.H33
7 files changed, 311 insertions, 243 deletions
diff --git a/Source/FieldSolver/SpectralSolver/AnyFFT.H b/Source/FieldSolver/SpectralSolver/AnyFFT.H
new file mode 100644
index 000000000..a02f86343
--- /dev/null
+++ b/Source/FieldSolver/SpectralSolver/AnyFFT.H
@@ -0,0 +1,91 @@
+#ifndef ANYFFT_H_
+#define ANYFFT_H_
+
+#ifdef AMREX_USE_GPU
+# include <cufft.h>
+#else
+# include <fftw3.h>
+#endif
+
+#include <AMReX_LayoutData.H>
+
+/**
+ * Wrapper around FFT libraries. The header file defines the API and the base types
+ * (Complex and VendorFFTPlan), and the implementation for different FFT libraries is
+ * done in different cpp files. This wrapper only depends on the underlying FFT library
+ * AND on AMReX (There is no dependence on WarpX).
+ */
+namespace AnyFFT
+{
+ // First, define library-dependent types (complex, FFT plan)
+
+ /** Complex type for FFT, depends on FFT library */
+#ifdef AMREX_USE_GPU
+# ifdef AMREX_USE_FLOAT
+ using Complex = cuComplex;
+# else
+ using Complex = cuDoubleComplex;
+# endif
+#else
+# ifdef AMREX_USE_FLOAT
+ using Complex = fftwf_complex;
+# else
+ using Complex = fftw_complex;
+# endif
+#endif
+
+ /** Library-dependent FFT plans type, which holds one fft plan per box
+ * (plans are only initialized for the boxes that are owned by the local MPI rank).
+ */
+#ifdef AMREX_USE_GPU
+ using VendorFFTPlan = cufftHandle;
+#else
+# ifdef AMREX_USE_FLOAT
+ using VendorFFTPlan = fftwf_plan;
+# else
+ using VendorFFTPlan = fftw_plan;
+# endif
+#endif
+
+ // Second, define library-independent API
+
+ /** Direction in which the FFT is performed. */
+ enum struct direction {R2C, C2R};
+
+ /** This struct contains the vendor FFT plan and additional metadata
+ */
+ struct FFTplan
+ {
+ amrex::Real* m_real_array; /**< pointer to real array */
+ Complex* m_complex_array; /**< pointer to complex array */
+ VendorFFTPlan m_plan; /**< Vendor FFT plan */
+ direction m_dir; /**< direction (C2R or R2C) */
+ int m_dim; /**< Dimensionality of the FFT plan */
+ };
+
+ /** Collection of FFT plans, one FFTplan per box */
+ using FFTplans = amrex::LayoutData<FFTplan>;
+
+ /** \brief create FFT plan for the backend FFT library.
+ * \param[in] real_size Size of the real array, along each dimension.
+ * Only the first dim elements are used.
+ * \param[out] real_array Real array from/to where R2C/C2R FFT is performed
+ * \param[out] complex_array Complex array to/from where R2C/C2R FFT is performed
+ * \param[in] dir direction, either R2C or C2R
+ * \param[in] dim direction, number of dimensions of the arrays. Must be <= AMREX_SPACEDIM.
+ */
+ FFTplan CreatePlan(const amrex::IntVect& real_size, amrex::Real * const real_array,
+ Complex * const complex_array, const direction dir, const int dim);
+
+ /** \brief Destroy library FFT plan.
+ * \param[out] fft_plan plan to destroy
+ */
+ void DestroyPlan(FFTplan& fft_plan);
+
+ /** \brief Perform FFT with backend library.
+ * \param[out] fft_plan plan for which the FFT is performed
+ */
+ void Execute(FFTplan& fft_plan);
+}
+
+#endif // ANYFFT_H_
diff --git a/Source/FieldSolver/SpectralSolver/Make.package b/Source/FieldSolver/SpectralSolver/Make.package
index 549347135..ba41619a1 100644
--- a/Source/FieldSolver/SpectralSolver/Make.package
+++ b/Source/FieldSolver/SpectralSolver/Make.package
@@ -1,6 +1,11 @@
CEXE_sources += SpectralSolver.cpp
CEXE_sources += SpectralFieldData.cpp
CEXE_sources += SpectralKSpace.cpp
+ifeq ($(USE_CUDA),TRUE)
+ CEXE_sources += WrapCuFFT.cpp
+else
+ CEXE_sources += WrapFFTW.cpp
+endif
ifeq ($(USE_RZ),TRUE)
CEXE_sources += SpectralSolverRZ.cpp
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldData.H b/Source/FieldSolver/SpectralSolver/SpectralFieldData.H
index d80ed9c96..f618fda35 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldData.H
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldData.H
@@ -10,6 +10,8 @@
#include "Utils/WarpX_Complex.H"
#include "SpectralKSpace.H"
+#include "AnyFFT.H"
+
#include <AMReX_MultiFab.H>
#include <string>
@@ -38,19 +40,6 @@ struct SpectralPMLIndex {
class SpectralFieldData
{
- // Define the FFTplans type, which holds one fft plan per box
- // (plans are only initialized for the boxes that are owned by
- // the local MPI rank)
-#ifdef AMREX_USE_GPU
- using FFTplans = amrex::LayoutData<cufftHandle>;
-#else
-# ifdef AMREX_USE_FLOAT
- using FFTplans = amrex::LayoutData<fftwf_plan>;
-# else
- using FFTplans = amrex::LayoutData<fftw_plan>;
-# endif
-#endif
-
public:
SpectralFieldData( const amrex::BoxArray& realspace_ba,
const SpectralKSpace& k_space,
@@ -72,7 +61,7 @@ class SpectralFieldData
// right before/after the Fourier transform
SpectralField tmpSpectralField; // contains Complexs
amrex::MultiFab tmpRealField; // contains Reals
- FFTplans forward_plan, backward_plan;
+ AnyFFT::FFTplans forward_plan, backward_plan;
// Correcting "shift" factors when performing FFT from/to
// a cell-centered grid in real space, instead of a nodal grid
SpectralShiftFactor xshift_FFTfromCell, xshift_FFTtoCell,
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
index 76299b7de..0f49e695b 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
@@ -9,26 +9,10 @@
#include <map>
-
-
#if WARPX_USE_PSATD
using namespace amrex;
-#ifdef AMREX_USE_GPU
-# ifdef AMREX_USE_FLOAT
-using cuPrecisionComplex = cuComplex;
-# else
-using cuPrecisionComplex = cuDoubleComplex;
-# endif
-#else
-# ifdef AMREX_USE_FLOAT
-using fftw_precision_complex = fftwf_complex;
-# else
-using fftw_precision_complex = fftw_complex;
-# endif
-#endif
-
/* \brief Initialize fields in spectral space, and FFT plans */
SpectralFieldData::SpectralFieldData( const amrex::BoxArray& realspace_ba,
const SpectralKSpace& k_space,
@@ -73,8 +57,8 @@ SpectralFieldData::SpectralFieldData( const amrex::BoxArray& realspace_ba,
#endif
// Allocate and initialize the FFT plans
- forward_plan = FFTplans(spectralspace_ba, dm);
- backward_plan = FFTplans(spectralspace_ba, dm);
+ forward_plan = AnyFFT::FFTplans(spectralspace_ba, dm);
+ backward_plan = AnyFFT::FFTplans(spectralspace_ba, dm);
// Loop over boxes and allocate the corresponding plan
// for each box owned by the local MPI proc
for ( MFIter mfi(spectralspace_ba, dm); mfi.isValid(); ++mfi ){
@@ -82,98 +66,16 @@ SpectralFieldData::SpectralFieldData( const amrex::BoxArray& realspace_ba,
// differ when using real-to-complex FFT. When initializing
// the FFT plan, the valid dimensions are those of the real-space box.
IntVect fft_size = realspace_ba[mfi].length();
-#ifdef AMREX_USE_GPU
- // Create cuFFT plans
- // Creating 3D plan for real to complex -- double precision
- // Assuming CUDA is used for programming GPU
- // Note that D2Z is inherently forward plan
- // and Z2D is inherently backward plan
- cufftResult result;
-# if (AMREX_SPACEDIM == 3)
- result = cufftPlan3d( &forward_plan[mfi], fft_size[2], fft_size[1],fft_size[0],
-# ifdef AMREX_USE_FLOAT
- CUFFT_R2C);
-# else
- CUFFT_D2Z);
-# endif
- if ( result != CUFFT_SUCCESS ) {
- amrex::Print() << " cufftplan3d forward failed! Error: " <<
- cufftErrorToString(result) << "\n";
- }
- result = cufftPlan3d( &backward_plan[mfi], fft_size[2], fft_size[1],fft_size[0],
-# ifdef AMREX_USE_FLOAT
- CUFFT_C2R);
-# else
- CUFFT_Z2D);
-# endif
- if ( result != CUFFT_SUCCESS ) {
- amrex::Print() << " cufftplan3d backward failed! Error: " <<
- cufftErrorToString(result) << "\n";
- }
-# else
- result = cufftPlan2d( &forward_plan[mfi], fft_size[1], fft_size[0],
-# ifdef AMREX_USE_FLOAT
- CUFFT_R2C);
-# else
- CUFFT_D2Z);
-# endif
- if ( result != CUFFT_SUCCESS ) {
- amrex::Print() << " cufftplan2d forward failed! Error: " <<
- cufftErrorToString(result) << "\n";
- }
-
- result = cufftPlan2d( &backward_plan[mfi], fft_size[1], fft_size[0],
-# ifdef AMREX_USE_FLOAT
- CUFFT_C2R);
-# else
- CUFFT_Z2D);
-# endif
- if ( result != CUFFT_SUCCESS ) {
- amrex::Print() << " cufftplan2d backward failed! Error: " <<
- cufftErrorToString(result) << "\n";
- }
-# endif
+ forward_plan[mfi] = AnyFFT::CreatePlan(
+ fft_size, tmpRealField[mfi].dataPtr(),
+ reinterpret_cast<AnyFFT::Complex*>( tmpSpectralField[mfi].dataPtr()),
+ AnyFFT::direction::R2C, AMREX_SPACEDIM);
-#else
- // Create FFTW plans
- forward_plan[mfi] =
- // Swap dimensions: AMReX FAB are Fortran-order but FFTW is C-order
-# if (AMREX_SPACEDIM == 3)
-# ifdef AMREX_USE_FLOAT
- fftwf_plan_dft_r2c_3d( fft_size[2], fft_size[1], fft_size[0],
-# else
- fftw_plan_dft_r2c_3d( fft_size[2], fft_size[1], fft_size[0],
-# endif
-# else
-# ifdef AMREX_USE_FLOAT
- fftwf_plan_dft_r2c_2d( fft_size[1], fft_size[0],
-# else
- fftw_plan_dft_r2c_2d( fft_size[1], fft_size[0],
-# endif
-# endif
- tmpRealField[mfi].dataPtr(),
- reinterpret_cast<fftw_precision_complex*>( tmpSpectralField[mfi].dataPtr() ),
- FFTW_ESTIMATE );
- backward_plan[mfi] =
- // Swap dimensions: AMReX FAB are Fortran-order but FFTW is C-order
-# if (AMREX_SPACEDIM == 3)
-# ifdef AMREX_USE_FLOAT
- fftwf_plan_dft_c2r_3d( fft_size[2], fft_size[1], fft_size[0],
-# else
- fftw_plan_dft_c2r_3d( fft_size[2], fft_size[1], fft_size[0],
-# endif
-# else
-# ifdef AMREX_USE_FLOAT
- fftwf_plan_dft_c2r_2d( fft_size[1], fft_size[0],
-# else
- fftw_plan_dft_c2r_2d( fft_size[1], fft_size[0],
-# endif
-# endif
- reinterpret_cast<fftw_precision_complex*>( tmpSpectralField[mfi].dataPtr() ),
- tmpRealField[mfi].dataPtr(),
- FFTW_ESTIMATE );
-#endif
+ backward_plan[mfi] = AnyFFT::CreatePlan(
+ fft_size, tmpRealField[mfi].dataPtr(),
+ reinterpret_cast<AnyFFT::Complex*>( tmpSpectralField[mfi].dataPtr()),
+ AnyFFT::direction::C2R, AMREX_SPACEDIM);
}
}
@@ -182,20 +84,8 @@ SpectralFieldData::~SpectralFieldData()
{
if (tmpRealField.size() > 0){
for ( MFIter mfi(tmpRealField); mfi.isValid(); ++mfi ){
-#ifdef AMREX_USE_GPU
- // Destroy cuFFT plans
- cufftDestroy( forward_plan[mfi] );
- cufftDestroy( backward_plan[mfi] );
-#else
- // Destroy FFTW plans
-# ifdef AMREX_USE_FLOAT
- fftwf_destroy_plan( forward_plan[mfi] );
- fftwf_destroy_plan( backward_plan[mfi] );
-# else
- fftw_destroy_plan( forward_plan[mfi] );
- fftw_destroy_plan( backward_plan[mfi] );
-# endif
-#endif
+ AnyFFT::DestroyPlan(forward_plan[mfi]);
+ AnyFFT::DestroyPlan(backward_plan[mfi]);
}
}
}
@@ -243,34 +133,7 @@ SpectralFieldData::ForwardTransform( const MultiFab& mf,
}
// Perform Fourier transform from `tmpRealField` to `tmpSpectralField`
-#ifdef AMREX_USE_GPU
- // Perform Fast Fourier Transform on GPU using cuFFT
- // make sure that this is done on the same
- // GPU stream as the above copy
- cufftResult result;
- cudaStream_t stream = amrex::Gpu::Device::cudaStream();
- cufftSetStream ( forward_plan[mfi], stream);
-# ifdef AMREX_USE_FLOAT
- result = cufftExecR2C(
-# else
- result = cufftExecD2Z(
-# endif
- forward_plan[mfi],
- tmpRealField[mfi].dataPtr(),
- reinterpret_cast<cuPrecisionComplex*>(
- tmpSpectralField[mfi].dataPtr()) );
- if ( result != CUFFT_SUCCESS ) {
- amrex::Print() <<
- " forward transform using cufftExec failed ! Error: " <<
- cufftErrorToString(result) << "\n";
- }
-#else
-# ifdef AMREX_USE_FLOAT
- fftwf_execute( forward_plan[mfi] );
-# else
- fftw_execute( forward_plan[mfi] );
-# endif
-#endif
+ AnyFFT::Execute(forward_plan[mfi]);
// Copy the spectral-space field `tmpSpectralField` to the appropriate
// index of the FabArray `fields` (specified by `field_index`)
@@ -358,34 +221,7 @@ SpectralFieldData::BackwardTransform( MultiFab& mf,
}
// Perform Fourier transform from `tmpSpectralField` to `tmpRealField`
-#ifdef AMREX_USE_GPU
- // Perform Fast Fourier Transform on GPU using cuFFT.
- // make sure that this is done on the same
- // GPU stream as the above copy
- cufftResult result;
- cudaStream_t stream = amrex::Gpu::Device::cudaStream();
- cufftSetStream ( backward_plan[mfi], stream);
-# ifdef AMREX_USE_FLOAT
- result = cufftExecC2R(
-# else
- result = cufftExecZ2D(
-# endif
- backward_plan[mfi],
- reinterpret_cast<cuPrecisionComplex*>(
- tmpSpectralField[mfi].dataPtr()),
- tmpRealField[mfi].dataPtr() );
- if ( result != CUFFT_SUCCESS ) {
- amrex::Print() <<
- " Backward transform using cufftexec failed! Error: " <<
- cufftErrorToString(result) << "\n";
- }
-#else
-# ifdef AMREX_USE_FLOAT
- fftwf_execute( backward_plan[mfi] );
-# else
- fftw_execute( backward_plan[mfi] );
-# endif
-#endif
+ AnyFFT::Execute(backward_plan[mfi]);
// Copy the temporary field `tmpRealField` to the real-space field `mf`
// (only in the valid cells ; not in the guard cells)
@@ -430,30 +266,4 @@ SpectralFieldData::BackwardTransform( MultiFab& mf,
}
}
-#ifdef AMREX_USE_GPU
-std::string
-SpectralFieldData::cufftErrorToString (const cufftResult& err)
-{
- const auto res2string = std::map<cufftResult, std::string>{
- {CUFFT_SUCCESS, "CUFFT_SUCCESS"},
- {CUFFT_INVALID_PLAN,"CUFFT_INVALID_PLAN"},
- {CUFFT_ALLOC_FAILED,"CUFFT_ALLOC_FAILED"},
- {CUFFT_INVALID_TYPE,"CUFFT_INVALID_TYPE"},
- {CUFFT_INVALID_VALUE,"CUFFT_INVALID_VALUE"},
- {CUFFT_INTERNAL_ERROR,"CUFFT_INTERNAL_ERROR"},
- {CUFFT_EXEC_FAILED,"CUFFT_EXEC_FAILED"},
- {CUFFT_SETUP_FAILED,"CUFFT_SETUP_FAILED"},
- {CUFFT_INVALID_SIZE,"CUFFT_INVALID_SIZE"},
- {CUFFT_UNALIGNED_DATA,"CUFFT_UNALIGNED_DATA"}};
-
- const auto it = res2string.find(err);
- if(it != res2string.end()){
- return it->second;
- }
- else{
- return std::to_string(err) +
- " (unknown error code)";
- }
-}
-#endif
#endif // WARPX_USE_PSATD
diff --git a/Source/FieldSolver/SpectralSolver/WrapCuFFT.cpp b/Source/FieldSolver/SpectralSolver/WrapCuFFT.cpp
new file mode 100644
index 000000000..3793eebe0
--- /dev/null
+++ b/Source/FieldSolver/SpectralSolver/WrapCuFFT.cpp
@@ -0,0 +1,119 @@
+#include "AnyFFT.H"
+
+namespace AnyFFT
+{
+
+#ifdef AMREX_USE_FLOAT
+ cufftType VendorR2C = CUFFT_R2C;
+ cufftType VendorC2R = CUFFT_C2R;
+#else
+ cufftType VendorR2C = CUFFT_D2Z;
+ cufftType VendorC2R = CUFFT_Z2D;
+#endif
+
+ std::string cufftErrorToString (const cufftResult& err);
+
+ FFTplan CreatePlan(const amrex::IntVect& real_size, amrex::Real * const real_array,
+ Complex * const complex_array, const direction dir, const int dim)
+ {
+ FFTplan fft_plan;
+
+ // Initialize fft_plan.m_plan with the vendor fft plan.
+ cufftResult result;
+ if (dir == direction::R2C){
+ if (dim == 3) {
+ result = cufftPlan3d(
+ &(fft_plan.m_plan), real_size[2], real_size[1], real_size[0], VendorR2C);
+ } else if (dim == 2) {
+ result = cufftPlan2d(
+ &(fft_plan.m_plan), real_size[1], real_size[0], VendorR2C);
+ } else {
+ amrex::Abort("only dim=2 and dim=3 have been implemented");
+ }
+ } else {
+ if (dim == 3) {
+ result = cufftPlan3d(
+ &(fft_plan.m_plan), real_size[2], real_size[1], real_size[0], VendorC2R);
+ } else if (dim == 2) {
+ result = cufftPlan2d(
+ &(fft_plan.m_plan), real_size[1], real_size[0], VendorC2R);
+ } else {
+ amrex::Abort("only dim=2 and dim=3 have been implemented");
+ }
+ }
+
+ if ( result != CUFFT_SUCCESS ) {
+ amrex::Print() << " cufftplan failed! Error: " <<
+ cufftErrorToString(result) << "\n";
+ }
+
+ // Store meta-data in fft_plan
+ fft_plan.m_real_array = real_array;
+ fft_plan.m_complex_array = complex_array;
+ fft_plan.m_dir = dir;
+ fft_plan.m_dim = dim;
+
+ return fft_plan;
+ }
+
+ void DestroyPlan(FFTplan& fft_plan)
+ {
+ cufftDestroy( fft_plan.m_plan );
+ }
+
+ void Execute(FFTplan& fft_plan){
+ // make sure that this is done on the same GPU stream as the above copy
+ cudaStream_t stream = amrex::Gpu::Device::cudaStream();
+ cufftSetStream ( fft_plan.m_plan, stream);
+ cufftResult result;
+ if (fft_plan.m_dir == direction::R2C){
+#ifdef AMREX_USE_FLOAT
+ result = cufftExecR2C(fft_plan.m_plan, fft_plan.m_real_array, fft_plan.m_complex_array);
+#else
+ result = cufftExecD2Z(fft_plan.m_plan, fft_plan.m_real_array, fft_plan.m_complex_array);
+#endif
+ } else if (fft_plan.m_dir == direction::C2R){
+#ifdef AMREX_USE_FLOAT
+ result = cufftExecC2R(fft_plan.m_plan, fft_plan.m_complex_array, fft_plan.m_real_array);
+#else
+ result = cufftExecZ2D(fft_plan.m_plan, fft_plan.m_complex_array, fft_plan.m_real_array);
+#endif
+ } else {
+ amrex::Abort("direction must be AnyFFT::direction::R2C or AnyFFT::direction::C2R");
+ }
+ if ( result != CUFFT_SUCCESS ) {
+ amrex::Print() << " forward transform using cufftExec failed ! Error: " <<
+ cufftErrorToString(result) << "\n";
+ }
+ }
+
+ /** \brief This method converts a cufftResult
+ * into the corresponding string
+ *
+ * @param[in] err a cufftResult
+ * @return an std::string
+ */
+ std::string cufftErrorToString (const cufftResult& err)
+ {
+ const auto res2string = std::map<cufftResult, std::string>{
+ {CUFFT_SUCCESS, "CUFFT_SUCCESS"},
+ {CUFFT_INVALID_PLAN,"CUFFT_INVALID_PLAN"},
+ {CUFFT_ALLOC_FAILED,"CUFFT_ALLOC_FAILED"},
+ {CUFFT_INVALID_TYPE,"CUFFT_INVALID_TYPE"},
+ {CUFFT_INVALID_VALUE,"CUFFT_INVALID_VALUE"},
+ {CUFFT_INTERNAL_ERROR,"CUFFT_INTERNAL_ERROR"},
+ {CUFFT_EXEC_FAILED,"CUFFT_EXEC_FAILED"},
+ {CUFFT_SETUP_FAILED,"CUFFT_SETUP_FAILED"},
+ {CUFFT_INVALID_SIZE,"CUFFT_INVALID_SIZE"},
+ {CUFFT_UNALIGNED_DATA,"CUFFT_UNALIGNED_DATA"}};
+
+ const auto it = res2string.find(err);
+ if(it != res2string.end()){
+ return it->second;
+ }
+ else{
+ return std::to_string(err) +
+ " (unknown error code)";
+ }
+ }
+}
diff --git a/Source/FieldSolver/SpectralSolver/WrapFFTW.cpp b/Source/FieldSolver/SpectralSolver/WrapFFTW.cpp
new file mode 100644
index 000000000..2780771ef
--- /dev/null
+++ b/Source/FieldSolver/SpectralSolver/WrapFFTW.cpp
@@ -0,0 +1,71 @@
+#include "AnyFFT.H"
+
+namespace AnyFFT
+{
+#ifdef AMREX_USE_FLOAT
+ const auto VendorCreatePlanR2C3D = fftwf_plan_dft_r2c_3d;
+ const auto VendorCreatePlanC2R3D = fftwf_plan_dft_c2r_3d;
+ const auto VendorCreatePlanR2C2D = fftwf_plan_dft_r2c_2d;
+ const auto VendorCreatePlanC2R2D = fftwf_plan_dft_c2r_2d;
+#else
+ const auto VendorCreatePlanR2C3D = fftw_plan_dft_r2c_3d;
+ const auto VendorCreatePlanC2R3D = fftw_plan_dft_c2r_3d;
+ const auto VendorCreatePlanR2C2D = fftw_plan_dft_r2c_2d;
+ const auto VendorCreatePlanC2R2D = fftw_plan_dft_c2r_2d;
+#endif
+
+ FFTplan CreatePlan(const amrex::IntVect& real_size, amrex::Real * const real_array,
+ Complex * const complex_array, const direction dir, const int dim)
+ {
+ FFTplan fft_plan;
+
+ // Initialize fft_plan.m_plan with the vendor fft plan.
+ // Swap dimensions: AMReX FAB are Fortran-order but FFTW is C-order
+ if (dir == direction::R2C){
+ if (dim == 3) {
+ fft_plan.m_plan = VendorCreatePlanR2C3D(
+ real_size[2], real_size[1], real_size[0], real_array, complex_array, FFTW_ESTIMATE);
+ } else if (dim == 2) {
+ fft_plan.m_plan = VendorCreatePlanR2C2D(
+ real_size[1], real_size[0], real_array, complex_array, FFTW_ESTIMATE);
+ } else {
+ amrex::Abort("only dim=2 and dim=3 have been implemented");
+ }
+ } else if (dir == direction::C2R){
+ if (dim == 3) {
+ fft_plan.m_plan = VendorCreatePlanC2R3D(
+ real_size[2], real_size[1], real_size[0], complex_array, real_array, FFTW_ESTIMATE);
+ } else if (dim == 2) {
+ fft_plan.m_plan = VendorCreatePlanC2R2D(
+ real_size[1], real_size[0], complex_array, real_array, FFTW_ESTIMATE);
+ } else {
+ amrex::Abort("only dim=2 and dim=3 have been implemented. Should be easy to add dim=1.");
+ }
+ }
+
+ // Store meta-data in fft_plan
+ fft_plan.m_real_array = real_array;
+ fft_plan.m_complex_array = complex_array;
+ fft_plan.m_dir = dir;
+ fft_plan.m_dim = dim;
+
+ return fft_plan;
+ }
+
+ void DestroyPlan(FFTplan& fft_plan)
+ {
+# ifdef AMREX_USE_FLOAT
+ fftwf_destroy_plan( fft_plan.m_plan );
+# else
+ fftw_destroy_plan( fft_plan.m_plan );
+# endif
+ }
+
+ void Execute(FFTplan& fft_plan){
+# ifdef AMREX_USE_FLOAT
+ fftwf_execute( fft_plan.m_plan );
+# else
+ fftw_execute( fft_plan.m_plan );
+# endif
+ }
+}
diff --git a/Source/Utils/WarpX_Complex.H b/Source/Utils/WarpX_Complex.H
index dbcfd1a87..e5a7f4602 100644
--- a/Source/Utils/WarpX_Complex.H
+++ b/Source/Utils/WarpX_Complex.H
@@ -8,6 +8,10 @@
#ifndef WARPX_COMPLEX_H_
#define WARPX_COMPLEX_H_
+#ifdef WARPX_USE_PSATD
+# include "FieldSolver/SpectralSolver/AnyFFT.H"
+#endif
+
#include <AMReX_REAL.H>
#include <AMReX_Gpu.H>
#include <AMReX_GpuComplex.H>
@@ -17,32 +21,11 @@
// Defines a complex type on GPU & CPU
using Complex = amrex::GpuComplex<amrex::Real>;
-#ifdef AMREX_USE_GPU
-# ifdef WARPX_USE_PSATD
-# include <cufft.h>
-# ifdef AMREX_USE_FLOAT
-static_assert( sizeof(Complex) == sizeof(cuComplex),
- "The complex types in WarpX and cuFFT do not match.");
-# else
-static_assert( sizeof(Complex) == sizeof(cuDoubleComplex),
- "The complex types in WarpX and cuFFT do not match.");
-# endif
-# endif // WARPX_USE_PSATD
-
-#else
-
-# ifdef WARPX_USE_PSATD
-# include <fftw3.h>
-# ifdef AMREX_USE_FLOAT
-static_assert( sizeof(Complex) == sizeof(fftwf_complex),
- "The complex types in WarpX and FFTW do not match.");
-# else
-static_assert( sizeof(Complex) == sizeof(fftw_complex),
- "The complex types in WarpX and FFTW do not match.");
-# endif
-# endif // WARPX_USE_PSATD
+#ifdef WARPX_USE_PSATD
+static_assert(sizeof(Complex) == sizeof(AnyFFT::Complex),
+ "The complex type in WarpX and the FFT library do not match.");
+#endif
-#endif // AMREX_USE_GPU
static_assert(sizeof(Complex) == sizeof(amrex::Real[2]),
"Unexpected complex type.");