diff options
Diffstat (limited to 'Python/pywarpx/fields.py')
-rw-r--r-- | Python/pywarpx/fields.py | 40 |
1 files changed, 38 insertions, 2 deletions
diff --git a/Python/pywarpx/fields.py b/Python/pywarpx/fields.py index 9068a22ea..fcb04f865 100644 --- a/Python/pywarpx/fields.py +++ b/Python/pywarpx/fields.py @@ -8,6 +8,13 @@ JxWrapper, JyWrapper, JzWrapper """ import numpy as np +try: + from mpi4py import MPI as mpi + comm_world = mpi.COMM_WORLD + npes = comm_world.Get_size() +except ImportError: + npes = 1 + from . import _libwarpx @@ -92,6 +99,11 @@ class _MultiFABWrapper(object): ny = hivects[1,:].max() - self.nghosts nz = hivects[2,:].max() - self.nghosts + if npes > 1: + nx = comm_world.allreduce(nx, op=mpi.MAX) + ny = comm_world.allreduce(ny, op=mpi.MAX) + nz = comm_world.allreduce(nz, op=mpi.MAX) + if isinstance(ix, slice): ixstart = max(ix.start or -self.nghosts, -self.nghosts) ixstop = min(ix.stop or nx + 1 + self.nghosts, nx + self.overlaps[0] + self.nghosts) @@ -117,6 +129,7 @@ class _MultiFABWrapper(object): max(0, izstop - izstart)) resultglobal = np.zeros(sss) + datalist = [] for i in range(len(fields)): # --- The ix1, 2 etc are relative to global indexing @@ -137,7 +150,16 @@ class _MultiFABWrapper(object): slice(iy1 - iystart, iy2 - iystart), slice(iz1 - izstart, iz2 - izstart)) - resultglobal[vslice] = fields[i][sss] + datalist.append((vslice, fields[i][sss])) + + if npes == 1: + all_datalist = [datalist] + else: + all_datalist = comm_world.allgather(datalist) + + for datalist in all_datalist: + for vslice, ff in datalist: + resultglobal[vslice] = ff # --- Now remove any of the reduced dimensions. sss = [slice(None), slice(None), slice(None)] @@ -177,6 +199,10 @@ class _MultiFABWrapper(object): nx = hivects[0,:].max() - self.nghosts nz = hivects[1,:].max() - self.nghosts + if npes > 1: + nx = comm_world.allreduce(nx, op=mpi.MAX) + nz = comm_world.allreduce(nz, op=mpi.MAX) + if isinstance(ix, slice): ixstart = max(ix.start or -self.nghosts, -self.nghosts) ixstop = min(ix.stop or nx + 1 + self.nghosts, nx + self.overlaps[0] + self.nghosts) @@ -198,6 +224,7 @@ class _MultiFABWrapper(object): sss = tuple(list(sss) + [ncomps]) resultglobal = np.zeros(sss) + datalist = [] for i in range(len(fields)): # --- The ix1, 2 etc are relative to global indexing @@ -216,7 +243,16 @@ class _MultiFABWrapper(object): vslice = (slice(ix1 - ixstart, ix2 - ixstart), slice(iz1 - izstart, iz2 - izstart)) - resultglobal[vslice] = fields[i][sss] + datalist.append((vslice, fields[i][sss])) + + if npes == 1: + all_datalist = [datalist] + else: + all_datalist = comm_world.allgather(datalist) + + for datalist in all_datalist: + for vslice, ff in datalist: + resultglobal[vslice] = ff # --- Now remove any of the reduced dimensions. sss = [slice(None), slice(None)] |