aboutsummaryrefslogtreecommitdiff
path: root/Source/Particles/Sorting
diff options
context:
space:
mode:
authorGravatar Remi Lehe <remi.lehe@normalesup.org> 2019-09-26 08:51:00 -0700
committerGravatar Remi Lehe <remi.lehe@normalesup.org> 2019-10-01 16:32:38 -0700
commit8575478f370f3db2525a8d5444e6949776e9da29 (patch)
treeb9ff850e818f859ebb39027790822c4f46b6d06e /Source/Particles/Sorting
parentafde2a0f3cca52fde99ff59b72a6630cbea4c391 (diff)
downloadWarpX-8575478f370f3db2525a8d5444e6949776e9da29.tar.gz
WarpX-8575478f370f3db2525a8d5444e6949776e9da29.tar.zst
WarpX-8575478f370f3db2525a8d5444e6949776e9da29.zip
Implemented stable partition
Diffstat (limited to 'Source/Particles/Sorting')
-rw-r--r--Source/Particles/Sorting/Partition.cpp24
-rw-r--r--Source/Particles/Sorting/SortingUtils.H28
2 files changed, 42 insertions, 10 deletions
diff --git a/Source/Particles/Sorting/Partition.cpp b/Source/Particles/Sorting/Partition.cpp
index 7e247bd8c..3ee49cafe 100644
--- a/Source/Particles/Sorting/Partition.cpp
+++ b/Source/Particles/Sorting/Partition.cpp
@@ -43,24 +43,30 @@ PhysicalParticleContainer::PartitionParticlesInBuffers(
BL_PROFILE("PPC::Evolve::partition");
auto& aos = pti.GetArrayOfStructs();
+
+ // Initialize temporary arrays
Gpu::ManagedDeviceVector<int> inexflag;
inexflag.resize(np);
Gpu::ManagedDeviceVector<long> pid;
pid.resize(np);
- // Select the largest buffer first
+ // First, partition particles in the larger buffer
+
+ // - Select the larger buffer
iMultiFab const* bmasks =
(WarpX::n_field_gather_buffer >= WarpX::n_current_deposition_buffer) ?
gather_masks : current_masks;
-
- // For each particle, find whether it is in the large buffer,
- // by looking up the mask. Store the answer in `inexflag`
- amrex::ParallelFor( np, fillBufferFlag( pti, bmasks, inexflag, Geom(lev) ) );
-
- // Partition the particles according to whether they are in the large buffer or not
+ // - For each particle, find whether it is in the larger buffer,
+ // by looking up the mask. Store the answer in `inexflag`.
+ amrex::ParallelFor( np, fillBufferFlag(pti, bmasks, inexflag, Geom(lev)) );
+ // - Find the indices that reorder particles so that the last particles
+ // are in the larger buffer
fillWithConsecutiveIntegers( pid );
- auto sep = std::stable_partition(pid.begin(), pid.end(),
- [&inexflag](long id) { return inexflag[id]; });
+ auto sep = stablePartition( pid.begin(), pid.end(), inexflag );
+ // At the end of this step, `pid` contains the indices that should be used to
+ // reorder the particles, and `sep` is the position in the array that
+ // separates the particles that deposit/gather on the fine patch (first part)
+ // and the particles that deposit/gather in the buffers (last part)
if (WarpX::n_current_deposition_buffer == WarpX::n_field_gather_buffer) {
nfine_current = nfine_gather = std::distance(pid.begin(), sep);
diff --git a/Source/Particles/Sorting/SortingUtils.H b/Source/Particles/Sorting/SortingUtils.H
index 1133ccab5..ede59b53b 100644
--- a/Source/Particles/Sorting/SortingUtils.H
+++ b/Source/Particles/Sorting/SortingUtils.H
@@ -6,7 +6,8 @@
#include <AMReX_Gpu.H>
// TODO: Add documentation
-void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v ) {
+void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v )
+{
#ifdef AMREX_USE_GPU
// On GPU: Use thrust
thrust::sequence( v.begin(), v.end() );
@@ -17,6 +18,31 @@ void fillWithConsecutiveIntegers( amrex::Gpu::ManagedDeviceVector<long>& v ) {
}
// TODO: Add documentation
+template< typename ForwardIterator >
+ForwardIterator stablePartition(ForwardIterator index_begin,
+ ForwardIterator index_end,
+ amrex::Gpu::ManagedDeviceVector<int>& predicate)
+{
+#ifdef AMREX_USE_GPU
+ // On GPU: Use thrust
+ int* AMREX_RESTRICT predicate_ptr = predicate.dataPtr();
+ ForwardIterator sep = thrust::stable_partition(
+ thrust::cuda::par(Cuda::The_ThrustCachedAllocator()),
+ index_begin, index_end,
+ AMREX_GPU_HOST_DEVICE
+ [predicate_ptr](long i) { return predicate_ptr[i]; }
+ );
+#else
+ // On CPU: Use std library
+ ForwardIterator sep = std::stable_partition(
+ index_begin, index_end,
+ [&predicate](long i) { return predicate[i]; }
+ );
+#endif
+ return sep;
+}
+
+// TODO: Add documentation
class fillBufferFlag
{
public: