aboutsummaryrefslogtreecommitdiff
path: root/Python
diff options
context:
space:
mode:
authorGravatar Dave Grote <grote1@llnl.gov> 2019-10-01 15:22:16 -0700
committerGravatar Dave Grote <grote1@llnl.gov> 2019-10-01 15:22:16 -0700
commitd1b89179d89adc26037816afcf28ca2a61816a81 (patch)
tree9eb26d046b313c2bb88615ec7e4eda968387c5fa /Python
parent81e4a0f3407af2aad99cacaaba483de80189d198 (diff)
downloadWarpX-d1b89179d89adc26037816afcf28ca2a61816a81.tar.gz
WarpX-d1b89179d89adc26037816afcf28ca2a61816a81.tar.zst
WarpX-d1b89179d89adc26037816afcf28ca2a61816a81.zip
Additional fix for single precision in Python interface
Diffstat (limited to 'Python')
-rwxr-xr-xPython/pywarpx/_libwarpx.py8
-rw-r--r--Python/pywarpx/fields.py4
2 files changed, 8 insertions, 4 deletions
diff --git a/Python/pywarpx/_libwarpx.py b/Python/pywarpx/_libwarpx.py
index cfb22257d..6c3e54619 100755
--- a/Python/pywarpx/_libwarpx.py
+++ b/Python/pywarpx/_libwarpx.py
@@ -66,22 +66,26 @@ _ParticleReal_size = libwarpx.warpx_ParticleReal_size()
if _Real_size == 8:
c_real = ctypes.c_double
+ _numpy_real_dtype = 'f8'
else:
c_real = ctypes.c_float
+ _numpy_real_dtype = 'f4'
if _ParticleReal_size == 8:
c_particlereal = ctypes.c_double
+ _numpy_particlereal_dtype = 'f8'
else:
c_particlereal = ctypes.c_float
+ _numpy_particlereal_dtype = 'f4'
dim = libwarpx.warpx_SpaceDim()
# our particle data type, depends on _ParticleReal_size
-_p_struct = [(d, 'f%d'%_ParticleReal_size) for d in 'xyz'[:dim]] + [('id', 'i4'), ('cpu', 'i4')]
+_p_struct = [(d, _numpy_particlereal_dtype) for d in 'xyz'[:dim]] + [('id', 'i4'), ('cpu', 'i4')]
_p_dtype = np.dtype(_p_struct, align=True)
_numpy_to_ctypes = {}
-_numpy_to_ctypes['f%d'%_ParticleReal_size] = c_particlereal
+_numpy_to_ctypes[_numpy_particlereal_dtype] = c_particlereal
_numpy_to_ctypes['i4'] = ctypes.c_int
class Particle(ctypes.Structure):
diff --git a/Python/pywarpx/fields.py b/Python/pywarpx/fields.py
index 8b7283a46..fa247db07 100644
--- a/Python/pywarpx/fields.py
+++ b/Python/pywarpx/fields.py
@@ -127,7 +127,7 @@ class _MultiFABWrapper(object):
sss = (max(0, ixstop - ixstart),
max(0, iystop - iystart),
max(0, izstop - izstart))
- resultglobal = np.zeros(sss)
+ resultglobal = np.zeros(sss, dtype=_libwarpx._numpy_real_dtype)
datalist = []
for i in range(len(fields)):
@@ -222,7 +222,7 @@ class _MultiFABWrapper(object):
max(0, izstop - izstart))
if ncomps > 1 and ic is None:
sss = tuple(list(sss) + [ncomps])
- resultglobal = np.zeros(sss)
+ resultglobal = np.zeros(sss, dtype=_libwarpx._numpy_real_dtype)
datalist = []
for i in range(len(fields)):