aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp4
-rw-r--r--Source/FieldSolver/SpectralSolver/WrapFFTW.cpp12
-rw-r--r--cmake/dependencies/FFT.cmake105
3 files changed, 90 insertions, 31 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
index bdb631063..056c030c0 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
@@ -145,6 +145,8 @@ SpectralFieldData::ForwardTransform (const int lev,
#endif
// Loop over boxes
+ // Note: we do NOT OpenMP parallelize here, since we use OpenMP threads for
+ // the FFTs on each box!
for ( MFIter mfi(mf); mfi.isValid(); ++mfi ){
if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
{
@@ -247,6 +249,8 @@ SpectralFieldData::BackwardTransform( const int lev,
#endif
// Loop over boxes
+ // Note: we do NOT OpenMP parallelize here, since we use OpenMP threads for
+ // the iFFTs on each box!
for ( MFIter mfi(mf); mfi.isValid(); ++mfi ){
if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
{
diff --git a/Source/FieldSolver/SpectralSolver/WrapFFTW.cpp b/Source/FieldSolver/SpectralSolver/WrapFFTW.cpp
index a4dfc8b29..57a0fad04 100644
--- a/Source/FieldSolver/SpectralSolver/WrapFFTW.cpp
+++ b/Source/FieldSolver/SpectralSolver/WrapFFTW.cpp
@@ -1,4 +1,4 @@
-/* Copyright 2019-2020
+/* Copyright 2019-2021
*
* This file is part of WarpX.
*
@@ -32,6 +32,16 @@ namespace AnyFFT
{
FFTplan fft_plan;
+#if defined(AMREX_USE_OMP) && defined(WarpX_FFTW_OMP)
+# ifdef AMREX_USE_FLOAT
+ fftwf_init_threads();
+ fftwf_plan_with_nthreads(omp_get_max_threads());
+# else
+ fftw_init_threads();
+ fftw_plan_with_nthreads(omp_get_max_threads());
+# endif
+#endif
+
// Initialize fft_plan.m_plan with the vendor fft plan.
// Swap dimensions: AMReX FAB are Fortran-order but FFTW is C-order
if (dir == direction::R2C){
diff --git a/cmake/dependencies/FFT.cmake b/cmake/dependencies/FFT.cmake
index 37b922c91..feb7b48ea 100644
--- a/cmake/dependencies/FFT.cmake
+++ b/cmake/dependencies/FFT.cmake
@@ -1,4 +1,54 @@
if(WarpX_PSATD)
+ # Helper Functions ############################################################
+ #
+ option(WarpX_FFTW_IGNORE_OMP "Ignore FFTW3 OpenMP support, even if found" OFF)
+ mark_as_advanced(WarpX_FFTW_IGNORE_OMP)
+
+ # Set the WarpX_FFTW_OMP=1 define on WarpX::thirdparty::FFT if TRUE and print
+ # a message
+ #
+ function(fftw_add_define HAS_FFTW_OMP_LIB)
+ if(HAS_FFTW_OMP_LIB)
+ message(STATUS "FFTW: Found OpenMP support")
+ target_compile_definitions(WarpX::thirdparty::FFT INTERFACE WarpX_FFTW_OMP=1)
+ else()
+ message(STATUS "FFTW: Could NOT find OpenMP support")
+ endif()
+ endfunction()
+
+ # Check if the PkgConfig target location has an _omp library, e.g.,
+ # libfftw3(f)_omp.a shipped and if yes, set the WarpX_FFTW_OMP=1 define.
+ #
+ function(fftw_check_omp library_paths fftw_precision_suffix)
+ if(WarpX_FFTW_IGNORE_OMP)
+ fftw_add_define(FALSE)
+ return()
+ endif()
+
+ find_library(HAS_FFTW_OMP_LIB fftw3${fftw_precision_suffix}_omp
+ PATHS ${library_paths}
+ NO_DEFAULT_PATH
+ NO_PACKAGE_ROOT_PATH
+ NO_CMAKE_PATH
+ NO_CMAKE_ENVIRONMENT_PATH
+ NO_SYSTEM_ENVIRONMENT_PATH
+ NO_CMAKE_SYSTEM_PATH
+ NO_CMAKE_FIND_ROOT_PATH
+ )
+ if(HAS_FFTW_OMP_LIB)
+ # the .pc files here forget to link the _omp.a/so files
+ # explicitly - we add those manually to avoid any trouble,
+ # e.g., in static builds.
+ target_link_libraries(WarpX::thirdparty::FFT INTERFACE ${HAS_FFTW_OMP_LIB})
+ endif()
+
+ fftw_add_define("${HAS_FFTW_OMP_LIB}")
+ endfunction()
+
+
+ # Various FFT implementations that we want to use #############################
+ #
+
# cuFFT (CUDA)
# TODO: check if `find_package` search works
@@ -29,20 +79,18 @@ if(WarpX_PSATD)
endif()
mark_as_advanced(WarpX_FFTW_SEARCH)
+ # floating point precision suffixes: float, double and quad precision
+ if(WarpX_PRECISION STREQUAL "DOUBLE")
+ set(HFFTWp "")
+ else()
+ set(HFFTWp "f")
+ endif()
+
if(WarpX_FFTW_SEARCH STREQUAL CMAKE)
- if(WarpX_PRECISION STREQUAL "DOUBLE")
- find_package(FFTW3 CONFIG REQUIRED)
- else()
- find_package(FFTW3f CONFIG REQUIRED)
- endif()
+ find_package(FFTW3${HFFTWp} CONFIG REQUIRED)
else()
- if(WarpX_PRECISION STREQUAL "DOUBLE")
- find_package(PkgConfig REQUIRED QUIET)
- pkg_check_modules(fftw3 REQUIRED IMPORTED_TARGET fftw3)
- else()
- find_package(PkgConfig REQUIRED QUIET)
- pkg_check_modules(fftw3f REQUIRED IMPORTED_TARGET fftw3f)
- endif()
+ find_package(PkgConfig REQUIRED QUIET)
+ pkg_check_modules(fftw3${HFFTWp} REQUIRED IMPORTED_TARGET fftw3${HFFTWp})
endif()
endif()
@@ -53,28 +101,25 @@ if(WarpX_PSATD)
elseif(WarpX_COMPUTE STREQUAL HIP)
make_third_party_includes_system(roc::rocfft FFT)
else()
- if(WarpX_PRECISION STREQUAL "DOUBLE")
- if(FFTW3_FOUND)
- # subtargets: fftw3, fftw3_threads, fftw3_omp
- if(WarpX_COMPUTE STREQUAL OMP AND TARGET FFTW3::fftw3_omp)
- make_third_party_includes_system(FFTW3::fftw3_omp FFT)
- else()
- make_third_party_includes_system(FFTW3::fftw3 FFT)
- endif()
+ if(FFTW3_FOUND)
+ # subtargets: fftw3(p), fftw3(p)_threads, fftw3(p)_omp
+ if(WarpX_COMPUTE STREQUAL OMP AND
+ TARGET FFTW3::fftw3${HFFTWp}_omp AND
+ NOT WarpX_FFTW_IGNORE_OMP)
+ make_third_party_includes_system(FFTW3::fftw3${HFFTWp}_omp FFT)
+ fftw_add_define(TRUE)
else()
- make_third_party_includes_system(PkgConfig::fftw3 FFT)
+ make_third_party_includes_system(FFTW3::fftw3${HFFTWp} FFT)
+ fftw_add_define(FALSE)
endif()
else()
- if(FFTW3f_FOUND)
- # subtargets: fftw3f, fftw3f_threads, fftw3f_omp
- if(WarpX_COMPUTE STREQUAL OMP AND TARGET FFTW3::fftw3f_omp)
- make_third_party_includes_system(FFTW3::fftw3f_omp FFT)
- else()
- make_third_party_includes_system(FFTW3::fftw3f FFT)
- endif()
+ make_third_party_includes_system(PkgConfig::fftw3${HFFTWp} FFT)
+ if(WarpX_COMPUTE STREQUAL OMP AND
+ NOT WarpX_FFTW_IGNORE_OMP)
+ fftw_check_omp("${fftw3${HFFTWp}_LIBRARY_DIRS}" "${HFFTWp}")
else()
- make_third_party_includes_system(PkgConfig::fftw3f FFT)
+ fftw_add_define(FALSE)
endif()
endif()
endif()
-endif()
+endif(WarpX_PSATD)