aboutsummaryrefslogtreecommitdiff
path: root/Source/Particles/Sorting/SortingUtils.H
diff options
context:
space:
mode:
Diffstat (limited to 'Source/Particles/Sorting/SortingUtils.H')
-rw-r--r--Source/Particles/Sorting/SortingUtils.H28
1 files changed, 27 insertions, 1 deletions
diff --git a/Source/Particles/Sorting/SortingUtils.H b/Source/Particles/Sorting/SortingUtils.H
index 1133ccab5..ede59b53b 100644
--- a/Source/Particles/Sorting/SortingUtils.H
+++ b/Source/Particles/Sorting/SortingUtils.H
@@ -6,7 +6,8 @@
#include <AMReX_Gpu.H>
// TODO: Add documentation
-void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v ) {
+void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v )
+{
#ifdef AMREX_USE_GPU
// On GPU: Use thrust
thrust::sequence( v.begin(), v.end() );
@@ -17,6 +18,31 @@ void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v ) {
}
// TODO: Add documentation
+template< typename ForwardIterator >
+ForwardIterator stablePartition(ForwardIterator index_begin,
+ ForwardIterator index_end,
+ amrex::Gpu::ManagedDeviceVector<int>& predicate)
+{
+#ifdef AMREX_USE_GPU
+ // On GPU: Use thrust
+ int* AMREX_RESTRICT predicate_ptr = predicate.dataPtr();
+ ForwardIterator sep = thrust::stable_partition(
+ thrust::cuda::par(Cuda::The_ThrustCachedAllocator()),
+ index_begin, index_end,
+ AMREX_GPU_HOST_DEVICE
+ [predicate_ptr](long i) { return predicate_ptr[i]; }
+ );
+#else
+ // On CPU: Use std library
+ ForwardIterator sep = std::stable_partition(
+ index_begin, index_end,
+ [&predicate](long i) { return predicate[i]; }
+ );
+#endif
+ return sep;
+}
+
+// TODO: Add documentation
class fillBufferFlag
{
public: