aboutsummaryrefslogtreecommitdiff
path: root/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
diff options
context:
space:
mode:
authorGravatar Weiqun Zhang <WeiqunZhang@lbl.gov> 2020-10-07 09:23:53 -0700
committerGravatar GitHub <noreply@github.com> 2020-10-07 09:23:53 -0700
commit530bbda9f58a7909f3dc2d0676a680e7669a4976 (patch)
tree6f3eddf0b55e81e6e4ffc2abe3a676a61fa6db2f /Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
parent49ed40b5610705c7f587fdad7c33349df4f7a878 (diff)
downloadWarpX-530bbda9f58a7909f3dc2d0676a680e7669a4976.tar.gz
WarpX-530bbda9f58a7909f3dc2d0676a680e7669a4976.tar.zst
WarpX-530bbda9f58a7909f3dc2d0676a680e7669a4976.zip
rocFFT support (#1410)
* rocFFT support * rocfft in 2d rz PSATD solver
Diffstat (limited to 'Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp')
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp109
1 files changed, 100 insertions, 9 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
index 736c2dcfa..dc7f58f48 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
@@ -53,8 +53,8 @@ SpectralFieldDataRZ::SpectralFieldDataRZ (amrex::BoxArray const & realspace_ba,
// Allocate and initialize the FFT plans and Hankel transformer.
forward_plan = FFTplans(spectralspace_ba, dm);
-#ifndef AMREX_USE_GPU
- // The backward plan is not needed with GPU since it would be the same
+#ifndef AMREX_USE_CUDA
+ // The backward plan is not needed with CUDA since it would be the same
// as the forward plan anyway.
backward_plan = FFTplans(spectralspace_ba, dm);
#endif
@@ -64,7 +64,7 @@ SpectralFieldDataRZ::SpectralFieldDataRZ (amrex::BoxArray const & realspace_ba,
// for each box owned by the local MPI proc.
for (amrex::MFIter mfi(spectralspace_ba, dm); mfi.isValid(); ++mfi){
amrex::IntVect grid_size = realspace_ba[mfi].length();
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
// Create cuFFT plan.
// This is alway complex to complex.
// This plan is for one azimuthal mode only.
@@ -80,10 +80,56 @@ SpectralFieldDataRZ::SpectralFieldDataRZ (amrex::BoxArray const & realspace_ba,
result = cufftPlanMany(&forward_plan[mfi], 1, fft_length, inembed, istride, idist,
onembed, ostride, odist, CUFFT_Z2Z, batch);
if (result != CUFFT_SUCCESS) {
- amrex::Print() << " cufftPlanMany failed! \n";
+ amrex::AllPrint() << " cufftPlanMany failed! \n";
}
// The backward plane is the same as the forward since the direction is passed when executed.
+#elif defined(AMREX_USE_HIP)
+ const std::size_t fft_length[] = {static_cast<std::size_t>(grid_size[1])};
+ const std::size_t stride[] = {static_cast<std::size_t>(grid_size[0])};
+ rocfft_plan_description description;
+ rocfft_status result;
+ result = rocfft_plan_description_create(&description);
+ result = rocfft_plan_description_set_data_layout(description,
+ rocfft_array_type_complex_interleaved,
+ rocfft_array_type_complex_interleaved,
+ nullptr, nullptr,
+ 1, stride, 1,
+ 1, stride, 1);
+
+ result = rocfft_plan_create(&(forward_plan[mfi]),
+ rocfft_placement_notinplace,
+ rocfft_transform_type_complex_forward,
+#ifdef AMREX_USE_FLOAT
+ rocfft_precision_single,
+#else
+ rocfft_precision_double,
+#endif
+ 1, fft_length,
+ grid_size[0], // number of transforms
+ description);
+ if (result != rocfft_status_success) {
+ amrex::AllPrint() << " rocfft_plan_create failed! \n";
+ }
+ result = rocfft_plan_create(&(backward_plan[mfi]),
+ rocfft_placement_notinplace,
+ rocfft_transform_type_complex_inverse,
+#ifdef AMREX_USE_FLOAT
+ rocfft_precision_single,
+#else
+ rocfft_precision_double,
+#endif
+ 1, fft_length,
+ grid_size[0], // number of transforms
+ description);
+ if (result != rocfft_status_success) {
+ amrex::AllPrint() << " rocfft_plan_create failed! \n";
+ }
+
+ result = rocfft_plan_description_destroy(description);
+ if (result != rocfft_status_success) {
+ amrex::AllPrint() << " rocfft_plan_description_destroy failed! \n";
+ }
#else
// Create FFTW plans.
fftw_iodim dims[1];
@@ -129,10 +175,13 @@ SpectralFieldDataRZ::~SpectralFieldDataRZ()
{
if (fields.size() > 0){
for (amrex::MFIter mfi(fields); mfi.isValid(); ++mfi){
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
// Destroy cuFFT plans.
cufftDestroy(forward_plan[mfi]);
// cufftDestroy(backward_plan[mfi]); // This was never allocated.
+#elif defined(AMREX_USE_HIP)
+ rocfft_plan_destroy(forward_plan[mfi]);
+ rocfft_plan_destroy(backward_plan[mfi]);
#else
// Destroy FFTW plans.
fftw_destroy_plan(forward_plan[mfi]);
@@ -168,7 +217,7 @@ SpectralFieldDataRZ::FABZForwardTransform (amrex::MFIter const & mfi, amrex::Box
});
// Perform Fourier transform from `tempHTransformed` to `tmpSpectralField`.
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
// Perform Fast Fourier Transform on GPU using cuFFT.
// Make sure that this is done on the same
// GPU stream as the above copy.
@@ -181,9 +230,30 @@ SpectralFieldDataRZ::FABZForwardTransform (amrex::MFIter const & mfi, amrex::Box
reinterpret_cast<cuDoubleComplex*>(tmpSpectralField[mfi].dataPtr(mode)), // cuDoubleComplex *out
CUFFT_FORWARD);
if (result != CUFFT_SUCCESS) {
- amrex::Print() << " forward transform using cufftExecZ2Z failed ! \n";
+ amrex::AllPrint() << " forward transform using cufftExecZ2Z failed ! \n";
+ }
+ }
+#elif defined(AMREX_USE_HIP)
+ rocfft_execution_info execinfo = NULL;
+ rocfft_status result = rocfft_execution_info_create(&execinfo);
+ std::size_t buffersize = 0;
+ result = rocfft_plan_get_work_buffer_size(forward_plan[mfi], &buffersize);
+ void* buffer = amrex::The_Arena()->alloc(buffersize);
+ result = rocfft_execution_info_set_work_buffer(execinfo, buffer, buffersize);
+ result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());
+
+ for (int mode=0 ; mode < n_rz_azimuthal_modes ; mode++) {
+ void* in_array[] = {(void*)(tempHTransformed[mfi].dataPtr(mode))};
+ void* out_array[] = {(void*)(tmpSpectralField[mfi].dataPtr(mode))};
+ result = rocfft_execute(forward_plan[mfi], in_array, out_array, execinfo);
+ if (result != rocfft_status_success) {
+ amrex::AllPrint() << " forward transform using rocfft_execute failed ! \n";
}
}
+
+ amrex::Gpu::streamSynchronize();
+ amrex::The_Arena()->free(buffer);
+ result = rocfft_execution_info_destroy(execinfo);
#else
fftw_execute(forward_plan[mfi]);
#endif
@@ -252,7 +322,7 @@ SpectralFieldDataRZ::FABZBackwardTransform (amrex::MFIter const & mfi, amrex::Bo
});
// Perform Fourier transform from `tmpSpectralField` to `tempHTransformed`.
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
// Perform Fast Fourier Transform on GPU using cuFFT.
// Make sure that this is done on the same
// GPU stream as the above copy.
@@ -265,9 +335,30 @@ SpectralFieldDataRZ::FABZBackwardTransform (amrex::MFIter const & mfi, amrex::Bo
reinterpret_cast<cuDoubleComplex*>(tempHTransformed[mfi].dataPtr(mode)), // cuDoubleComplex *out
CUFFT_INVERSE);
if (result != CUFFT_SUCCESS) {
- amrex::Print() << " backwardtransform using cufftExecZ2Z failed ! \n";
+ amrex::AllPrint() << " backwardtransform using cufftExecZ2Z failed ! \n";
}
}
+#elif defined(AMREX_USE_HIP)
+ rocfft_execution_info execinfo = NULL;
+ rocfft_status result = rocfft_execution_info_create(&execinfo);
+ std::size_t buffersize = 0;
+ result = rocfft_plan_get_work_buffer_size(forward_plan[mfi], &buffersize);
+ void* buffer = amrex::The_Arena()->alloc(buffersize);
+ result = rocfft_execution_info_set_work_buffer(execinfo, buffer, buffersize);
+ result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());
+
+ for (int mode=0 ; mode < n_rz_azimuthal_modes ; mode++) {
+ void* in_array[] = {(void*)(tmpSpectralField[mfi].dataPtr(mode))};
+ void* out_array[] = {(void*)(tempHTransformed[mfi].dataPtr(mode))};
+ result = rocfft_execute(backward_plan[mfi], in_array, out_array, execinfo);
+ if (result != rocfft_status_success) {
+ amrex::AllPrint() << " forward transform using rocfft_execute failed ! \n";
+ }
+ }
+
+ amrex::Gpu::streamSynchronize();
+ amrex::The_Arena()->free(buffer);
+ result = rocfft_execution_info_destroy(execinfo);
#else
fftw_execute(backward_plan[mfi]);
#endif