aboutsummaryrefslogtreecommitdiff
path: root/Source/FieldSolver
diff options
context:
space:
mode:
Diffstat (limited to 'Source/FieldSolver')
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp23
1 files changed, 18 insertions, 5 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
index dc7f58f48..c98c79835 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
@@ -77,8 +77,13 @@ SpectralFieldDataRZ::SpectralFieldDataRZ (amrex::BoxArray const & realspace_ba,
int ostride = grid_size[0];
int odist = 1;
int batch = grid_size[0]; // number of ffts
+# ifdef AMREX_USE_FLOAT
+ auto cufft_type = CUFFT_C2C;
+# else
+ auto cufft_type = CUFFT_Z2Z;
+# endif
result = cufftPlanMany(&forward_plan[mfi], 1, fft_length, inembed, istride, idist,
- onembed, ostride, odist, CUFFT_Z2Z, batch);
+ onembed, ostride, odist, cufft_type, batch);
if (result != CUFFT_SUCCESS) {
amrex::AllPrint() << " cufftPlanMany failed! \n";
}
@@ -225,9 +230,13 @@ SpectralFieldDataRZ::FABZForwardTransform (amrex::MFIter const & mfi, amrex::Box
cudaStream_t stream = amrex::Gpu::Device::cudaStream();
cufftSetStream(forward_plan[mfi], stream);
for (int mode=0 ; mode < n_rz_azimuthal_modes ; mode++) {
+# ifdef AMREX_USE_FLOAT
+ result = cufftExecC2C(forward_plan[mfi],
+# else
result = cufftExecZ2Z(forward_plan[mfi],
- reinterpret_cast<cuDoubleComplex*>(tempHTransformed[mfi].dataPtr(mode)), // cuDoubleComplex *in
- reinterpret_cast<cuDoubleComplex*>(tmpSpectralField[mfi].dataPtr(mode)), // cuDoubleComplex *out
+# endif
+ reinterpret_cast<AnyFFT::Complex*>(tempHTransformed[mfi].dataPtr(mode)), // Complex *in
+ reinterpret_cast<AnyFFT::Complex*>(tmpSpectralField[mfi].dataPtr(mode)), // Complex *out
CUFFT_FORWARD);
if (result != CUFFT_SUCCESS) {
amrex::AllPrint() << " forward transform using cufftExecZ2Z failed ! \n";
@@ -330,9 +339,13 @@ SpectralFieldDataRZ::FABZBackwardTransform (amrex::MFIter const & mfi, amrex::Bo
cudaStream_t stream = amrex::Gpu::Device::cudaStream();
cufftSetStream(forward_plan[mfi], stream);
for (int mode=0 ; mode < n_rz_azimuthal_modes ; mode++) {
+# ifdef AMREX_USE_FLOAT
+ result = cufftExecC2C(forward_plan[mfi],
+# else
result = cufftExecZ2Z(forward_plan[mfi],
- reinterpret_cast<cuDoubleComplex*>(tmpSpectralField[mfi].dataPtr(mode)), // cuDoubleComplex *in
- reinterpret_cast<cuDoubleComplex*>(tempHTransformed[mfi].dataPtr(mode)), // cuDoubleComplex *out
+# endif
+ reinterpret_cast<AnyFFT::Complex*>(tmpSpectralField[mfi].dataPtr(mode)), // Complex *in
+ reinterpret_cast<AnyFFT::Complex*>(tempHTransformed[mfi].dataPtr(mode)), // Complex *out
CUFFT_INVERSE);
if (result != CUFFT_SUCCESS) {
amrex::AllPrint() << " backwardtransform using cufftExecZ2Z failed ! \n";