diff options
author | 2019-10-01 15:22:16 -0700 | |
---|---|---|
committer | 2019-10-01 15:22:16 -0700 | |
commit | d1b89179d89adc26037816afcf28ca2a61816a81 (patch) | |
tree | 9eb26d046b313c2bb88615ec7e4eda968387c5fa /Python | |
parent | 81e4a0f3407af2aad99cacaaba483de80189d198 (diff) | |
download | WarpX-d1b89179d89adc26037816afcf28ca2a61816a81.tar.gz WarpX-d1b89179d89adc26037816afcf28ca2a61816a81.tar.zst WarpX-d1b89179d89adc26037816afcf28ca2a61816a81.zip |
Additional fix for single precision in Python interface
Diffstat (limited to 'Python')
-rwxr-xr-x | Python/pywarpx/_libwarpx.py | 8 | ||||
-rw-r--r-- | Python/pywarpx/fields.py | 4 |
2 files changed, 8 insertions, 4 deletions
diff --git a/Python/pywarpx/_libwarpx.py b/Python/pywarpx/_libwarpx.py index cfb22257d..6c3e54619 100755 --- a/Python/pywarpx/_libwarpx.py +++ b/Python/pywarpx/_libwarpx.py @@ -66,22 +66,26 @@ _ParticleReal_size = libwarpx.warpx_ParticleReal_size() if _Real_size == 8: c_real = ctypes.c_double + _numpy_real_dtype = 'f8' else: c_real = ctypes.c_float + _numpy_real_dtype = 'f4' if _ParticleReal_size == 8: c_particlereal = ctypes.c_double + _numpy_particlereal_dtype = 'f8' else: c_particlereal = ctypes.c_float + _numpy_particlereal_dtype = 'f4' dim = libwarpx.warpx_SpaceDim() # our particle data type, depends on _ParticleReal_size -_p_struct = [(d, 'f%d'%_ParticleReal_size) for d in 'xyz'[:dim]] + [('id', 'i4'), ('cpu', 'i4')] +_p_struct = [(d, _numpy_particlereal_dtype) for d in 'xyz'[:dim]] + [('id', 'i4'), ('cpu', 'i4')] _p_dtype = np.dtype(_p_struct, align=True) _numpy_to_ctypes = {} -_numpy_to_ctypes['f%d'%_ParticleReal_size] = c_particlereal +_numpy_to_ctypes[_numpy_particlereal_dtype] = c_particlereal _numpy_to_ctypes['i4'] = ctypes.c_int class Particle(ctypes.Structure): diff --git a/Python/pywarpx/fields.py b/Python/pywarpx/fields.py index 8b7283a46..fa247db07 100644 --- a/Python/pywarpx/fields.py +++ b/Python/pywarpx/fields.py @@ -127,7 +127,7 @@ class _MultiFABWrapper(object): sss = (max(0, ixstop - ixstart), max(0, iystop - iystart), max(0, izstop - izstart)) - resultglobal = np.zeros(sss) + resultglobal = np.zeros(sss, dtype=_libwarpx._numpy_real_dtype) datalist = [] for i in range(len(fields)): @@ -222,7 +222,7 @@ class _MultiFABWrapper(object): max(0, izstop - izstart)) if ncomps > 1 and ic is None: sss = tuple(list(sss) + [ncomps]) - resultglobal = np.zeros(sss) + resultglobal = np.zeros(sss, dtype=_libwarpx._numpy_real_dtype) datalist = [] for i in range(len(fields)): |