diff options
Diffstat (limited to 'Python')
-rwxr-xr-x | Python/pywarpx/_libwarpx.py | 40 |
1 files changed, 27 insertions, 13 deletions
diff --git a/Python/pywarpx/_libwarpx.py b/Python/pywarpx/_libwarpx.py index 629cecd2a..65a4d3cd5 100755 --- a/Python/pywarpx/_libwarpx.py +++ b/Python/pywarpx/_libwarpx.py @@ -133,19 +133,22 @@ class LibWarpX(): if _ParticleReal_size == 8: c_particlereal = ctypes.c_double - _numpy_particlereal_dtype = 'f8' + self._numpy_particlereal_dtype = 'f8' else: c_particlereal = ctypes.c_float - _numpy_particlereal_dtype = 'f4' + self._numpy_particlereal_dtype = 'f4' self.dim = self.libwarpx_so.warpx_SpaceDim() # our particle data type, depends on _ParticleReal_size - _p_struct = [(d, _numpy_particlereal_dtype) for d in 'xyz'[:self.dim]] + [('id', 'i4'), ('cpu', 'i4')] + _p_struct = ( + [(d, self._numpy_particlereal_dtype) for d in 'xyz'[:self.dim]] + + [('id', 'i4'), ('cpu', 'i4')] + ) self._p_dtype = np.dtype(_p_struct, align=True) _numpy_to_ctypes = {} - _numpy_to_ctypes[_numpy_particlereal_dtype] = c_particlereal + _numpy_to_ctypes[self._numpy_particlereal_dtype] = c_particlereal _numpy_to_ctypes['i4'] = ctypes.c_int class Particle(ctypes.Structure): @@ -594,32 +597,43 @@ class LibWarpX(): # --- Broadcast scalars into appropriate length arrays # --- If the parameter was not supplied, use the default value if lenx == 1: - x = np.full(maxlen, (x or 0.), float) + x = np.full(maxlen, (x or 0.), self._numpy_particlereal_dtype) if leny == 1: - y = np.full(maxlen, (y or 0.), float) + y = np.full(maxlen, (y or 0.), self._numpy_particlereal_dtype) if lenz == 1: - z = np.full(maxlen, (z or 0.), float) + z = np.full(maxlen, (z or 0.), self._numpy_particlereal_dtype) if lenux == 1: - ux = np.full(maxlen, (ux or 0.), float) + ux = np.full(maxlen, (ux or 0.), self._numpy_particlereal_dtype) if lenuy == 1: - uy = np.full(maxlen, (uy or 0.), float) + uy = np.full(maxlen, (uy or 0.), self._numpy_particlereal_dtype) if lenuz == 1: - uz = np.full(maxlen, (uz or 0.), float) + uz = np.full(maxlen, (uz or 0.), self._numpy_particlereal_dtype) if lenw == 1: - w = np.full(maxlen, (w or 0.), float) + w = np.full(maxlen, (w or 0.), self._numpy_particlereal_dtype) for key, val in kwargs.items(): if np.size(val) == 1: - kwargs[key] = np.full(maxlen, val, float) + kwargs[key] = np.full( + maxlen, val, self._numpy_particlereal_dtype + ) # --- The -3 is because the comps include the velocites nattr = self.get_nattr_species(species_name) - 3 - attr = np.zeros((maxlen, nattr)) + attr = np.zeros((maxlen, nattr), self._numpy_particlereal_dtype) attr[:,0] = w for key, vals in kwargs.items(): # --- The -3 is because components 1 to 3 are velocities attr[:,self.get_particle_comp_index(species_name, key)-3] = vals + # Iff x/y/z/ux/uy/uz are not numpy arrays of the correct dtype, new + # array copies are made with the correct dtype + x = x.astype(self._numpy_particlereal_dtype, copy=False) + y = y.astype(self._numpy_particlereal_dtype, copy=False) + z = z.astype(self._numpy_particlereal_dtype, copy=False) + ux = ux.astype(self._numpy_particlereal_dtype, copy=False) + uy = uy.astype(self._numpy_particlereal_dtype, copy=False) + uz = uz.astype(self._numpy_particlereal_dtype, copy=False) + self.libwarpx_so.warpx_addNParticles( ctypes.c_char_p(species_name.encode('utf-8')), x.size, x, y, z, ux, uy, uz, nattr, attr, unique_particles |