diff options
author | 2020-10-07 09:25:00 -0700 | |
---|---|---|
committer | 2020-10-07 09:25:00 -0700 | |
commit | b7996947408d0d9c4a01e7c8a19bb9dbf905c341 (patch) | |
tree | eeb3ee5586f8dda7d77adbb61568e7a68e464eb4 /Source/FieldSolver/SpectralSolver | |
parent | 530bbda9f58a7909f3dc2d0676a680e7669a4976 (diff) | |
download | WarpX-b7996947408d0d9c4a01e7c8a19bb9dbf905c341.tar.gz WarpX-b7996947408d0d9c4a01e7c8a19bb9dbf905c341.tar.zst WarpX-b7996947408d0d9c4a01e7c8a19bb9dbf905c341.zip |
RZ + single precision fix (#1417)
* RZ_SP fix
* RZ SP fix
* Fix cufft for single precision RZ.
Co-authored-by: Weiqun Zhang <weiqunzhang@lbl.gov>
Diffstat (limited to 'Source/FieldSolver/SpectralSolver')
-rw-r--r-- | Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp index dc7f58f48..c98c79835 100644 --- a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp +++ b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp @@ -77,8 +77,13 @@ SpectralFieldDataRZ::SpectralFieldDataRZ (amrex::BoxArray const & realspace_ba, int ostride = grid_size[0]; int odist = 1; int batch = grid_size[0]; // number of ffts +# ifdef AMREX_USE_FLOAT + auto cufft_type = CUFFT_C2C; +# else + auto cufft_type = CUFFT_Z2Z; +# endif result = cufftPlanMany(&forward_plan[mfi], 1, fft_length, inembed, istride, idist, - onembed, ostride, odist, CUFFT_Z2Z, batch); + onembed, ostride, odist, cufft_type, batch); if (result != CUFFT_SUCCESS) { amrex::AllPrint() << " cufftPlanMany failed! \n"; } @@ -225,9 +230,13 @@ SpectralFieldDataRZ::FABZForwardTransform (amrex::MFIter const & mfi, amrex::Box cudaStream_t stream = amrex::Gpu::Device::cudaStream(); cufftSetStream(forward_plan[mfi], stream); for (int mode=0 ; mode < n_rz_azimuthal_modes ; mode++) { +# ifdef AMREX_USE_FLOAT + result = cufftExecC2C(forward_plan[mfi], +# else result = cufftExecZ2Z(forward_plan[mfi], - reinterpret_cast<cuDoubleComplex*>(tempHTransformed[mfi].dataPtr(mode)), // cuDoubleComplex *in - reinterpret_cast<cuDoubleComplex*>(tmpSpectralField[mfi].dataPtr(mode)), // cuDoubleComplex *out +# endif + reinterpret_cast<AnyFFT::Complex*>(tempHTransformed[mfi].dataPtr(mode)), // Complex *in + reinterpret_cast<AnyFFT::Complex*>(tmpSpectralField[mfi].dataPtr(mode)), // Complex *out CUFFT_FORWARD); if (result != CUFFT_SUCCESS) { amrex::AllPrint() << " forward transform using cufftExecZ2Z failed ! \n"; @@ -330,9 +339,13 @@ SpectralFieldDataRZ::FABZBackwardTransform (amrex::MFIter const & mfi, amrex::Bo cudaStream_t stream = amrex::Gpu::Device::cudaStream(); cufftSetStream(forward_plan[mfi], stream); for (int mode=0 ; mode < n_rz_azimuthal_modes ; mode++) { +# ifdef AMREX_USE_FLOAT + result = cufftExecC2C(forward_plan[mfi], +# else result = cufftExecZ2Z(forward_plan[mfi], - reinterpret_cast<cuDoubleComplex*>(tmpSpectralField[mfi].dataPtr(mode)), // cuDoubleComplex *in - reinterpret_cast<cuDoubleComplex*>(tempHTransformed[mfi].dataPtr(mode)), // cuDoubleComplex *out +# endif + reinterpret_cast<AnyFFT::Complex*>(tmpSpectralField[mfi].dataPtr(mode)), // Complex *in + reinterpret_cast<AnyFFT::Complex*>(tempHTransformed[mfi].dataPtr(mode)), // Complex *out CUFFT_INVERSE); if (result != CUFFT_SUCCESS) { amrex::AllPrint() << " backwardtransform using cufftExecZ2Z failed ! \n"; |