aboutsummaryrefslogtreecommitdiff
path: root/Python
diff options
context:
space:
mode:
Diffstat (limited to 'Python')
-rw-r--r--Python/pywarpx/WarpX.py4
-rwxr-xr-xPython/pywarpx/_libwarpx.py26
-rw-r--r--Python/pywarpx/picmi.py8
3 files changed, 26 insertions, 12 deletions
diff --git a/Python/pywarpx/WarpX.py b/Python/pywarpx/WarpX.py
index dc09ba241..73106cf62 100644
--- a/Python/pywarpx/WarpX.py
+++ b/Python/pywarpx/WarpX.py
@@ -68,10 +68,10 @@ class WarpX(Bucket):
return argv
- def init(self):
+ def init(self, mpi_comm=None):
from . import wx
argv = ['warpx'] + self.create_argv_list()
- wx.initialize(argv)
+ wx.initialize(argv, mpi_comm=mpi_comm)
def evolve(self, nsteps=-1):
from . import wx
diff --git a/Python/pywarpx/_libwarpx.py b/Python/pywarpx/_libwarpx.py
index d51bdfa9d..f1cc6564d 100755
--- a/Python/pywarpx/_libwarpx.py
+++ b/Python/pywarpx/_libwarpx.py
@@ -18,10 +18,16 @@ from .Geometry import geometry
try:
# --- If mpi4py is going to be used, this needs to be imported
- # --- before libwarpx is loaded (though don't know why)
+ # --- before libwarpx is loaded, because mpi4py calls MPI_Init
from mpi4py import MPI
+ # --- Change MPI Comm type depending on MPICH (int) or OpenMPI (void*)
+ if MPI._sizeof(MPI.Comm) == ctypes.sizeof(ctypes.c_int):
+ _MPI_Comm_type = ctypes.c_int
+ else:
+ _MPI_Comm_type = ctypes.c_void_p
except ImportError:
- pass
+ MPI = None
+ _MPI_Comm_type = ctypes.c_void_p
# --- Is there a better way of handling constants?
clight = 2.99792458e+8 # m/s
@@ -136,6 +142,7 @@ def _array1d_from_pointer(pointer, dtype, size):
# set the arg and return types of the wrapped functions
libwarpx.amrex_init.argtypes = (ctypes.c_int, _LP_LP_c_char)
+libwarpx.amrex_init_with_inited_mpi.argtypes = (ctypes.c_int, _LP_LP_c_char, _MPI_Comm_type)
libwarpx.warpx_getParticleStructs.restype = _LP_particle_p
libwarpx.warpx_getParticleArrays.restype = _LP_LP_c_particlereal
libwarpx.warpx_getEfield.restype = _LP_LP_c_real
@@ -208,6 +215,8 @@ libwarpx.warpx_getdt.restype = c_real
libwarpx.warpx_maxStep.restype = ctypes.c_int
libwarpx.warpx_stopTime.restype = c_real
libwarpx.warpx_finestLevel.restype = ctypes.c_int
+libwarpx.warpx_getMyProc.restype = ctypes.c_int
+libwarpx.warpx_getNProcs.restype = ctypes.c_int
libwarpx.warpx_EvolveE.argtypes = [c_real]
libwarpx.warpx_EvolveB.argtypes = [c_real]
@@ -234,7 +243,7 @@ def get_nattr():
# --- The -3 is because the comps include the velocites
return libwarpx.warpx_nComps() - 3
-def amrex_init(argv):
+def amrex_init(argv, mpi_comm=None):
# --- Construct the ctype list of strings to pass in
argc = len(argv)
argvC = (_LP_c_char * (argc+1))()
@@ -242,9 +251,14 @@ def amrex_init(argv):
enc_arg = arg.encode('utf-8')
argvC[i] = ctypes.create_string_buffer(enc_arg)
- libwarpx.amrex_init(argc, argvC)
+ if mpi_comm is None or MPI is None:
+ libwarpx.amrex_init(argc, argvC)
+ else:
+ comm_ptr = MPI._addressof(mpi_comm)
+ comm_val = _MPI_Comm_type.from_address(comm_ptr)
+ libwarpx.amrex_init_with_inited_mpi(argc, argvC, comm_val)
-def initialize(argv=None):
+def initialize(argv=None, mpi_comm=None):
'''
Initialize WarpX and AMReX. Must be called before
@@ -253,7 +267,7 @@ def initialize(argv=None):
'''
if argv is None:
argv = sys.argv
- amrex_init(argv)
+ amrex_init(argv, mpi_comm)
libwarpx.warpx_ConvertLabParamsToBoost()
libwarpx.warpx_ReadBCParams()
if geometry_dim == 'rz':
diff --git a/Python/pywarpx/picmi.py b/Python/pywarpx/picmi.py
index fdc2ce1e0..257cadb9a 100644
--- a/Python/pywarpx/picmi.py
+++ b/Python/pywarpx/picmi.py
@@ -804,12 +804,12 @@ class Simulation(picmistandard.PICMI_Simulation):
for diagnostic in self.diagnostics:
diagnostic.initialize_inputs()
- def initialize_warpx(self):
+ def initialize_warpx(self, mpi_comm=None):
if self.warpx_initialized:
return
self.warpx_initialized = True
- pywarpx.warpx.init()
+ pywarpx.warpx.init(mpi_comm)
def write_input_file(self, file_name='inputs'):
self.initialize_inputs()
@@ -820,9 +820,9 @@ class Simulation(picmistandard.PICMI_Simulation):
kw['stop_time'] = self.max_time
pywarpx.warpx.write_inputs(file_name, **kw)
- def step(self, nsteps=None):
+ def step(self, nsteps=None, mpi_comm=None):
self.initialize_inputs()
- self.initialize_warpx()
+ self.initialize_warpx(mpi_comm)
if nsteps is None:
if self.max_steps is not None:
nsteps = self.max_steps