diff options
-rw-r--r-- | Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp | 4 | ||||
-rw-r--r-- | Source/FieldSolver/SpectralSolver/WrapFFTW.cpp | 12 | ||||
-rw-r--r-- | cmake/dependencies/FFT.cmake | 105 |
3 files changed, 90 insertions, 31 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp index bdb631063..056c030c0 100644 --- a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp +++ b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp @@ -145,6 +145,8 @@ SpectralFieldData::ForwardTransform (const int lev, #endif // Loop over boxes + // Note: we do NOT OpenMP parallelize here, since we use OpenMP threads for + // the FFTs on each box! for ( MFIter mfi(mf); mfi.isValid(); ++mfi ){ if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers) { @@ -247,6 +249,8 @@ SpectralFieldData::BackwardTransform( const int lev, #endif // Loop over boxes + // Note: we do NOT OpenMP parallelize here, since we use OpenMP threads for + // the iFFTs on each box! for ( MFIter mfi(mf); mfi.isValid(); ++mfi ){ if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers) { diff --git a/Source/FieldSolver/SpectralSolver/WrapFFTW.cpp b/Source/FieldSolver/SpectralSolver/WrapFFTW.cpp index a4dfc8b29..57a0fad04 100644 --- a/Source/FieldSolver/SpectralSolver/WrapFFTW.cpp +++ b/Source/FieldSolver/SpectralSolver/WrapFFTW.cpp @@ -1,4 +1,4 @@ -/* Copyright 2019-2020 +/* Copyright 2019-2021 * * This file is part of WarpX. * @@ -32,6 +32,16 @@ namespace AnyFFT { FFTplan fft_plan; +#if defined(AMREX_USE_OMP) && defined(WarpX_FFTW_OMP) +# ifdef AMREX_USE_FLOAT + fftwf_init_threads(); + fftwf_plan_with_nthreads(omp_get_max_threads()); +# else + fftw_init_threads(); + fftw_plan_with_nthreads(omp_get_max_threads()); +# endif +#endif + // Initialize fft_plan.m_plan with the vendor fft plan. // Swap dimensions: AMReX FAB are Fortran-order but FFTW is C-order if (dir == direction::R2C){ diff --git a/cmake/dependencies/FFT.cmake b/cmake/dependencies/FFT.cmake index 37b922c91..feb7b48ea 100644 --- a/cmake/dependencies/FFT.cmake +++ b/cmake/dependencies/FFT.cmake @@ -1,4 +1,54 @@ if(WarpX_PSATD) + # Helper Functions ############################################################ + # + option(WarpX_FFTW_IGNORE_OMP "Ignore FFTW3 OpenMP support, even if found" OFF) + mark_as_advanced(WarpX_FFTW_IGNORE_OMP) + + # Set the WarpX_FFTW_OMP=1 define on WarpX::thirdparty::FFT if TRUE and print + # a message + # + function(fftw_add_define HAS_FFTW_OMP_LIB) + if(HAS_FFTW_OMP_LIB) + message(STATUS "FFTW: Found OpenMP support") + target_compile_definitions(WarpX::thirdparty::FFT INTERFACE WarpX_FFTW_OMP=1) + else() + message(STATUS "FFTW: Could NOT find OpenMP support") + endif() + endfunction() + + # Check if the PkgConfig target location has an _omp library, e.g., + # libfftw3(f)_omp.a shipped and if yes, set the WarpX_FFTW_OMP=1 define. + # + function(fftw_check_omp library_paths fftw_precision_suffix) + if(WarpX_FFTW_IGNORE_OMP) + fftw_add_define(FALSE) + return() + endif() + + find_library(HAS_FFTW_OMP_LIB fftw3${fftw_precision_suffix}_omp + PATHS ${library_paths} + NO_DEFAULT_PATH + NO_PACKAGE_ROOT_PATH + NO_CMAKE_PATH + NO_CMAKE_ENVIRONMENT_PATH + NO_SYSTEM_ENVIRONMENT_PATH + NO_CMAKE_SYSTEM_PATH + NO_CMAKE_FIND_ROOT_PATH + ) + if(HAS_FFTW_OMP_LIB) + # the .pc files here forget to link the _omp.a/so files + # explicitly - we add those manually to avoid any trouble, + # e.g., in static builds. + target_link_libraries(WarpX::thirdparty::FFT INTERFACE ${HAS_FFTW_OMP_LIB}) + endif() + + fftw_add_define("${HAS_FFTW_OMP_LIB}") + endfunction() + + + # Various FFT implementations that we want to use ############################# + # + # cuFFT (CUDA) # TODO: check if `find_package` search works @@ -29,20 +79,18 @@ if(WarpX_PSATD) endif() mark_as_advanced(WarpX_FFTW_SEARCH) + # floating point precision suffixes: float, double and quad precision + if(WarpX_PRECISION STREQUAL "DOUBLE") + set(HFFTWp "") + else() + set(HFFTWp "f") + endif() + if(WarpX_FFTW_SEARCH STREQUAL CMAKE) - if(WarpX_PRECISION STREQUAL "DOUBLE") - find_package(FFTW3 CONFIG REQUIRED) - else() - find_package(FFTW3f CONFIG REQUIRED) - endif() + find_package(FFTW3${HFFTWp} CONFIG REQUIRED) else() - if(WarpX_PRECISION STREQUAL "DOUBLE") - find_package(PkgConfig REQUIRED QUIET) - pkg_check_modules(fftw3 REQUIRED IMPORTED_TARGET fftw3) - else() - find_package(PkgConfig REQUIRED QUIET) - pkg_check_modules(fftw3f REQUIRED IMPORTED_TARGET fftw3f) - endif() + find_package(PkgConfig REQUIRED QUIET) + pkg_check_modules(fftw3${HFFTWp} REQUIRED IMPORTED_TARGET fftw3${HFFTWp}) endif() endif() @@ -53,28 +101,25 @@ if(WarpX_PSATD) elseif(WarpX_COMPUTE STREQUAL HIP) make_third_party_includes_system(roc::rocfft FFT) else() - if(WarpX_PRECISION STREQUAL "DOUBLE") - if(FFTW3_FOUND) - # subtargets: fftw3, fftw3_threads, fftw3_omp - if(WarpX_COMPUTE STREQUAL OMP AND TARGET FFTW3::fftw3_omp) - make_third_party_includes_system(FFTW3::fftw3_omp FFT) - else() - make_third_party_includes_system(FFTW3::fftw3 FFT) - endif() + if(FFTW3_FOUND) + # subtargets: fftw3(p), fftw3(p)_threads, fftw3(p)_omp + if(WarpX_COMPUTE STREQUAL OMP AND + TARGET FFTW3::fftw3${HFFTWp}_omp AND + NOT WarpX_FFTW_IGNORE_OMP) + make_third_party_includes_system(FFTW3::fftw3${HFFTWp}_omp FFT) + fftw_add_define(TRUE) else() - make_third_party_includes_system(PkgConfig::fftw3 FFT) + make_third_party_includes_system(FFTW3::fftw3${HFFTWp} FFT) + fftw_add_define(FALSE) endif() else() - if(FFTW3f_FOUND) - # subtargets: fftw3f, fftw3f_threads, fftw3f_omp - if(WarpX_COMPUTE STREQUAL OMP AND TARGET FFTW3::fftw3f_omp) - make_third_party_includes_system(FFTW3::fftw3f_omp FFT) - else() - make_third_party_includes_system(FFTW3::fftw3f FFT) - endif() + make_third_party_includes_system(PkgConfig::fftw3${HFFTWp} FFT) + if(WarpX_COMPUTE STREQUAL OMP AND + NOT WarpX_FFTW_IGNORE_OMP) + fftw_check_omp("${fftw3${HFFTWp}_LIBRARY_DIRS}" "${HFFTWp}") else() - make_third_party_includes_system(PkgConfig::fftw3f FFT) + fftw_add_define(FALSE) endif() endif() endif() -endif() +endif(WarpX_PSATD) |