aboutsummaryrefslogtreecommitdiff
path: root/Source/Particles/Sorting/SortingUtils.H
diff options
context:
space:
mode:
authorGravatar Remi Lehe <remi.lehe@normalesup.org> 2019-09-26 08:51:00 -0700
committerGravatar Remi Lehe <remi.lehe@normalesup.org> 2019-10-01 16:32:38 -0700
commit8575478f370f3db2525a8d5444e6949776e9da29 (patch)
treeb9ff850e818f859ebb39027790822c4f46b6d06e /Source/Particles/Sorting/SortingUtils.H
parentafde2a0f3cca52fde99ff59b72a6630cbea4c391 (diff)
downloadWarpX-8575478f370f3db2525a8d5444e6949776e9da29.tar.gz
WarpX-8575478f370f3db2525a8d5444e6949776e9da29.tar.zst
WarpX-8575478f370f3db2525a8d5444e6949776e9da29.zip
Implemented stable partition
Diffstat (limited to 'Source/Particles/Sorting/SortingUtils.H')
-rw-r--r--Source/Particles/Sorting/SortingUtils.H28
1 files changed, 27 insertions, 1 deletions
diff --git a/Source/Particles/Sorting/SortingUtils.H b/Source/Particles/Sorting/SortingUtils.H
index 1133ccab5..ede59b53b 100644
--- a/Source/Particles/Sorting/SortingUtils.H
+++ b/Source/Particles/Sorting/SortingUtils.H
@@ -6,7 +6,8 @@
#include <AMReX_Gpu.H>
// TODO: Add documentation
-void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v ) {
+void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v )
+{
#ifdef AMREX_USE_GPU
// On GPU: Use thrust
thrust::sequence( v.begin(), v.end() );
@@ -17,6 +18,31 @@ void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v ) {
}
// TODO: Add documentation
+template< typename ForwardIterator >
+ForwardIterator stablePartition(ForwardIterator index_begin,
+ ForwardIterator index_end,
+ amrex::Gpu::ManagedDeviceVector<int>& predicate)
+{
+#ifdef AMREX_USE_GPU
+ // On GPU: Use thrust
+ int* AMREX_RESTRICT predicate_ptr = predicate.dataPtr();
+ ForwardIterator sep = thrust::stable_partition(
+ thrust::cuda::par(Cuda::The_ThrustCachedAllocator()),
+ index_begin, index_end,
+ AMREX_GPU_HOST_DEVICE
+ [predicate_ptr](long i) { return predicate_ptr[i]; }
+ );
+#else
+ // On CPU: Use std library
+ ForwardIterator sep = std::stable_partition(
+ index_begin, index_end,
+ [&predicate](long i) { return predicate[i]; }
+ );
+#endif
+ return sep;
+}
+
+// TODO: Add documentation
class fillBufferFlag
{
public: