diff options
author | 2020-12-02 14:57:28 -0800 | |
---|---|---|
committer | 2020-12-02 14:57:28 -0800 | |
commit | 7fe38c8d13d9a511ab8f2c1243a4c58ca393c813 (patch) | |
tree | 306e64ca69c22077a5fdc9ff698c833431dfd1e1 /Python/pywarpx/fields.py | |
parent | 16c404ec6918e8264b1def78e0ba3969d96cafad (diff) | |
download | WarpX-7fe38c8d13d9a511ab8f2c1243a4c58ca393c813.tar.gz WarpX-7fe38c8d13d9a511ab8f2c1243a4c58ca393c813.tar.zst WarpX-7fe38c8d13d9a511ab8f2c1243a4c58ca393c813.zip |
Fix python wrapper (#1532)
* Fixes to the Python interface for accessing field and particle data
* In Python wrapper, cleaned up how ngrow is handled
Diffstat (limited to 'Python/pywarpx/fields.py')
-rw-r--r-- | Python/pywarpx/fields.py | 97 |
1 files changed, 49 insertions, 48 deletions
diff --git a/Python/pywarpx/fields.py b/Python/pywarpx/fields.py index eeed3283d..e36f4707a 100644 --- a/Python/pywarpx/fields.py +++ b/Python/pywarpx/fields.py @@ -45,26 +45,26 @@ class _MultiFABWrapper(object): self.dim = _libwarpx.dim # overlaps is one along the axes where the grid boundaries overlap the neighboring grid, - # which is the case with node centering + # which is the case with node centering. + # This presumably will never change during a calculation. self.overlaps = self.get_nodal_flag() def _getlovects(self): if self.direction is None: - lovects = self.get_lovects(self.level, self.include_ghosts) + lovects, ngrow = self.get_lovects(self.level, self.include_ghosts) else: - lovects = self.get_lovects(self.level, self.direction, self.include_ghosts) - self.nghosts = -lovects.min() - return lovects + lovects, ngrow = self.get_lovects(self.level, self.direction, self.include_ghosts) + return lovects, ngrow def _gethivects(self): - lovects = self._getlovects() + lovects, ngrow = self._getlovects() fields = self._getfields() hivects = np.zeros_like(lovects) for i in range(len(fields)): hivects[:,i] = lovects[:,i] + np.array(fields[i].shape[:self.dim]) - self.overlaps - return hivects + return hivects, ngrow def _getfields(self): if self.direction is None: @@ -73,7 +73,8 @@ class _MultiFABWrapper(object): return self.get_fabs(self.level, self.direction, self.include_ghosts) def __len__(self): - return lend(self._getlovects()) + lovects, ngrow = self._getlovects() + return lend(lovects) def mesh(self, direction): """Returns the mesh along the specified direction with the appropriate centering. @@ -96,8 +97,8 @@ class _MultiFABWrapper(object): raise Exception('Inappropriate direction given') # --- Get the total number of cells along the direction - hivects = self._gethivects() - nn = hivects[idir,:].max() - self.nghosts + self.overlaps[idir] + hivects, ngrow = self._gethivects() + nn = hivects[idir,:].max() - ngrow[idir] + self.overlaps[idir] if npes > 1: nn = comm_world.allreduce(nn, op=mpi.MAX) @@ -143,8 +144,8 @@ class _MultiFABWrapper(object): """Returns slices of a 3D decomposed array, """ - lovects = self._getlovects() - hivects = self._gethivects() + lovects, ngrow = self._getlovects() + hivects, ngrow = self._gethivects() fields = self._getfields() ix = index[0] @@ -164,9 +165,9 @@ class _MultiFABWrapper(object): else: ic = None - nx = hivects[0,:].max() - self.nghosts - ny = hivects[1,:].max() - self.nghosts - nz = hivects[2,:].max() - self.nghosts + nx = hivects[0,:].max() - ngrow[0] + ny = hivects[1,:].max() - ngrow[1] + nz = hivects[2,:].max() - ngrow[2] if npes > 1: nx = comm_world.allreduce(nx, op=mpi.MAX) @@ -174,20 +175,20 @@ class _MultiFABWrapper(object): nz = comm_world.allreduce(nz, op=mpi.MAX) if isinstance(ix, slice): - ixstart = max(ix.start or -self.nghosts, -self.nghosts) - ixstop = min(ix.stop or nx + 1 + self.nghosts, nx + self.overlaps[0] + self.nghosts) + ixstart = max(ix.start or -ngrow[0], -ngrow[0]) + ixstop = min(ix.stop or nx + 1 + ngrow[0], nx + self.overlaps[0] + ngrow[0]) else: ixstart = ix ixstop = ix + 1 if isinstance(iy, slice): - iystart = max(iy.start or -self.nghosts, -self.nghosts) - iystop = min(iy.stop or ny + 1 + self.nghosts, ny + self.overlaps[1] + self.nghosts) + iystart = max(iy.start or -ngrow[1], -ngrow[1]) + iystop = min(iy.stop or ny + 1 + ngrow[1], ny + self.overlaps[1] + ngrow[1]) else: iystart = iy iystop = iy + 1 if isinstance(iz, slice): - izstart = max(iz.start or -self.nghosts, -self.nghosts) - izstop = min(iz.stop or nz + 1 + self.nghosts, nz + self.overlaps[2] + self.nghosts) + izstart = max(iz.start or -ngrow[2], -ngrow[2]) + izstop = min(iz.stop or nz + 1 + ngrow[2], nz + self.overlaps[2] + ngrow[2]) else: izstart = iz izstop = iz + 1 @@ -250,8 +251,8 @@ class _MultiFABWrapper(object): """Returns slices of a 2D decomposed array, """ - lovects = self._getlovects() - hivects = self._gethivects() + lovects, ngrow = self._getlovects() + hivects, ngrow = self._gethivects() fields = self._getfields() ix = index[0] @@ -270,22 +271,22 @@ class _MultiFABWrapper(object): else: ic = None - nx = hivects[0,:].max() - self.nghosts - nz = hivects[1,:].max() - self.nghosts + nx = hivects[0,:].max() - ngrow[0] + nz = hivects[1,:].max() - ngrow[1] if npes > 1: nx = comm_world.allreduce(nx, op=mpi.MAX) nz = comm_world.allreduce(nz, op=mpi.MAX) if isinstance(ix, slice): - ixstart = max(ix.start or -self.nghosts, -self.nghosts) - ixstop = min(ix.stop or nx + 1 + self.nghosts, nx + self.overlaps[0] + self.nghosts) + ixstart = max(ix.start or -ngrow[0], -ngrow[0]) + ixstop = min(ix.stop or nx + 1 + ngrow[0], nx + self.overlaps[0] + ngrow[0]) else: ixstart = ix ixstop = ix + 1 if isinstance(iz, slice): - izstart = max(iz.start or -self.nghosts, -self.nghosts) - izstop = min(iz.stop or nz + 1 + self.nghosts, nz + self.overlaps[1] + self.nghosts) + izstart = max(iz.start or -ngrow[1], -ngrow[1]) + izstop = min(iz.stop or nz + 1 + ngrow[1], nz + self.overlaps[1] + ngrow[1]) else: izstart = iz izstop = iz + 1 @@ -365,8 +366,8 @@ class _MultiFABWrapper(object): iy = index[1] iz = index[2] - lovects = self._getlovects() - hivects = self._gethivects() + lovects, ngrow = self._getlovects() + hivects, ngrow = self._gethivects() fields = self._getfields() if len(index) > self.dim: @@ -377,9 +378,9 @@ class _MultiFABWrapper(object): else: ic = None - nx = hivects[0,:].max() - self.nghosts - ny = hivects[1,:].max() - self.nghosts - nz = hivects[2,:].max() - self.nghosts + nx = hivects[0,:].max() - ngrow[0] + ny = hivects[1,:].max() - ngrow[1] + nz = hivects[2,:].max() - ngrow[2] # --- Add extra dimensions so that the input has the same number of # --- dimensions as array. @@ -392,20 +393,20 @@ class _MultiFABWrapper(object): value3d.shape = sss if isinstance(ix, slice): - ixstart = max(ix.start or -self.nghosts, -self.nghosts) - ixstop = min(ix.stop or nx + 1 + self.nghosts, nx + self.overlaps[0] + self.nghosts) + ixstart = max(ix.start or -ngrow[0], -ngrow[0]) + ixstop = min(ix.stop or nx + 1 + ngrow[0], nx + self.overlaps[0] + ngrow[0]) else: ixstart = ix ixstop = ix + 1 if isinstance(iy, slice): - iystart = max(iy.start or -self.nghosts, -self.nghosts) - iystop = min(iy.stop or ny + 1 + self.nghosts, ny + self.overlaps[1] + self.nghosts) + iystart = max(iy.start or -ngrow[1], -ngrow[1]) + iystop = min(iy.stop or ny + 1 + ngrow[1], ny + self.overlaps[1] + ngrow[1]) else: iystart = iy iystop = iy + 1 if isinstance(iz, slice): - izstart = max(iz.start or -self.nghosts, -self.nghosts) - izstop = min(iz.stop or nz + 1 + self.nghosts, nz + self.overlaps[2] + self.nghosts) + izstart = max(iz.start or -ngrow[2], -ngrow[2]) + izstop = min(iz.stop or nz + 1 + ngrow[2], nz + self.overlaps[2] + ngrow[2]) else: izstart = iz izstop = iz + 1 @@ -442,8 +443,8 @@ class _MultiFABWrapper(object): ix = index[0] iz = index[2] - lovects = self._getlovects() - hivects = self._gethivects() + lovects, ngrow = self._getlovects() + hivects, ngrow = self._gethivects() fields = self._getfields() if len(fields[0].shape) > self.dim: @@ -459,8 +460,8 @@ class _MultiFABWrapper(object): else: ic = None - nx = hivects[0,:].max() - self.nghosts - nz = hivects[2,:].max() - self.nghosts + nx = hivects[0,:].max() - ngrow[0] + nz = hivects[2,:].max() - ngrow[1] # --- Add extra dimensions so that the input has the same number of # --- dimensions as array. @@ -472,14 +473,14 @@ class _MultiFABWrapper(object): value3d.shape = sss if isinstance(ix, slice): - ixstart = max(ix.start or -self.nghosts, -self.nghosts) - ixstop = min(ix.stop or nx + 1 + self.nghosts, nx + self.overlaps[0] + self.nghosts) + ixstart = max(ix.start or -ngrow[0], -ngrow[0]) + ixstop = min(ix.stop or nx + 1 + ngrow[0], nx + self.overlaps[0] + ngrow[0]) else: ixstart = ix ixstop = ix + 1 if isinstance(iz, slice): - izstart = max(iz.start or -self.nghosts, -self.nghosts) - izstop = min(iz.stop or nz + 1 + self.nghosts, nz + self.overlaps[2] + self.nghosts) + izstart = max(iz.start or -ngrow[1], -ngrow[1]) + izstop = min(iz.stop or nz + 1 + ngrow[1], nz + self.overlaps[2] + ngrow[1]) else: izstart = iz izstop = iz + 1 |