diff options
Diffstat (limited to 'Source/FieldSolver')
-rw-r--r-- | Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp index dc7f58f48..c98c79835 100644 --- a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp +++ b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp @@ -77,8 +77,13 @@ SpectralFieldDataRZ::SpectralFieldDataRZ (amrex::BoxArray const & realspace_ba, int ostride = grid_size[0]; int odist = 1; int batch = grid_size[0]; // number of ffts +# ifdef AMREX_USE_FLOAT + auto cufft_type = CUFFT_C2C; +# else + auto cufft_type = CUFFT_Z2Z; +# endif result = cufftPlanMany(&forward_plan[mfi], 1, fft_length, inembed, istride, idist, - onembed, ostride, odist, CUFFT_Z2Z, batch); + onembed, ostride, odist, cufft_type, batch); if (result != CUFFT_SUCCESS) { amrex::AllPrint() << " cufftPlanMany failed! \n"; } @@ -225,9 +230,13 @@ SpectralFieldDataRZ::FABZForwardTransform (amrex::MFIter const & mfi, amrex::Box cudaStream_t stream = amrex::Gpu::Device::cudaStream(); cufftSetStream(forward_plan[mfi], stream); for (int mode=0 ; mode < n_rz_azimuthal_modes ; mode++) { +# ifdef AMREX_USE_FLOAT + result = cufftExecC2C(forward_plan[mfi], +# else result = cufftExecZ2Z(forward_plan[mfi], - reinterpret_cast<cuDoubleComplex*>(tempHTransformed[mfi].dataPtr(mode)), // cuDoubleComplex *in - reinterpret_cast<cuDoubleComplex*>(tmpSpectralField[mfi].dataPtr(mode)), // cuDoubleComplex *out +# endif + reinterpret_cast<AnyFFT::Complex*>(tempHTransformed[mfi].dataPtr(mode)), // Complex *in + reinterpret_cast<AnyFFT::Complex*>(tmpSpectralField[mfi].dataPtr(mode)), // Complex *out CUFFT_FORWARD); if (result != CUFFT_SUCCESS) { amrex::AllPrint() << " forward transform using cufftExecZ2Z failed ! \n"; @@ -330,9 +339,13 @@ SpectralFieldDataRZ::FABZBackwardTransform (amrex::MFIter const & mfi, amrex::Bo cudaStream_t stream = amrex::Gpu::Device::cudaStream(); cufftSetStream(forward_plan[mfi], stream); for (int mode=0 ; mode < n_rz_azimuthal_modes ; mode++) { +# ifdef AMREX_USE_FLOAT + result = cufftExecC2C(forward_plan[mfi], +# else result = cufftExecZ2Z(forward_plan[mfi], - reinterpret_cast<cuDoubleComplex*>(tmpSpectralField[mfi].dataPtr(mode)), // cuDoubleComplex *in - reinterpret_cast<cuDoubleComplex*>(tempHTransformed[mfi].dataPtr(mode)), // cuDoubleComplex *out +# endif + reinterpret_cast<AnyFFT::Complex*>(tmpSpectralField[mfi].dataPtr(mode)), // Complex *in + reinterpret_cast<AnyFFT::Complex*>(tempHTransformed[mfi].dataPtr(mode)), // Complex *out CUFFT_INVERSE); if (result != CUFFT_SUCCESS) { amrex::AllPrint() << " backwardtransform using cufftExecZ2Z failed ! \n"; |