aboutsummaryrefslogtreecommitdiff
path: root/Source/FieldSolver/SpectralSolver/SpectralHankelTransform
diff options
context:
space:
mode:
authorGravatar Remi Lehe <remi.lehe@normalesup.org> 2022-09-15 15:15:09 -0700
committerGravatar GitHub <noreply@github.com> 2022-09-15 15:15:09 -0700
commitac2521aa4c23e999715931331a817c41e416ddd2 (patch)
tree059810af6a77e9e3e2a65b670e58368748b36e81 /Source/FieldSolver/SpectralSolver/SpectralHankelTransform
parent04b6f67caab8ee95428a569c529110a12b2527f3 (diff)
downloadWarpX-ac2521aa4c23e999715931331a817c41e416ddd2.tar.gz
WarpX-ac2521aa4c23e999715931331a817c41e416ddd2.tar.zst
WarpX-ac2521aa4c23e999715931331a817c41e416ddd2.zip
Use blaspp::gemm on GPU for Hankel transform (#3383)
* Use gemm on GPU for Hankel transform * Add stream synchronization * Add `amrex` * blas::gemm call: add `queue` with device id * CMake: BLAS++ Missing Deps * Update installation instructions for Summit * CMake: BLAS++ should not need curand * Add paths to blaspp/lapackpp * Move Queue Constructor to Constructor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Correct name of profiled area * Use gemm for inverse Hankel transform * Add missing comma * Clean up the code so that it compiles for CPU * Clean up code ; update documentation * Update Comment Co-authored-by: Remi Lehe <remi.lehe@normalesup.org> * Update Tools/machines/summit-olcf/summit_warpx.profile.example Co-authored-by: Axel Huebl <axel.huebl@plasma.ninja> * Add stream synchronization * Switch to streamsynchronize * Update comments Co-authored-by: Axel Huebl <axel.huebl@plasma.ninja> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Diffstat (limited to 'Source/FieldSolver/SpectralSolver/SpectralHankelTransform')
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.H12
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp85
2 files changed, 49 insertions, 48 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.H b/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.H
index 5a87031c4..3f74f5d82 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.H
+++ b/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.H
@@ -8,6 +8,14 @@
#define WARPX_HANKEL_TRANSFORM_H_
#include <AMReX_FArrayBox.H>
+#include <AMReX_REAL.H>
+#include <AMReX_GpuContainers.H>
+
+#ifdef AMREX_USE_GPU
+# include <blas.hh>
+#endif
+
+#include <memory>
/* \brief This defines the class that performs the Hankel transform.
* Original authors: Remi Lehe, Manuel Kirchen
@@ -45,6 +53,10 @@ class HankelTransform
RealVector m_invM;
RealVector m_M;
+
+#ifdef AMREX_USE_GPU
+ std::unique_ptr<blas::Queue> m_queue;
+#endif
};
#endif
diff --git a/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp b/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp
index 43b26f2ee..ddd07acad 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp
@@ -11,6 +11,8 @@
#include "Utils/WarpXConst.H"
#include "WarpX.H"
+#include "Utils/WarpXProfilerWrapper.H"
+
#include <blas.hh>
#include <lapack.hh>
@@ -23,10 +25,20 @@ HankelTransform::HankelTransform (int const hankel_order,
: m_nr(nr), m_nk(nr)
{
+ WARPX_PROFILE("HankelTransform::HankelTransform");
+
// Check that azimuthal_mode has a valid value
WARPX_ALWAYS_ASSERT_WITH_MESSAGE(hankel_order-1 <= azimuthal_mode && azimuthal_mode <= hankel_order+1,
"azimuthal_mode must be either hankel_order-1, hankel_order or hankel_order+1");
+#ifdef AMREX_USE_GPU
+ // BLAS setup
+ // SYCL note: we need to double check AMReX device ID conventions and
+ // BLAS++ device ID conventions are the same
+ int const device_id = amrex::Gpu::Device::deviceId();
+ m_queue = std::make_unique<blas::Queue>( device_id, 0 );
+#endif
+
amrex::Vector<amrex::Real> alphas;
amrex::Vector<int> alpha_errors;
@@ -186,6 +198,8 @@ void
HankelTransform::HankelForwardTransform (amrex::FArrayBox const& F, int const F_icomp,
amrex::FArrayBox & G, int const G_icomp)
{
+ WARPX_PROFILE("HankelTransform::HankelForwardTransform");
+
amrex::Box const& F_box = F.box();
amrex::Box const& G_box = G.box();
@@ -198,37 +212,24 @@ HankelTransform::HankelForwardTransform (amrex::FArrayBox const& F, int const F_
AMREX_ALWAYS_ASSERT(ngr >= 0);
AMREX_ALWAYS_ASSERT(F_box.bigEnd(0)+1 >= m_nr);
-#ifndef AMREX_USE_GPU
- // On CPU, the blas::gemm is significantly faster
+ // We perform stream synchronization since `gemm` may be running
+ // on a different stream.
+ amrex::Gpu::streamSynchronize();
// Note that M is flagged to be transposed since it has dimensions (m_nr, m_nk)
blas::gemm(blas::Layout::ColMajor, blas::Op::Trans, blas::Op::NoTrans,
m_nk, nz, m_nr, 1._rt,
m_M.dataPtr(), m_nk,
F.dataPtr(F_icomp)+ngr, nrF, 0._rt,
- G.dataPtr(G_icomp), m_nk);
-
-#else
- // On GPU, the explicit loop is significantly faster
- // It is not clear if the GPU gemm wasn't build properly, it is cycling data out and back
- // in to the device, or if it is because gemm is launching its own threads.
-
- amrex::Real const * M_arr = m_M.dataPtr();
- amrex::Array4<const amrex::Real> const & F_arr = F.array();
- amrex::Array4< amrex::Real> const & G_arr = G.array();
-
- int const nr = m_nr;
-
- amrex::ParallelFor(G_box,
- [=] AMREX_GPU_DEVICE(int ik, int iz, int k3d) noexcept {
- G_arr(ik,iz,k3d,G_icomp) = 0.;
- for (int ir=0 ; ir < nr ; ir++) {
- int const ii = ir + ik*nr;
- G_arr(ik,iz,k3d,G_icomp) += M_arr[ii]*F_arr(ir,iz,k3d,F_icomp);
- }
- });
-
+ G.dataPtr(G_icomp), m_nk
+#ifdef AMREX_USE_GPU
+ , *m_queue // Calls the GPU version of blas::gemm
#endif
+ );
+
+ // We perform stream synchronization since `gemm` may be running
+ // on a different stream.
+ amrex::Gpu::streamSynchronize();
}
@@ -236,6 +237,8 @@ void
HankelTransform::HankelInverseTransform (amrex::FArrayBox const& G, int const G_icomp,
amrex::FArrayBox & F, int const F_icomp)
{
+ WARPX_PROFILE("HankelTransform::HankelInverseTransform");
+
amrex::Box const& G_box = G.box();
amrex::Box const& F_box = F.box();
@@ -248,36 +251,22 @@ HankelTransform::HankelInverseTransform (amrex::FArrayBox const& G, int const G_
AMREX_ALWAYS_ASSERT(ngr >= 0);
AMREX_ALWAYS_ASSERT(F_box.bigEnd(0)+1 >= m_nr);
-#ifndef AMREX_USE_GPU
- // On CPU, the blas::gemm is significantly faster
+ // We perform stream synchronization since `gemm` may be running
+ // on a different stream.
+ amrex::Gpu::streamSynchronize();
// Note that m_invM is flagged to be transposed since it has dimensions (m_nk, m_nr)
blas::gemm(blas::Layout::ColMajor, blas::Op::Trans, blas::Op::NoTrans,
m_nr, nz, m_nk, 1._rt,
m_invM.dataPtr(), m_nr,
G.dataPtr(G_icomp), m_nk, 0._rt,
- F.dataPtr(F_icomp)+ngr, nrF);
-
-#else
- // On GPU, the explicit loop is significantly faster
- // It is not clear if the GPU gemm wasn't build properly, it is cycling data out and back
- // in to the device, or if it is because gemm is launching its own threads.
-
- amrex::Real const * invM_arr = m_invM.dataPtr();
- amrex::Array4<const amrex::Real> const & G_arr = G.array();
- amrex::Array4< amrex::Real> const & F_arr = F.array();
-
- int const nk = m_nk;
-
- amrex::ParallelFor(G_box,
- [=] AMREX_GPU_DEVICE(int ir, int iz, int k3d) noexcept {
- F_arr(ir,iz,k3d,F_icomp) = 0.;
- for (int ik=0 ; ik < nk ; ik++) {
- int const ii = ik + ir*nk;
- F_arr(ir,iz,k3d,F_icomp) += invM_arr[ii]*G_arr(ik,iz,k3d,G_icomp);
- }
- });
-
+ F.dataPtr(F_icomp)+ngr, nrF
+#ifdef AMREX_USE_GPU
+ , *m_queue // Calls the GPU version of blas::gemm
#endif
+ );
+ // We perform stream synchronization since `gemm` may be running
+ // on a different stream.
+ amrex::Gpu::streamSynchronize();
}