diff options
author | 2020-10-07 09:23:53 -0700 | |
---|---|---|
committer | 2020-10-07 09:23:53 -0700 | |
commit | 530bbda9f58a7909f3dc2d0676a680e7669a4976 (patch) | |
tree | 6f3eddf0b55e81e6e4ffc2abe3a676a61fa6db2f /Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp | |
parent | 49ed40b5610705c7f587fdad7c33349df4f7a878 (diff) | |
download | WarpX-530bbda9f58a7909f3dc2d0676a680e7669a4976.tar.gz WarpX-530bbda9f58a7909f3dc2d0676a680e7669a4976.tar.zst WarpX-530bbda9f58a7909f3dc2d0676a680e7669a4976.zip |
rocFFT support (#1410)
* rocFFT support
* rocfft in 2d rz PSATD solver
Diffstat (limited to 'Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp')
-rw-r--r-- | Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp b/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp new file mode 100644 index 000000000..54c96762a --- /dev/null +++ b/Source/FieldSolver/SpectralSolver/WrapRocFFT.cpp @@ -0,0 +1,132 @@ +/* Copyright 2019-2020 + * + * This file is part of WarpX. + * + * License: BSD-3-Clause-LBNL + */ + +#include "AnyFFT.H" + +namespace AnyFFT +{ + + std::string rocfftErrorToString (const rocfft_status err); + + namespace { + void assert_rocfft_status (std::string const& name, rocfft_status status) + { + if (status != rocfft_status_success) { + amrex::Abort(name + " failed! Error: " + rocfftErrorToString(status)); + } + } + } + + FFTplan CreatePlan (const amrex::IntVect& real_size, amrex::Real * const real_array, + Complex * const complex_array, const direction dir, const int dim) + { + FFTplan fft_plan; + + const std::size_t lengths[] = {AMREX_D_DECL(std::size_t(real_size[0]), + std::size_t(real_size[1]), + std::size_t(real_size[2]))}; + + // Initialize fft_plan.m_plan with the vendor fft plan. + rocfft_status result = rocfft_plan_create(&(fft_plan.m_plan), + rocfft_placement_notinplace, + (dir == direction::R2C) + ? rocfft_transform_type_real_forward + : rocfft_transform_type_real_inverse, +#ifdef AMREX_USE_FLOAT + rocfft_precision_single, +#else + rocfft_precision_double, +#endif + dim, lengths, + 1, // number of transforms, + nullptr); + assert_rocfft_status("rocfft_plan_create", result); + + // Store meta-data in fft_plan + fft_plan.m_real_array = real_array; + fft_plan.m_complex_array = complex_array; + fft_plan.m_dir = dir; + fft_plan.m_dim = dim; + + return fft_plan; + } + + void DestroyPlan (FFTplan& fft_plan) + { + rocfft_plan_destroy( fft_plan.m_plan ); + } + + void Execute (FFTplan& fft_plan) + { + rocfft_execution_info execinfo = NULL; + rocfft_status result = rocfft_execution_info_create(&execinfo); + assert_rocfft_status("rocfft_execution_info_create", result); + + std::size_t buffersize = 0; + result = rocfft_plan_get_work_buffer_size(fft_plan.m_plan, &buffersize); + assert_rocfft_status("rocfft_plan_get_work_buffer_size", result); + + void* buffer = amrex::The_Arena()->alloc(buffersize); + result = rocfft_execution_info_set_work_buffer(execinfo, buffer, buffersize); + assert_rocfft_status("rocfft_execution_info_set_work_buffer", result); + + result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream()); + assert_rocfft_status("rocfft_execution_info_set_stream", result); + + if (fft_plan.m_dir == direction::R2C) { + result = rocfft_execute(fft_plan.m_plan, + (void**)&(fft_plan.m_real_array), // in + (void**)&(fft_plan.m_complex_array), // out + execinfo); + } else if (fft_plan.m_dir == direction::C2R) { + result = rocfft_execute(fft_plan.m_plan, + (void**)&(fft_plan.m_complex_array), // in + (void**)&(fft_plan.m_real_array), // out + execinfo); + } else { + amrex::Abort("direction must be AnyFFT::direction::R2C or AnyFFT::direction::C2R"); + } + + assert_rocfft_status("rocfft_execute", result); + + amrex::Gpu::streamSynchronize(); + + amrex::The_Arena()->free(buffer); + + result = rocfft_execution_info_destroy(execinfo); + assert_rocfft_status("rocfft_execution_info_destroy", result); + } + + /** \brief This method converts a rocfftResult + * into the corresponding string + * + * @param[in] err a rocfftResult + * @return an std::string + */ + std::string rocfftErrorToString (const rocfft_status err) + { + if (err == rocfft_status_success) { + return std::string("rocfft_status_success"); + } else if (err == rocfft_status_failure) { + return std::string("rocfft_status_failure"); + } else if (err == rocfft_status_invalid_arg_value) { + return std::string("rocfft_status_invalid_arg_value"); + } else if (err == rocfft_status_invalid_dimensions) { + return std::string("rocfft_status_invalid_dimensions"); + } else if (err == rocfft_status_invalid_array_type) { + return std::string("rocfft_status_invalid_array_type"); + } else if (err == rocfft_status_invalid_strides) { + return std::string("rocfft_status_invalid_strides"); + } else if (err == rocfft_status_invalid_distance) { + return std::string("rocfft_status_invalid_distance"); + } else if (err == rocfft_status_invalid_offset) { + return std::string("rocfft_status_invalid_offset"); + } else { + return std::to_string(err) + " (unknown error code)"; + } + } +} |