aboutsummaryrefslogtreecommitdiff
path: root/Python/pywarpx/fields.py
diff options
context:
space:
mode:
authorGravatar WeiqunZhang <WeiqunZhang@lbl.gov> 2019-09-05 09:47:33 -0700
committerGravatar GitHub <noreply@github.com> 2019-09-05 09:47:33 -0700
commit5736e73df40b0251a63e1910a0366babee6f8a14 (patch)
treed010370b1a5d11d4f43d6ed90dc1d2479fd44336 /Python/pywarpx/fields.py
parentb93dce4d62b9a3014641f0df551fd7de453126ba (diff)
parentc4fdd59cf048d4563a0d8f1397f226b83fa4ba0a (diff)
downloadWarpX-5736e73df40b0251a63e1910a0366babee6f8a14.tar.gz
WarpX-5736e73df40b0251a63e1910a0366babee6f8a14.tar.zst
WarpX-5736e73df40b0251a63e1910a0366babee6f8a14.zip
Merge pull request #311 from ECP-WarpX/python_update
Python update
Diffstat (limited to 'Python/pywarpx/fields.py')
-rw-r--r--Python/pywarpx/fields.py40
1 files changed, 38 insertions, 2 deletions
diff --git a/Python/pywarpx/fields.py b/Python/pywarpx/fields.py
index 9068a22ea..fcb04f865 100644
--- a/Python/pywarpx/fields.py
+++ b/Python/pywarpx/fields.py
@@ -8,6 +8,13 @@ JxWrapper, JyWrapper, JzWrapper
"""
import numpy as np
+try:
+ from mpi4py import MPI as mpi
+ comm_world = mpi.COMM_WORLD
+ npes = comm_world.Get_size()
+except ImportError:
+ npes = 1
+
from . import _libwarpx
@@ -92,6 +99,11 @@ class _MultiFABWrapper(object):
ny = hivects[1,:].max() - self.nghosts
nz = hivects[2,:].max() - self.nghosts
+ if npes > 1:
+ nx = comm_world.allreduce(nx, op=mpi.MAX)
+ ny = comm_world.allreduce(ny, op=mpi.MAX)
+ nz = comm_world.allreduce(nz, op=mpi.MAX)
+
if isinstance(ix, slice):
ixstart = max(ix.start or -self.nghosts, -self.nghosts)
ixstop = min(ix.stop or nx + 1 + self.nghosts, nx + self.overlaps[0] + self.nghosts)
@@ -117,6 +129,7 @@ class _MultiFABWrapper(object):
max(0, izstop - izstart))
resultglobal = np.zeros(sss)
+ datalist = []
for i in range(len(fields)):
# --- The ix1, 2 etc are relative to global indexing
@@ -137,7 +150,16 @@ class _MultiFABWrapper(object):
slice(iy1 - iystart, iy2 - iystart),
slice(iz1 - izstart, iz2 - izstart))
- resultglobal[vslice] = fields[i][sss]
+ datalist.append((vslice, fields[i][sss]))
+
+ if npes == 1:
+ all_datalist = [datalist]
+ else:
+ all_datalist = comm_world.allgather(datalist)
+
+ for datalist in all_datalist:
+ for vslice, ff in datalist:
+ resultglobal[vslice] = ff
# --- Now remove any of the reduced dimensions.
sss = [slice(None), slice(None), slice(None)]
@@ -177,6 +199,10 @@ class _MultiFABWrapper(object):
nx = hivects[0,:].max() - self.nghosts
nz = hivects[1,:].max() - self.nghosts
+ if npes > 1:
+ nx = comm_world.allreduce(nx, op=mpi.MAX)
+ nz = comm_world.allreduce(nz, op=mpi.MAX)
+
if isinstance(ix, slice):
ixstart = max(ix.start or -self.nghosts, -self.nghosts)
ixstop = min(ix.stop or nx + 1 + self.nghosts, nx + self.overlaps[0] + self.nghosts)
@@ -198,6 +224,7 @@ class _MultiFABWrapper(object):
sss = tuple(list(sss) + [ncomps])
resultglobal = np.zeros(sss)
+ datalist = []
for i in range(len(fields)):
# --- The ix1, 2 etc are relative to global indexing
@@ -216,7 +243,16 @@ class _MultiFABWrapper(object):
vslice = (slice(ix1 - ixstart, ix2 - ixstart),
slice(iz1 - izstart, iz2 - izstart))
- resultglobal[vslice] = fields[i][sss]
+ datalist.append((vslice, fields[i][sss]))
+
+ if npes == 1:
+ all_datalist = [datalist]
+ else:
+ all_datalist = comm_world.allgather(datalist)
+
+ for datalist in all_datalist:
+ for vslice, ff in datalist:
+ resultglobal[vslice] = ff
# --- Now remove any of the reduced dimensions.
sss = [slice(None), slice(None)]