aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Source/FieldSolver/SpectralSolver/SpectralKSpace.cpp7
-rw-r--r--Source/Particles/Sorting/SortingUtils.H33
-rw-r--r--Source/Utils/WarpX_Complex.H14
3 files changed, 22 insertions, 32 deletions
diff --git a/Source/FieldSolver/SpectralSolver/SpectralKSpace.cpp b/Source/FieldSolver/SpectralSolver/SpectralKSpace.cpp
index aee131324..c21388aba 100644
--- a/Source/FieldSolver/SpectralSolver/SpectralKSpace.cpp
+++ b/Source/FieldSolver/SpectralSolver/SpectralKSpace.cpp
@@ -144,12 +144,7 @@ SpectralKSpace::getSpectralShiftFactor( const DistributionMapping& dm,
}
const Complex I{0,1};
for (int i=0; i<k.size(); i++ ){
-#ifdef AMREX_USE_GPU
- shift[i] = thrust::exp( I*sign*k[i]*0.5*dx[i_dim] );
-#else
- shift[i] = std::exp( I*sign*k[i]*0.5*dx[i_dim] );
-#endif
-
+ shift[i] = exp( I*sign*k[i]*0.5*dx[i_dim]);
}
}
return shift_factor;
diff --git a/Source/Particles/Sorting/SortingUtils.H b/Source/Particles/Sorting/SortingUtils.H
index 35bc059aa..f425c6c7b 100644
--- a/Source/Particles/Sorting/SortingUtils.H
+++ b/Source/Particles/Sorting/SortingUtils.H
@@ -3,10 +3,7 @@
#include <WarpXParticleContainer.H>
#include <AMReX_Gpu.H>
-#ifdef AMREX_USE_GPU
- #include <thrust/partition.h>
- #include <thrust/distance.h>
-#endif
+#include <AMReX_Partition.H>
/** \brief Fill the elements of the input vector with consecutive integer,
* starting from 0
@@ -16,8 +13,10 @@
void fillWithConsecutiveIntegers( amrex::Gpu::DeviceVector<long>& v )
{
#ifdef AMREX_USE_GPU
- // On GPU: Use thrust
- thrust::sequence( v.begin(), v.end() );
+ // On GPU: Use amrex
+ auto data = v.data();
+ auto N = v.size();
+ AMREX_FOR_1D( N, i, data[i] = i;);
#else
// On CPU: Use std library
std::iota( v.begin(), v.end(), 0L );
@@ -35,17 +34,18 @@ void fillWithConsecutiveIntegers( amrex::Gpu::DeviceVector<long>& v )
*/
template< typename ForwardIterator >
ForwardIterator stablePartition(ForwardIterator const index_begin,
- ForwardIterator const index_end,
- amrex::Gpu::DeviceVector<int> const& predicate)
+ ForwardIterator const index_end,
+ amrex::Gpu::DeviceVector<int> const& predicate)
{
#ifdef AMREX_USE_GPU
- // On GPU: Use thrust
+ // On GPU: Use amrex
int const* AMREX_RESTRICT predicate_ptr = predicate.dataPtr();
- ForwardIterator const sep = thrust::stable_partition(
- thrust::cuda::par(amrex::Gpu::The_ThrustCachedAllocator()),
- index_begin, index_end,
- [predicate_ptr] AMREX_GPU_DEVICE (long i) { return predicate_ptr[i]; }
- );
+ int N = static_cast<int>(std::distance(index_begin, index_end));
+ auto num_true = amrex::StablePartition(&(*index_begin), N,
+ [predicate_ptr] AMREX_GPU_DEVICE (long i) { return predicate_ptr[i]; });
+
+ ForwardIterator sep = index_begin;
+ std::advance(sep, num_true);
#else
// On CPU: Use std library
ForwardIterator const sep = std::stable_partition(
@@ -66,12 +66,7 @@ template< typename ForwardIterator >
int iteratorDistance(ForwardIterator const first,
ForwardIterator const last)
{
-#ifdef AMREX_USE_GPU
- // On GPU: Use thrust
- return thrust::distance( first, last );
-#else
return std::distance( first, last );
-#endif
}
/** \brief Functor that fills the elements of the particle array `inexflag`
diff --git a/Source/Utils/WarpX_Complex.H b/Source/Utils/WarpX_Complex.H
index 25ed1a53a..1f265d3c5 100644
--- a/Source/Utils/WarpX_Complex.H
+++ b/Source/Utils/WarpX_Complex.H
@@ -3,13 +3,14 @@
#include <AMReX_REAL.H>
#include <AMReX_Gpu.H>
+#include <AMReX_GpuComplex.H>
+
+#include <complex>
// Define complex type on GPU/CPU
#ifdef AMREX_USE_GPU
-#include <thrust/complex.h>
-
-using Complex = thrust::complex<amrex::Real>;
+using Complex = amrex::GpuComplex<amrex::Real>;
#ifdef WARPX_USE_PSATD
#include <cufft.h>
@@ -19,7 +20,6 @@ static_assert( sizeof(Complex) == sizeof(cuDoubleComplex),
#else
-#include <complex>
using Complex = std::complex<amrex::Real>;
#ifdef WARPX_USE_PSATD
@@ -39,7 +39,7 @@ namespace MathFunc
template<typename T>
AMREX_GPU_HOST_DEVICE T exp (const T& val){
#ifdef AMREX_USE_GPU
- return thrust::exp(val);
+ return amrex::exp(val);
#else
return std::exp(val);
#endif
@@ -49,7 +49,7 @@ namespace MathFunc
template<typename T>
AMREX_GPU_HOST_DEVICE T sqrt (const T& val){
#ifdef AMREX_USE_GPU
- return thrust::sqrt(val);
+ return amrex::sqrt(val);
#else
return std::sqrt(val);
#endif
@@ -59,7 +59,7 @@ namespace MathFunc
template<typename T1, typename T2>
AMREX_GPU_HOST_DEVICE T1 pow (const T1& val, const T2& power){
#ifdef AMREX_USE_GPU
- return thrust::pow(val, power);
+ return amrex::pow(val, power);
#else
return std::pow(val, power);
#endif