aboutsummaryrefslogtreecommitdiff
path: root/Python
diff options
context:
space:
mode:
Diffstat (limited to 'Python')
-rwxr-xr-xPython/pywarpx/_libwarpx.py40
1 files changed, 27 insertions, 13 deletions
diff --git a/Python/pywarpx/_libwarpx.py b/Python/pywarpx/_libwarpx.py
index 629cecd2a..65a4d3cd5 100755
--- a/Python/pywarpx/_libwarpx.py
+++ b/Python/pywarpx/_libwarpx.py
@@ -133,19 +133,22 @@ class LibWarpX():
if _ParticleReal_size == 8:
c_particlereal = ctypes.c_double
- _numpy_particlereal_dtype = 'f8'
+ self._numpy_particlereal_dtype = 'f8'
else:
c_particlereal = ctypes.c_float
- _numpy_particlereal_dtype = 'f4'
+ self._numpy_particlereal_dtype = 'f4'
self.dim = self.libwarpx_so.warpx_SpaceDim()
# our particle data type, depends on _ParticleReal_size
- _p_struct = [(d, _numpy_particlereal_dtype) for d in 'xyz'[:self.dim]] + [('id', 'i4'), ('cpu', 'i4')]
+ _p_struct = (
+ [(d, self._numpy_particlereal_dtype) for d in 'xyz'[:self.dim]]
+ + [('id', 'i4'), ('cpu', 'i4')]
+ )
self._p_dtype = np.dtype(_p_struct, align=True)
_numpy_to_ctypes = {}
- _numpy_to_ctypes[_numpy_particlereal_dtype] = c_particlereal
+ _numpy_to_ctypes[self._numpy_particlereal_dtype] = c_particlereal
_numpy_to_ctypes['i4'] = ctypes.c_int
class Particle(ctypes.Structure):
@@ -594,32 +597,43 @@ class LibWarpX():
# --- Broadcast scalars into appropriate length arrays
# --- If the parameter was not supplied, use the default value
if lenx == 1:
- x = np.full(maxlen, (x or 0.), float)
+ x = np.full(maxlen, (x or 0.), self._numpy_particlereal_dtype)
if leny == 1:
- y = np.full(maxlen, (y or 0.), float)
+ y = np.full(maxlen, (y or 0.), self._numpy_particlereal_dtype)
if lenz == 1:
- z = np.full(maxlen, (z or 0.), float)
+ z = np.full(maxlen, (z or 0.), self._numpy_particlereal_dtype)
if lenux == 1:
- ux = np.full(maxlen, (ux or 0.), float)
+ ux = np.full(maxlen, (ux or 0.), self._numpy_particlereal_dtype)
if lenuy == 1:
- uy = np.full(maxlen, (uy or 0.), float)
+ uy = np.full(maxlen, (uy or 0.), self._numpy_particlereal_dtype)
if lenuz == 1:
- uz = np.full(maxlen, (uz or 0.), float)
+ uz = np.full(maxlen, (uz or 0.), self._numpy_particlereal_dtype)
if lenw == 1:
- w = np.full(maxlen, (w or 0.), float)
+ w = np.full(maxlen, (w or 0.), self._numpy_particlereal_dtype)
for key, val in kwargs.items():
if np.size(val) == 1:
- kwargs[key] = np.full(maxlen, val, float)
+ kwargs[key] = np.full(
+ maxlen, val, self._numpy_particlereal_dtype
+ )
# --- The -3 is because the comps include the velocites
nattr = self.get_nattr_species(species_name) - 3
- attr = np.zeros((maxlen, nattr))
+ attr = np.zeros((maxlen, nattr), self._numpy_particlereal_dtype)
attr[:,0] = w
for key, vals in kwargs.items():
# --- The -3 is because components 1 to 3 are velocities
attr[:,self.get_particle_comp_index(species_name, key)-3] = vals
+ # Iff x/y/z/ux/uy/uz are not numpy arrays of the correct dtype, new
+ # array copies are made with the correct dtype
+ x = x.astype(self._numpy_particlereal_dtype, copy=False)
+ y = y.astype(self._numpy_particlereal_dtype, copy=False)
+ z = z.astype(self._numpy_particlereal_dtype, copy=False)
+ ux = ux.astype(self._numpy_particlereal_dtype, copy=False)
+ uy = uy.astype(self._numpy_particlereal_dtype, copy=False)
+ uz = uz.astype(self._numpy_particlereal_dtype, copy=False)
+
self.libwarpx_so.warpx_addNParticles(
ctypes.c_char_p(species_name.encode('utf-8')), x.size,
x, y, z, ux, uy, uz, nattr, attr, unique_particles