aboutsummaryrefslogtreecommitdiff
path: root/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp')
-rw-r--r--Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp132
1 files changed, 132 insertions, 0 deletions
diff --git a/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp b/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp
new file mode 100644
index 000000000..54c96762a
--- /dev/null
+++ b/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp
@@ -0,0 +1,132 @@
+/* Copyright 2019-2020
+ *
+ * This file is part of WarpX.
+ *
+ * License: BSD-3-Clause-LBNL
+ */
+
+#include "AnyFFT.H"
+
+namespace AnyFFT
+{
+
+ std::string rocfftErrorToString (const rocfft_status err);
+
+ namespace {
+ void assert_rocfft_status (std::string const& name, rocfft_status status)
+ {
+ if (status != rocfft_status_success) {
+ amrex::Abort(name + " failed! Error: " + rocfftErrorToString(status));
+ }
+ }
+ }
+
+ FFTplan CreatePlan (const amrex::IntVect& real_size, amrex::Real * const real_array,
+ Complex * const complex_array, const direction dir, const int dim)
+ {
+ FFTplan fft_plan;
+
+ const std::size_t lengths[] = {AMREX_D_DECL(std::size_t(real_size[0]),
+ std::size_t(real_size[1]),
+ std::size_t(real_size[2]))};
+
+ // Initialize fft_plan.m_plan with the vendor fft plan.
+ rocfft_status result = rocfft_plan_create(&(fft_plan.m_plan),
+ rocfft_placement_notinplace,
+ (dir == direction::R2C)
+ ? rocfft_transform_type_real_forward
+ : rocfft_transform_type_real_inverse,
+#ifdef AMREX_USE_FLOAT
+ rocfft_precision_single,
+#else
+ rocfft_precision_double,
+#endif
+ dim, lengths,
+ 1, // number of transforms,
+ nullptr);
+ assert_rocfft_status("rocfft_plan_create", result);
+
+ // Store meta-data in fft_plan
+ fft_plan.m_real_array = real_array;
+ fft_plan.m_complex_array = complex_array;
+ fft_plan.m_dir = dir;
+ fft_plan.m_dim = dim;
+
+ return fft_plan;
+ }
+
+ void DestroyPlan (FFTplan& fft_plan)
+ {
+ rocfft_plan_destroy( fft_plan.m_plan );
+ }
+
+ void Execute (FFTplan& fft_plan)
+ {
+ rocfft_execution_info execinfo = NULL;
+ rocfft_status result = rocfft_execution_info_create(&execinfo);
+ assert_rocfft_status("rocfft_execution_info_create", result);
+
+ std::size_t buffersize = 0;
+ result = rocfft_plan_get_work_buffer_size(fft_plan.m_plan, &buffersize);
+ assert_rocfft_status("rocfft_plan_get_work_buffer_size", result);
+
+ void* buffer = amrex::The_Arena()->alloc(buffersize);
+ result = rocfft_execution_info_set_work_buffer(execinfo, buffer, buffersize);
+ assert_rocfft_status("rocfft_execution_info_set_work_buffer", result);
+
+ result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());
+ assert_rocfft_status("rocfft_execution_info_set_stream", result);
+
+ if (fft_plan.m_dir == direction::R2C) {
+ result = rocfft_execute(fft_plan.m_plan,
+ (void**)&(fft_plan.m_real_array), // in
+ (void**)&(fft_plan.m_complex_array), // out
+ execinfo);
+ } else if (fft_plan.m_dir == direction::C2R) {
+ result = rocfft_execute(fft_plan.m_plan,
+ (void**)&(fft_plan.m_complex_array), // in
+ (void**)&(fft_plan.m_real_array), // out
+ execinfo);
+ } else {
+ amrex::Abort("direction must be AnyFFT::direction::R2C or AnyFFT::direction::C2R");
+ }
+
+ assert_rocfft_status("rocfft_execute", result);
+
+ amrex::Gpu::streamSynchronize();
+
+ amrex::The_Arena()->free(buffer);
+
+ result = rocfft_execution_info_destroy(execinfo);
+ assert_rocfft_status("rocfft_execution_info_destroy", result);
+ }
+
+ /** \brief This method converts a rocfftResult
+ * into the corresponding string
+ *
+ * @param[in] err a rocfftResult
+ * @return an std::string
+ */
+ std::string rocfftErrorToString (const rocfft_status err)
+ {
+ if (err == rocfft_status_success) {
+ return std::string("rocfft_status_success");
+ } else if (err == rocfft_status_failure) {
+ return std::string("rocfft_status_failure");
+ } else if (err == rocfft_status_invalid_arg_value) {
+ return std::string("rocfft_status_invalid_arg_value");
+ } else if (err == rocfft_status_invalid_dimensions) {
+ return std::string("rocfft_status_invalid_dimensions");
+ } else if (err == rocfft_status_invalid_array_type) {
+ return std::string("rocfft_status_invalid_array_type");
+ } else if (err == rocfft_status_invalid_strides) {
+ return std::string("rocfft_status_invalid_strides");
+ } else if (err == rocfft_status_invalid_distance) {
+ return std::string("rocfft_status_invalid_distance");
+ } else if (err == rocfft_status_invalid_offset) {
+ return std::string("rocfft_status_invalid_offset");
+ } else {
+ return std::to_string(err) + " (unknown error code)";
+ }
+ }
+}