aboutsummaryrefslogtreecommitdiff
path: root/Source/Particles/Sorting/SortingUtils.H
diff options
context:
space:
mode:
Diffstat (limited to 'Source/Particles/Sorting/SortingUtils.H')
-rw-r--r--Source/Particles/Sorting/SortingUtils.H33
1 files changed, 14 insertions, 19 deletions
diff --git a/Source/Particles/Sorting/SortingUtils.H b/Source/Particles/Sorting/SortingUtils.H
index 35bc059aa..f425c6c7b 100644
--- a/Source/Particles/Sorting/SortingUtils.H
+++ b/Source/Particles/Sorting/SortingUtils.H
@@ -3,10 +3,7 @@
#include <WarpXParticleContainer.H>
#include <AMReX_Gpu.H>
-#ifdef AMREX_USE_GPU
- #include <thrust/partition.h>
- #include <thrust/distance.h>
-#endif
+#include <AMReX_Partition.H>
/** \brief Fill the elements of the input vector with consecutive integer,
* starting from 0
@@ -16,8 +13,10 @@
void fillWithConsecutiveIntegers( amrex::Gpu::DeviceVector<long>& v )
{
#ifdef AMREX_USE_GPU
- // On GPU: Use thrust
- thrust::sequence( v.begin(), v.end() );
+ // On GPU: Use amrex
+ auto data = v.data();
+ auto N = v.size();
+ AMREX_FOR_1D( N, i, data[i] = i;);
#else
// On CPU: Use std library
std::iota( v.begin(), v.end(), 0L );
@@ -35,17 +34,18 @@ void fillWithConsecutiveIntegers( amrex::Gpu::DeviceVector<long>& v )
*/
template< typename ForwardIterator >
ForwardIterator stablePartition(ForwardIterator const index_begin,
- ForwardIterator const index_end,
- amrex::Gpu::DeviceVector<int> const& predicate)
+ ForwardIterator const index_end,
+ amrex::Gpu::DeviceVector<int> const& predicate)
{
#ifdef AMREX_USE_GPU
- // On GPU: Use thrust
+ // On GPU: Use amrex
int const* AMREX_RESTRICT predicate_ptr = predicate.dataPtr();
- ForwardIterator const sep = thrust::stable_partition(
- thrust::cuda::par(amrex::Gpu::The_ThrustCachedAllocator()),
- index_begin, index_end,
- [predicate_ptr] AMREX_GPU_DEVICE (long i) { return predicate_ptr[i]; }
- );
+ int N = static_cast<int>(std::distance(index_begin, index_end));
+ auto num_true = amrex::StablePartition(&(*index_begin), N,
+ [predicate_ptr] AMREX_GPU_DEVICE (long i) { return predicate_ptr[i]; });
+
+ ForwardIterator sep = index_begin;
+ std::advance(sep, num_true);
#else
// On CPU: Use std library
ForwardIterator const sep = std::stable_partition(
@@ -66,12 +66,7 @@ template< typename ForwardIterator >
int iteratorDistance(ForwardIterator const first,
ForwardIterator const last)
{
-#ifdef AMREX_USE_GPU
- // On GPU: Use thrust
- return thrust::distance( first, last );
-#else
return std::distance( first, last );
-#endif
}
/** \brief Functor that fills the elements of the particle array `inexflag`