diff options
Diffstat (limited to 'Source/Particles/Sorting/SortingUtils.H')
-rw-r--r-- | Source/Particles/Sorting/SortingUtils.H | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/Source/Particles/Sorting/SortingUtils.H b/Source/Particles/Sorting/SortingUtils.H index 1133ccab5..ede59b53b 100644 --- a/Source/Particles/Sorting/SortingUtils.H +++ b/Source/Particles/Sorting/SortingUtils.H @@ -6,7 +6,8 @@ #include <AMReX_Gpu.H> // TODO: Add documentation -void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v ) { +void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v ) +{ #ifdef AMREX_USE_GPU // On GPU: Use thrust thrust::sequence( v.begin(), v.end() ); @@ -17,6 +18,31 @@ void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v ) { } // TODO: Add documentation +template< typename ForwardIterator > +ForwardIterator stablePartition(ForwardIterator index_begin, + ForwardIterator index_end, + amrex::Gpu::ManagedDeviceVector<int>& predicate) +{ +#ifdef AMREX_USE_GPU + // On GPU: Use thrust + int* AMREX_RESTRICT predicate_ptr = predicate.dataPtr(); + ForwardIterator sep = thrust::stable_partition( + thrust::cuda::par(Cuda::The_ThrustCachedAllocator()), + index_begin, index_end, + AMREX_GPU_HOST_DEVICE + [predicate_ptr](long i) { return predicate_ptr[i]; } + ); +#else + // On CPU: Use std library + ForwardIterator sep = std::stable_partition( + index_begin, index_end, + [&predicate](long i) { return predicate[i]; } + ); +#endif + return sep; +} + +// TODO: Add documentation class fillBufferFlag { public: |