aboutsummaryrefslogtreecommitdiff
path: root/Source/FieldSolver/SpectralSolver
diff options
context:
space:
mode:
authorGravatar Weiqun Zhang <WeiqunZhang@lbl.gov> 2020-10-07 09:23:53 -0700
committerGravatar GitHub <noreply@github.com> 2020-10-07 09:23:53 -0700
commit530bbda9f58a7909f3dc2d0676a680e7669a4976 (patch)
tree6f3eddf0b55e81e6e4ffc2abe3a676a61fa6db2f /Source/FieldSolver/SpectralSolver
parent49ed40b5610705c7f587fdad7c33349df4f7a878 (diff)
downloadWarpX-530bbda9f58a7909f3dc2d0676a680e7669a4976.tar.gz
WarpX-530bbda9f58a7909f3dc2d0676a680e7669a4976.tar.zst
WarpX-530bbda9f58a7909f3dc2d0676a680e7669a4976.zip
rocFFT support (#1410)
* rocFFT support * rocfft in 2d rz PSATD solver
Diffstat (limited to 'Source/FieldSolver/SpectralSolver')
-rw-r--r--Source/FieldSolver/SpectralSolver/AnyFFT.H16
-rw-r--r--Source/FieldSolver/SpectralSolver/Make.package2
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldData.H10
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H4
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp109
-rw-r--r--Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp132
6 files changed, 250 insertions, 23 deletions
diff --git a/Source/FieldSolver/SpectralSolver/AnyFFT.H b/Source/FieldSolver/SpectralSolver/AnyFFT.H
index fbdc1520f..1a551c167 100644
--- a/Source/FieldSolver/SpectralSolver/AnyFFT.H
+++ b/Source/FieldSolver/SpectralSolver/AnyFFT.H
@@ -8,8 +8,10 @@
#ifndef ANYFFT_H_
#define ANYFFT_H_
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
# include <cufft.h>
+#elif defined(AMREX_USE_HIP)
+# include <rocfft.h>
#else
# include <fftw3.h>
#endif
@@ -27,12 +29,18 @@ namespace AnyFFT
// First, define library-dependent types (complex, FFT plan)
/** Complex type for FFT, depends on FFT library */
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
# ifdef AMREX_USE_FLOAT
using Complex = cuComplex;
# else
using Complex = cuDoubleComplex;
# endif
+#elif defined(AMREX_USE_HIP)
+# ifdef AMREX_USE_FLOAT
+ using Complex = float2;
+# else
+ using Complex = double2;
+# endif
#else
# ifdef AMREX_USE_FLOAT
using Complex = fftwf_complex;
@@ -44,8 +52,10 @@ namespace AnyFFT
/** Library-dependent FFT plans type, which holds one fft plan per box
* (plans are only initialized for the boxes that are owned by the local MPI rank).
*/
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
using VendorFFTPlan = cufftHandle;
+#elif defined(AMREX_USE_HIP)
+ using VendorFFTPlan = rocfft_plan;
#else
# ifdef AMREX_USE_FLOAT
using VendorFFTPlan = fftwf_plan;
diff --git a/Source/FieldSolver/SpectralSolver/Make.package b/Source/FieldSolver/SpectralSolver/Make.package
index 04ae352be..8be3a6812 100644
--- a/Source/FieldSolver/SpectralSolver/Make.package
+++ b/Source/FieldSolver/SpectralSolver/Make.package
@@ -3,6 +3,8 @@ CEXE_sources += SpectralFieldData.cpp
CEXE_sources += SpectralKSpace.cpp
ifeq ($(USE_CUDA),TRUE)
CEXE_sources += WrapCuFFT.cpp
+else ifeq ($(USE_HIP),TRUE)
+ CEXE_sources += WrapRocFFT.cpp
else
CEXE_sources += WrapFFTW.cpp
endif
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldData.H b/Source/FieldSolver/SpectralSolver/SpectralFieldData.H
index f48272744..4990f9926 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldData.H
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldData.H
@@ -83,16 +83,6 @@ class SpectralFieldData
SpectralShiftFactor yshift_FFTfromCell, yshift_FFTtoCell;
#endif
-#ifdef AMREX_USE_GPU
- /** \brief This method converts a cufftResult
- * into the corresponding string
- *
- * @param[in] err a cufftResult
- * @return an std::string
- */
- std::string cufftErrorToString (const cufftResult& err);
-#endif
-
bool m_periodic_single_box;
};
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H
index b89c0106e..0dff6da7d 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.H
@@ -24,8 +24,10 @@ class SpectralFieldDataRZ
// Define the FFTplans type, which holds one fft plan per box
// (plans are only initialized for the boxes that are owned by
// the local MPI rank)
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
using FFTplans = amrex::LayoutData<cufftHandle>;
+#elif defined(AMREX_USE_HIP)
+ using FFTplans = amrex::LayoutData<rocfft_plan>;
#else
using FFTplans = amrex::LayoutData<fftw_plan>;
#endif
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
index 736c2dcfa..dc7f58f48 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldDataRZ.cpp
@@ -53,8 +53,8 @@ SpectralFieldDataRZ::SpectralFieldDataRZ (amrex::BoxArray const & realspace_ba,
// Allocate and initialize the FFT plans and Hankel transformer.
forward_plan = FFTplans(spectralspace_ba, dm);
-#ifndef AMREX_USE_GPU
- // The backward plan is not needed with GPU since it would be the same
+#ifndef AMREX_USE_CUDA
+ // The backward plan is not needed with CUDA since it would be the same
// as the forward plan anyway.
backward_plan = FFTplans(spectralspace_ba, dm);
#endif
@@ -64,7 +64,7 @@ SpectralFieldDataRZ::SpectralFieldDataRZ (amrex::BoxArray const & realspace_ba,
// for each box owned by the local MPI proc.
for (amrex::MFIter mfi(spectralspace_ba, dm); mfi.isValid(); ++mfi){
amrex::IntVect grid_size = realspace_ba[mfi].length();
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
// Create cuFFT plan.
// This is alway complex to complex.
// This plan is for one azimuthal mode only.
@@ -80,10 +80,56 @@ SpectralFieldDataRZ::SpectralFieldDataRZ (amrex::BoxArray const & realspace_ba,
result = cufftPlanMany(&forward_plan[mfi], 1, fft_length, inembed, istride, idist,
onembed, ostride, odist, CUFFT_Z2Z, batch);
if (result != CUFFT_SUCCESS) {
- amrex::Print() << " cufftPlanMany failed! \n";
+ amrex::AllPrint() << " cufftPlanMany failed! \n";
}
// The backward plane is the same as the forward since the direction is passed when executed.
+#elif defined(AMREX_USE_HIP)
+ const std::size_t fft_length[] = {static_cast<std::size_t>(grid_size[1])};
+ const std::size_t stride[] = {static_cast<std::size_t>(grid_size[0])};
+ rocfft_plan_description description;
+ rocfft_status result;
+ result = rocfft_plan_description_create(&description);
+ result = rocfft_plan_description_set_data_layout(description,
+ rocfft_array_type_complex_interleaved,
+ rocfft_array_type_complex_interleaved,
+ nullptr, nullptr,
+ 1, stride, 1,
+ 1, stride, 1);
+
+ result = rocfft_plan_create(&(forward_plan[mfi]),
+ rocfft_placement_notinplace,
+ rocfft_transform_type_complex_forward,
+#ifdef AMREX_USE_FLOAT
+ rocfft_precision_single,
+#else
+ rocfft_precision_double,
+#endif
+ 1, fft_length,
+ grid_size[0], // number of transforms
+ description);
+ if (result != rocfft_status_success) {
+ amrex::AllPrint() << " rocfft_plan_create failed! \n";
+ }
+ result = rocfft_plan_create(&(backward_plan[mfi]),
+ rocfft_placement_notinplace,
+ rocfft_transform_type_complex_inverse,
+#ifdef AMREX_USE_FLOAT
+ rocfft_precision_single,
+#else
+ rocfft_precision_double,
+#endif
+ 1, fft_length,
+ grid_size[0], // number of transforms
+ description);
+ if (result != rocfft_status_success) {
+ amrex::AllPrint() << " rocfft_plan_create failed! \n";
+ }
+
+ result = rocfft_plan_description_destroy(description);
+ if (result != rocfft_status_success) {
+ amrex::AllPrint() << " rocfft_plan_description_destroy failed! \n";
+ }
#else
// Create FFTW plans.
fftw_iodim dims[1];
@@ -129,10 +175,13 @@ SpectralFieldDataRZ::~SpectralFieldDataRZ()
{
if (fields.size() > 0){
for (amrex::MFIter mfi(fields); mfi.isValid(); ++mfi){
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
// Destroy cuFFT plans.
cufftDestroy(forward_plan[mfi]);
// cufftDestroy(backward_plan[mfi]); // This was never allocated.
+#elif defined(AMREX_USE_HIP)
+ rocfft_plan_destroy(forward_plan[mfi]);
+ rocfft_plan_destroy(backward_plan[mfi]);
#else
// Destroy FFTW plans.
fftw_destroy_plan(forward_plan[mfi]);
@@ -168,7 +217,7 @@ SpectralFieldDataRZ::FABZForwardTransform (amrex::MFIter const & mfi, amrex::Box
});
// Perform Fourier transform from `tempHTransformed` to `tmpSpectralField`.
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
// Perform Fast Fourier Transform on GPU using cuFFT.
// Make sure that this is done on the same
// GPU stream as the above copy.
@@ -181,9 +230,30 @@ SpectralFieldDataRZ::FABZForwardTransform (amrex::MFIter const & mfi, amrex::Box
reinterpret_cast<cuDoubleComplex*>(tmpSpectralField[mfi].dataPtr(mode)), // cuDoubleComplex *out
CUFFT_FORWARD);
if (result != CUFFT_SUCCESS) {
- amrex::Print() << " forward transform using cufftExecZ2Z failed ! \n";
+ amrex::AllPrint() << " forward transform using cufftExecZ2Z failed ! \n";
+ }
+ }
+#elif defined(AMREX_USE_HIP)
+ rocfft_execution_info execinfo = NULL;
+ rocfft_status result = rocfft_execution_info_create(&execinfo);
+ std::size_t buffersize = 0;
+ result = rocfft_plan_get_work_buffer_size(forward_plan[mfi], &buffersize);
+ void* buffer = amrex::The_Arena()->alloc(buffersize);
+ result = rocfft_execution_info_set_work_buffer(execinfo, buffer, buffersize);
+ result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());
+
+ for (int mode=0 ; mode < n_rz_azimuthal_modes ; mode++) {
+ void* in_array[] = {(void*)(tempHTransformed[mfi].dataPtr(mode))};
+ void* out_array[] = {(void*)(tmpSpectralField[mfi].dataPtr(mode))};
+ result = rocfft_execute(forward_plan[mfi], in_array, out_array, execinfo);
+ if (result != rocfft_status_success) {
+ amrex::AllPrint() << " forward transform using rocfft_execute failed ! \n";
}
}
+
+ amrex::Gpu::streamSynchronize();
+ amrex::The_Arena()->free(buffer);
+ result = rocfft_execution_info_destroy(execinfo);
#else
fftw_execute(forward_plan[mfi]);
#endif
@@ -252,7 +322,7 @@ SpectralFieldDataRZ::FABZBackwardTransform (amrex::MFIter const & mfi, amrex::Bo
});
// Perform Fourier transform from `tmpSpectralField` to `tempHTransformed`.
-#ifdef AMREX_USE_GPU
+#if defined(AMREX_USE_CUDA)
// Perform Fast Fourier Transform on GPU using cuFFT.
// Make sure that this is done on the same
// GPU stream as the above copy.
@@ -265,9 +335,30 @@ SpectralFieldDataRZ::FABZBackwardTransform (amrex::MFIter const & mfi, amrex::Bo
reinterpret_cast<cuDoubleComplex*>(tempHTransformed[mfi].dataPtr(mode)), // cuDoubleComplex *out
CUFFT_INVERSE);
if (result != CUFFT_SUCCESS) {
- amrex::Print() << " backwardtransform using cufftExecZ2Z failed ! \n";
+ amrex::AllPrint() << " backwardtransform using cufftExecZ2Z failed ! \n";
}
}
+#elif defined(AMREX_USE_HIP)
+ rocfft_execution_info execinfo = NULL;
+ rocfft_status result = rocfft_execution_info_create(&execinfo);
+ std::size_t buffersize = 0;
+ result = rocfft_plan_get_work_buffer_size(forward_plan[mfi], &buffersize);
+ void* buffer = amrex::The_Arena()->alloc(buffersize);
+ result = rocfft_execution_info_set_work_buffer(execinfo, buffer, buffersize);
+ result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());
+
+ for (int mode=0 ; mode < n_rz_azimuthal_modes ; mode++) {
+ void* in_array[] = {(void*)(tmpSpectralField[mfi].dataPtr(mode))};
+ void* out_array[] = {(void*)(tempHTransformed[mfi].dataPtr(mode))};
+ result = rocfft_execute(backward_plan[mfi], in_array, out_array, execinfo);
+ if (result != rocfft_status_success) {
+ amrex::AllPrint() << " forward transform using rocfft_execute failed ! \n";
+ }
+ }
+
+ amrex::Gpu::streamSynchronize();
+ amrex::The_Arena()->free(buffer);
+ result = rocfft_execution_info_destroy(execinfo);
#else
fftw_execute(backward_plan[mfi]);
#endif
diff --git a/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp b/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp
new file mode 100644
index 000000000..54c96762a
--- /dev/null
+++ b/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp
@@ -0,0 +1,132 @@
+/* Copyright 2019-2020
+ *
+ * This file is part of WarpX.
+ *
+ * License: BSD-3-Clause-LBNL
+ */
+
+#include "AnyFFT.H"
+
+namespace AnyFFT
+{
+
+ std::string rocfftErrorToString (const rocfft_status err);
+
+ namespace {
+ void assert_rocfft_status (std::string const& name, rocfft_status status)
+ {
+ if (status != rocfft_status_success) {
+ amrex::Abort(name + " failed! Error: " + rocfftErrorToString(status));
+ }
+ }
+ }
+
+ FFTplan CreatePlan (const amrex::IntVect& real_size, amrex::Real * const real_array,
+ Complex * const complex_array, const direction dir, const int dim)
+ {
+ FFTplan fft_plan;
+
+ const std::size_t lengths[] = {AMREX_D_DECL(std::size_t(real_size[0]),
+ std::size_t(real_size[1]),
+ std::size_t(real_size[2]))};
+
+ // Initialize fft_plan.m_plan with the vendor fft plan.
+ rocfft_status result = rocfft_plan_create(&(fft_plan.m_plan),
+ rocfft_placement_notinplace,
+ (dir == direction::R2C)
+ ? rocfft_transform_type_real_forward
+ : rocfft_transform_type_real_inverse,
+#ifdef AMREX_USE_FLOAT
+ rocfft_precision_single,
+#else
+ rocfft_precision_double,
+#endif
+ dim, lengths,
+ 1, // number of transforms,
+ nullptr);
+ assert_rocfft_status("rocfft_plan_create", result);
+
+ // Store meta-data in fft_plan
+ fft_plan.m_real_array = real_array;
+ fft_plan.m_complex_array = complex_array;
+ fft_plan.m_dir = dir;
+ fft_plan.m_dim = dim;
+
+ return fft_plan;
+ }
+
+ void DestroyPlan (FFTplan& fft_plan)
+ {
+ rocfft_plan_destroy( fft_plan.m_plan );
+ }
+
+ void Execute (FFTplan& fft_plan)
+ {
+ rocfft_execution_info execinfo = NULL;
+ rocfft_status result = rocfft_execution_info_create(&execinfo);
+ assert_rocfft_status("rocfft_execution_info_create", result);
+
+ std::size_t buffersize = 0;
+ result = rocfft_plan_get_work_buffer_size(fft_plan.m_plan, &buffersize);
+ assert_rocfft_status("rocfft_plan_get_work_buffer_size", result);
+
+ void* buffer = amrex::The_Arena()->alloc(buffersize);
+ result = rocfft_execution_info_set_work_buffer(execinfo, buffer, buffersize);
+ assert_rocfft_status("rocfft_execution_info_set_work_buffer", result);
+
+ result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());
+ assert_rocfft_status("rocfft_execution_info_set_stream", result);
+
+ if (fft_plan.m_dir == direction::R2C) {
+ result = rocfft_execute(fft_plan.m_plan,
+ (void**)&(fft_plan.m_real_array), // in
+ (void**)&(fft_plan.m_complex_array), // out
+ execinfo);
+ } else if (fft_plan.m_dir == direction::C2R) {
+ result = rocfft_execute(fft_plan.m_plan,
+ (void**)&(fft_plan.m_complex_array), // in
+ (void**)&(fft_plan.m_real_array), // out
+ execinfo);
+ } else {
+ amrex::Abort("direction must be AnyFFT::direction::R2C or AnyFFT::direction::C2R");
+ }
+
+ assert_rocfft_status("rocfft_execute", result);
+
+ amrex::Gpu::streamSynchronize();
+
+ amrex::The_Arena()->free(buffer);
+
+ result = rocfft_execution_info_destroy(execinfo);
+ assert_rocfft_status("rocfft_execution_info_destroy", result);
+ }
+
+ /** \brief This method converts a rocfftResult
+ * into the corresponding string
+ *
+ * @param[in] err a rocfftResult
+ * @return an std::string
+ */
+ std::string rocfftErrorToString (const rocfft_status err)
+ {
+ if (err == rocfft_status_success) {
+ return std::string("rocfft_status_success");
+ } else if (err == rocfft_status_failure) {
+ return std::string("rocfft_status_failure");
+ } else if (err == rocfft_status_invalid_arg_value) {
+ return std::string("rocfft_status_invalid_arg_value");
+ } else if (err == rocfft_status_invalid_dimensions) {
+ return std::string("rocfft_status_invalid_dimensions");
+ } else if (err == rocfft_status_invalid_array_type) {
+ return std::string("rocfft_status_invalid_array_type");
+ } else if (err == rocfft_status_invalid_strides) {
+ return std::string("rocfft_status_invalid_strides");
+ } else if (err == rocfft_status_invalid_distance) {
+ return std::string("rocfft_status_invalid_distance");
+ } else if (err == rocfft_status_invalid_offset) {
+ return std::string("rocfft_status_invalid_offset");
+ } else {
+ return std::to_string(err) + " (unknown error code)";
+ }
+ }
+}