diff options
Diffstat (limited to 'Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp')
-rw-r--r-- | Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp | 85 |
1 files changed, 37 insertions, 48 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp b/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp index 43b26f2ee..ddd07acad 100644 --- a/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp +++ b/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp @@ -11,6 +11,8 @@ #include "Utils/WarpXConst.H" #include "WarpX.H" +#include "Utils/WarpXProfilerWrapper.H" + #include <blas.hh> #include <lapack.hh> @@ -23,10 +25,20 @@ HankelTransform::HankelTransform (int const hankel_order, : m_nr(nr), m_nk(nr) { + WARPX_PROFILE("HankelTransform::HankelTransform"); + // Check that azimuthal_mode has a valid value WARPX_ALWAYS_ASSERT_WITH_MESSAGE(hankel_order-1 <= azimuthal_mode && azimuthal_mode <= hankel_order+1, "azimuthal_mode must be either hankel_order-1, hankel_order or hankel_order+1"); +#ifdef AMREX_USE_GPU + // BLAS setup + // SYCL note: we need to double check AMReX device ID conventions and + // BLAS++ device ID conventions are the same + int const device_id = amrex::Gpu::Device::deviceId(); + m_queue = std::make_unique<blas::Queue>( device_id, 0 ); +#endif + amrex::Vector<amrex::Real> alphas; amrex::Vector<int> alpha_errors; @@ -186,6 +198,8 @@ void HankelTransform::HankelForwardTransform (amrex::FArrayBox const& F, int const F_icomp, amrex::FArrayBox & G, int const G_icomp) { + WARPX_PROFILE("HankelTransform::HankelForwardTransform"); + amrex::Box const& F_box = F.box(); amrex::Box const& G_box = G.box(); @@ -198,37 +212,24 @@ HankelTransform::HankelForwardTransform (amrex::FArrayBox const& F, int const F_ AMREX_ALWAYS_ASSERT(ngr >= 0); AMREX_ALWAYS_ASSERT(F_box.bigEnd(0)+1 >= m_nr); -#ifndef AMREX_USE_GPU - // On CPU, the blas::gemm is significantly faster + // We perform stream synchronization since `gemm` may be running + // on a different stream. + amrex::Gpu::streamSynchronize(); // Note that M is flagged to be transposed since it has dimensions (m_nr, m_nk) blas::gemm(blas::Layout::ColMajor, blas::Op::Trans, blas::Op::NoTrans, m_nk, nz, m_nr, 1._rt, m_M.dataPtr(), m_nk, F.dataPtr(F_icomp)+ngr, nrF, 0._rt, - G.dataPtr(G_icomp), m_nk); - -#else - // On GPU, the explicit loop is significantly faster - // It is not clear if the GPU gemm wasn't build properly, it is cycling data out and back - // in to the device, or if it is because gemm is launching its own threads. - - amrex::Real const * M_arr = m_M.dataPtr(); - amrex::Array4<const amrex::Real> const & F_arr = F.array(); - amrex::Array4< amrex::Real> const & G_arr = G.array(); - - int const nr = m_nr; - - amrex::ParallelFor(G_box, - [=] AMREX_GPU_DEVICE(int ik, int iz, int k3d) noexcept { - G_arr(ik,iz,k3d,G_icomp) = 0.; - for (int ir=0 ; ir < nr ; ir++) { - int const ii = ir + ik*nr; - G_arr(ik,iz,k3d,G_icomp) += M_arr[ii]*F_arr(ir,iz,k3d,F_icomp); - } - }); - + G.dataPtr(G_icomp), m_nk +#ifdef AMREX_USE_GPU + , *m_queue // Calls the GPU version of blas::gemm #endif + ); + + // We perform stream synchronization since `gemm` may be running + // on a different stream. + amrex::Gpu::streamSynchronize(); } @@ -236,6 +237,8 @@ void HankelTransform::HankelInverseTransform (amrex::FArrayBox const& G, int const G_icomp, amrex::FArrayBox & F, int const F_icomp) { + WARPX_PROFILE("HankelTransform::HankelInverseTransform"); + amrex::Box const& G_box = G.box(); amrex::Box const& F_box = F.box(); @@ -248,36 +251,22 @@ HankelTransform::HankelInverseTransform (amrex::FArrayBox const& G, int const G_ AMREX_ALWAYS_ASSERT(ngr >= 0); AMREX_ALWAYS_ASSERT(F_box.bigEnd(0)+1 >= m_nr); -#ifndef AMREX_USE_GPU - // On CPU, the blas::gemm is significantly faster + // We perform stream synchronization since `gemm` may be running + // on a different stream. + amrex::Gpu::streamSynchronize(); // Note that m_invM is flagged to be transposed since it has dimensions (m_nk, m_nr) blas::gemm(blas::Layout::ColMajor, blas::Op::Trans, blas::Op::NoTrans, m_nr, nz, m_nk, 1._rt, m_invM.dataPtr(), m_nr, G.dataPtr(G_icomp), m_nk, 0._rt, - F.dataPtr(F_icomp)+ngr, nrF); - -#else - // On GPU, the explicit loop is significantly faster - // It is not clear if the GPU gemm wasn't build properly, it is cycling data out and back - // in to the device, or if it is because gemm is launching its own threads. - - amrex::Real const * invM_arr = m_invM.dataPtr(); - amrex::Array4<const amrex::Real> const & G_arr = G.array(); - amrex::Array4< amrex::Real> const & F_arr = F.array(); - - int const nk = m_nk; - - amrex::ParallelFor(G_box, - [=] AMREX_GPU_DEVICE(int ir, int iz, int k3d) noexcept { - F_arr(ir,iz,k3d,F_icomp) = 0.; - for (int ik=0 ; ik < nk ; ik++) { - int const ii = ik + ir*nk; - F_arr(ir,iz,k3d,F_icomp) += invM_arr[ii]*G_arr(ik,iz,k3d,G_icomp); - } - }); - + F.dataPtr(F_icomp)+ngr, nrF +#ifdef AMREX_USE_GPU + , *m_queue // Calls the GPU version of blas::gemm #endif + ); + // We perform stream synchronization since `gemm` may be running + // on a different stream. + amrex::Gpu::streamSynchronize(); } |