aboutsummaryrefslogtreecommitdiff
path: root/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp')
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp85
1 files changed, 37 insertions, 48 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp b/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp
index 43b26f2ee..ddd07acad 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp
@@ -11,6 +11,8 @@
#include "Utils/WarpXConst.H"
#include "WarpX.H"
+#include "Utils/WarpXProfilerWrapper.H"
+
#include <blas.hh>
#include <lapack.hh>
@@ -23,10 +25,20 @@ HankelTransform::HankelTransform (int const hankel_order,
: m_nr(nr), m_nk(nr)
{
+ WARPX_PROFILE("HankelTransform::HankelTransform");
+
// Check that azimuthal_mode has a valid value
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(hankel_order-1 <= azimuthal_mode && azimuthal_mode <= hankel_order+1,
"azimuthal_mode must be either hankel_order-1, hankel_order or hankel_order+1");
+#ifdef AMREX_USE_GPU
+ // BLAS setup
+ // SYCL note: we need to double check AMReX device ID conventions and
+ // BLAS++ device ID conventions are the same
+ int const device_id = amrex::Gpu::Device::deviceId();
+ m_queue = std::make_unique<blas::Queue>( device_id, 0 );
+#endif
+
amrex::Vector<amrex::Real> alphas;
amrex::Vector<int> alpha_errors;
@@ -186,6 +198,8 @@ void
HankelTransform::HankelForwardTransform (amrex::FArrayBox const& F, int const F_icomp,
amrex::FArrayBox & G, int const G_icomp)
{
+ WARPX_PROFILE("HankelTransform::HankelForwardTransform");
+
amrex::Box const& F_box = F.box();
amrex::Box const& G_box = G.box();
@@ -198,37 +212,24 @@ HankelTransform::HankelForwardTransform (amrex::FArrayBox const& F, int const F_
AMREX_ALWAYS_ASSERT(ngr >= 0);
AMREX_ALWAYS_ASSERT(F_box.bigEnd(0)+1 >= m_nr);
-#ifndef AMREX_USE_GPU
- // On CPU, the blas::gemm is significantly faster
+ // We perform stream synchronization since `gemm` may be running
+ // on a different stream.
+ amrex::Gpu::streamSynchronize();
// Note that M is flagged to be transposed since it has dimensions (m_nr, m_nk)
blas::gemm(blas::Layout::ColMajor, blas::Op::Trans, blas::Op::NoTrans,
m_nk, nz, m_nr, 1._rt,
m_M.dataPtr(), m_nk,
F.dataPtr(F_icomp)+ngr, nrF, 0._rt,
- G.dataPtr(G_icomp), m_nk);
-
-#else
- // On GPU, the explicit loop is significantly faster
- // It is not clear if the GPU gemm wasn't build properly, it is cycling data out and back
- // in to the device, or if it is because gemm is launching its own threads.
-
- amrex::Real const * M_arr = m_M.dataPtr();
- amrex::Array4<const amrex::Real> const & F_arr = F.array();
- amrex::Array4< amrex::Real> const & G_arr = G.array();
-
- int const nr = m_nr;
-
- amrex::ParallelFor(G_box,
- [=] AMREX_GPU_DEVICE(int ik, int iz, int k3d) noexcept {
- G_arr(ik,iz,k3d,G_icomp) = 0.;
- for (int ir=0 ; ir < nr ; ir++) {
- int const ii = ir + ik*nr;
- G_arr(ik,iz,k3d,G_icomp) += M_arr[ii]*F_arr(ir,iz,k3d,F_icomp);
- }
- });
-
+ G.dataPtr(G_icomp), m_nk
+#ifdef AMREX_USE_GPU
+ , *m_queue // Calls the GPU version of blas::gemm
#endif
+ );
+
+ // We perform stream synchronization since `gemm` may be running
+ // on a different stream.
+ amrex::Gpu::streamSynchronize();
}
@@ -236,6 +237,8 @@ void
HankelTransform::HankelInverseTransform (amrex::FArrayBox const& G, int const G_icomp,
amrex::FArrayBox & F, int const F_icomp)
{
+ WARPX_PROFILE("HankelTransform::HankelInverseTransform");
+
amrex::Box const& G_box = G.box();
amrex::Box const& F_box = F.box();
@@ -248,36 +251,22 @@ HankelTransform::HankelInverseTransform (amrex::FArrayBox const& G, int const G_
AMREX_ALWAYS_ASSERT(ngr >= 0);
AMREX_ALWAYS_ASSERT(F_box.bigEnd(0)+1 >= m_nr);
-#ifndef AMREX_USE_GPU
- // On CPU, the blas::gemm is significantly faster
+ // We perform stream synchronization since `gemm` may be running
+ // on a different stream.
+ amrex::Gpu::streamSynchronize();
// Note that m_invM is flagged to be transposed since it has dimensions (m_nk, m_nr)
blas::gemm(blas::Layout::ColMajor, blas::Op::Trans, blas::Op::NoTrans,
m_nr, nz, m_nk, 1._rt,
m_invM.dataPtr(), m_nr,
G.dataPtr(G_icomp), m_nk, 0._rt,
- F.dataPtr(F_icomp)+ngr, nrF);
-
-#else
- // On GPU, the explicit loop is significantly faster
- // It is not clear if the GPU gemm wasn't build properly, it is cycling data out and back
- // in to the device, or if it is because gemm is launching its own threads.
-
- amrex::Real const * invM_arr = m_invM.dataPtr();
- amrex::Array4<const amrex::Real> const & G_arr = G.array();
- amrex::Array4< amrex::Real> const & F_arr = F.array();
-
- int const nk = m_nk;
-
- amrex::ParallelFor(G_box,
- [=] AMREX_GPU_DEVICE(int ir, int iz, int k3d) noexcept {
- F_arr(ir,iz,k3d,F_icomp) = 0.;
- for (int ik=0 ; ik < nk ; ik++) {
- int const ii = ik + ir*nk;
- F_arr(ir,iz,k3d,F_icomp) += invM_arr[ii]*G_arr(ik,iz,k3d,G_icomp);
- }
- });
-
+ F.dataPtr(F_icomp)+ngr, nrF
+#ifdef AMREX_USE_GPU
+ , *m_queue // Calls the GPU version of blas::gemm
#endif
+ );
+ // We perform stream synchronization since `gemm` may be running
+ // on a different stream.
+ amrex::Gpu::streamSynchronize();
}