diff options
Diffstat (limited to 'Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp')
-rw-r--r-- | Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp | 51 |
1 files changed, 30 insertions, 21 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp b/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp index c5249d54f..24e9d7f81 100644 --- a/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp +++ b/Source/FieldSolver/SpectralSolver/SpectralHankelTransform/HankelTransform.cpp @@ -25,7 +25,7 @@ HankelTransform::HankelTransform (int const hankel_order, AMREX_ALWAYS_ASSERT_WITH_MESSAGE(hankel_order-1 <= azimuthal_mode && azimuthal_mode <= hankel_order+1, "azimuthal_mode must be either hankel_order-1, hankel_order or hankel_order+1"); - RealVector alphas; + amrex::Vector<amrex::Real> alphas; amrex::Vector<int> alpha_errors; GetBesselRoots(azimuthal_mode, m_nk, alphas, alpha_errors); @@ -33,13 +33,13 @@ HankelTransform::HankelTransform (int const hankel_order, AMREX_ALWAYS_ASSERT(std::all_of(alpha_errors.begin(), alpha_errors.end(), [](int i) { return i == 0; })); // Calculate the spectral grid - m_kr.resize(m_nk); + amrex::Vector<amrex::Real> kr(m_nk); for (int ik=0 ; ik < m_nk ; ik++) { - m_kr[ik] = alphas[ik]/rmax; + kr[ik] = alphas[ik]/rmax; } // Calculate the spatial grid (Uniform grid with a half-cell offset) - RealVector rmesh(m_nr); + amrex::Vector<amrex::Real> rmesh(m_nr); amrex::Real dr = rmax/m_nr; for (int ir=0 ; ir < m_nr ; ir++) { rmesh[ir] = dr*(ir + 0.5_rt); @@ -57,22 +57,22 @@ HankelTransform::HankelTransform (int const hankel_order, p_denom = hankel_order; } - RealVector denom(m_nk); + amrex::Vector<amrex::Real> denom(m_nk); for (int ik=0 ; ik < m_nk ; ik++) { const amrex::Real jna = jn(p_denom, alphas[ik]); denom[ik] = MathConst::pi*rmax*rmax*jna*jna; } - RealVector num(m_nk*m_nr); + amrex::Vector<amrex::Real> num(m_nk*m_nr); for (int ir=0 ; ir < m_nr ; ir++) { for (int ik=0 ; ik < m_nk ; ik++) { int const ii = ik + ir*m_nk; - num[ii] = jn(hankel_order, rmesh[ir]*m_kr[ik]); + num[ii] = jn(hankel_order, rmesh[ir]*kr[ik]); } } // Get the inverse matrix - invM.resize(m_nk*m_nr); + amrex::Vector<amrex::Real> invM(m_nk*m_nr); if (azimuthal_mode > 0) { for (int ir=0 ; ir < m_nr ; ir++) { for (int ik=1 ; ik < m_nk ; ik++) { @@ -107,18 +107,20 @@ HankelTransform::HankelTransform (int const hankel_order, } } + amrex::Vector<amrex::Real> M; + // Calculate the matrix M by inverting invM if (azimuthal_mode !=0 && hankel_order != azimuthal_mode-1) { // In this case, invM is singular, thus we calculate the pseudo-inverse. // The Moore-Penrose psuedo-inverse is calculated using the SVD method. M.resize(m_nk*m_nr, 0.); - RealVector invMcopy(invM); - RealVector sdiag(m_nk-1, 0.); - RealVector u((m_nk-1)*(m_nk-1), 0.); - RealVector vt((m_nr)*(m_nr), 0.); - RealVector sp((m_nr)*(m_nk-1), 0.); - RealVector temp((m_nr)*(m_nk-1), 0.); + amrex::Vector<amrex::Real> invMcopy(invM); + amrex::Vector<amrex::Real> sdiag(m_nk-1, 0.); + amrex::Vector<amrex::Real> u((m_nk-1)*(m_nk-1), 0.); + amrex::Vector<amrex::Real> vt((m_nr)*(m_nr), 0.); + amrex::Vector<amrex::Real> sp((m_nr)*(m_nk-1), 0.); + amrex::Vector<amrex::Real> temp((m_nr)*(m_nk-1), 0.); // Calculate the singlular-value-decomposition of invM (leaving out the first row). // invM = u*sdiag*vt @@ -169,6 +171,13 @@ HankelTransform::HankelTransform (int const hankel_order, } + m_kr.resize(kr.size()); + m_invM.resize(invM.size()); + m_M.resize(M.size()); + amrex::Gpu::copyAsync(amrex::Gpu::hostToDevice, kr.begin(), kr.end(), m_kr.begin()); + amrex::Gpu::copyAsync(amrex::Gpu::hostToDevice, invM.begin(), invM.end(), m_invM.begin()); + amrex::Gpu::copyAsync(amrex::Gpu::hostToDevice, M.begin(), M.end(), m_M.begin()); + amrex::Gpu::synchronize(); } void @@ -193,7 +202,7 @@ HankelTransform::HankelForwardTransform (amrex::FArrayBox const& F, int const F_ // Note that M is flagged to be transposed since it has dimensions (m_nr, m_nk) blas::gemm(blas::Layout::ColMajor, blas::Op::Trans, blas::Op::NoTrans, m_nk, nz, m_nr, 1._rt, - M.dataPtr(), m_nk, + m_M.dataPtr(), m_nk, F.dataPtr(F_icomp)+ngr, nrF, 0._rt, G.dataPtr(G_icomp), m_nk); @@ -202,13 +211,13 @@ HankelTransform::HankelForwardTransform (amrex::FArrayBox const& F, int const F_ // It is not clear if the GPU gemm wasn't build properly, it is cycling data out and back // in to the device, or if it is because gemm is launching its own threads. - amrex::Real const * M_arr = M.dataPtr(); + amrex::Real const * M_arr = m_M.dataPtr(); amrex::Array4<const amrex::Real> const & F_arr = F.array(); amrex::Array4< amrex::Real> const & G_arr = G.array(); int const nr = m_nr; - ParallelFor(G_box, + amrex::ParallelFor(G_box, [=] AMREX_GPU_DEVICE(int ik, int iz, int inotused) noexcept { G_arr(ik,iz,G_icomp) = 0.; for (int ir=0 ; ir < nr ; ir++) { @@ -240,10 +249,10 @@ HankelTransform::HankelInverseTransform (amrex::FArrayBox const& G, int const G_ #ifndef AMREX_USE_GPU // On CPU, the blas::gemm is significantly faster - // Note that invM is flagged to be transposed since it has dimensions (m_nk, m_nr) + // Note that m_invM is flagged to be transposed since it has dimensions (m_nk, m_nr) blas::gemm(blas::Layout::ColMajor, blas::Op::Trans, blas::Op::NoTrans, m_nr, nz, m_nk, 1._rt, - invM.dataPtr(), m_nr, + m_invM.dataPtr(), m_nr, G.dataPtr(G_icomp), m_nk, 0._rt, F.dataPtr(F_icomp)+ngr, nrF); @@ -252,13 +261,13 @@ HankelTransform::HankelInverseTransform (amrex::FArrayBox const& G, int const G_ // It is not clear if the GPU gemm wasn't build properly, it is cycling data out and back // in to the device, or if it is because gemm is launching its own threads. - amrex::Real const * invM_arr = invM.dataPtr(); + amrex::Real const * invM_arr = m_invM.dataPtr(); amrex::Array4<const amrex::Real> const & G_arr = G.array(); amrex::Array4< amrex::Real> const & F_arr = F.array(); int const nk = m_nk; - ParallelFor(G_box, + amrex::ParallelFor(G_box, [=] AMREX_GPU_DEVICE(int ir, int iz, int inotused) noexcept { F_arr(ir,iz,F_icomp) = 0.; for (int ik=0 ; ik < nk ; ik++) { |