aboutsummaryrefslogtreecommitdiff
path: root/Python
diff options
context:
space:
mode:
authorGravatar Dave Grote <grote1@llnl.gov> 2019-12-12 12:54:54 -0800
committerGravatar Dave Grote <grote1@llnl.gov> 2019-12-12 12:54:54 -0800
commit76d61b6cd701d1fd283b6e3dae2fde729f875868 (patch)
tree99ee3589b1b3528ad8dba49d81989f937df9a858 /Python
parent84892de1590e0d856cbe19b9e269db8cae8ac9a2 (diff)
downloadWarpX-76d61b6cd701d1fd283b6e3dae2fde729f875868.tar.gz
WarpX-76d61b6cd701d1fd283b6e3dae2fde729f875868.tar.zst
WarpX-76d61b6cd701d1fd283b6e3dae2fde729f875868.zip
In Python fields wrapper, allowed multiple compoments in 3D
Diffstat (limited to 'Python')
-rw-r--r--Python/pywarpx/fields.py35
1 files changed, 32 insertions, 3 deletions
diff --git a/Python/pywarpx/fields.py b/Python/pywarpx/fields.py
index fa247db07..225725331 100644
--- a/Python/pywarpx/fields.py
+++ b/Python/pywarpx/fields.py
@@ -87,14 +87,28 @@ class _MultiFABWrapper(object):
def _getitem3d(self, index):
"""Returns slices of a 3D decomposed array,
"""
- ix = index[0]
- iy = index[1]
- iz = index[2]
lovects = self._getlovects()
hivects = self._gethivects()
fields = self._getfields()
+ ix = index[0]
+ iy = index[1]
+ iz = index[2]
+
+ if len(fields[0].shape) > self.dim:
+ ncomps = fields[0].shape[-1]
+ else:
+ ncomps = 1
+
+ if len(index) > self.dim:
+ if ncomps > 1:
+ ic = index[-1]
+ else:
+ raise Exception('Too many indices given')
+ else:
+ ic = None
+
nx = hivects[0,:].max() - self.nghosts
ny = hivects[1,:].max() - self.nghosts
nz = hivects[2,:].max() - self.nghosts
@@ -124,9 +138,12 @@ class _MultiFABWrapper(object):
izstop = iz + 1
# --- Setup the size of the array to be returned and create it.
+ # --- Space is added for multiple components if needed.
sss = (max(0, ixstop - ixstart),
max(0, iystop - iystart),
max(0, izstop - izstart))
+ if ncomps > 1 and ic is None:
+ sss = tuple(list(sss) + [ncomps])
resultglobal = np.zeros(sss, dtype=_libwarpx._numpy_real_dtype)
datalist = []
@@ -145,6 +162,8 @@ class _MultiFABWrapper(object):
sss = (slice(ix1 - lovects[0,i], ix2 - lovects[0,i]),
slice(iy1 - lovects[1,i], iy2 - lovects[1,i]),
slice(iz1 - lovects[2,i], iz2 - lovects[2,i]))
+ if ic is not None:
+ sss = tuple(list(sss) + [ic])
vslice = (slice(ix1 - ixstart, ix2 - ixstart),
slice(iy1 - iystart, iy2 - iystart),
@@ -295,6 +314,14 @@ class _MultiFABWrapper(object):
hivects = self._gethivects()
fields = self._getfields()
+ if len(index) > self.dim:
+ if ncomps > 1:
+ ic = index[-1]
+ else:
+ raise Exception('Too many indices given')
+ else:
+ ic = None
+
nx = hivects[0,:].max() - self.nghosts
ny = hivects[1,:].max() - self.nghosts
nz = hivects[2,:].max() - self.nghosts
@@ -343,6 +370,8 @@ class _MultiFABWrapper(object):
sss = (slice(ix1 - lovects[0,i], ix2 - lovects[0,i]),
slice(iy1 - lovects[1,i], iy2 - lovects[1,i]),
slice(iz1 - lovects[2,i], iz2 - lovects[2,i]))
+ if ic is not None:
+ sss = tuple(list(sss) + [ic])
if isinstance(value, np.ndarray):
vslice = (slice(ix1 - ixstart, ix2 - ixstart),