diff options
Diffstat (limited to 'Python')
-rw-r--r-- | Python/pywarpx/WarpX.py | 4 | ||||
-rwxr-xr-x | Python/pywarpx/_libwarpx.py | 26 | ||||
-rw-r--r-- | Python/pywarpx/picmi.py | 8 |
3 files changed, 26 insertions, 12 deletions
diff --git a/Python/pywarpx/WarpX.py b/Python/pywarpx/WarpX.py index dc09ba241..73106cf62 100644 --- a/Python/pywarpx/WarpX.py +++ b/Python/pywarpx/WarpX.py @@ -68,10 +68,10 @@ class WarpX(Bucket): return argv - def init(self): + def init(self, mpi_comm=None): from . import wx argv = ['warpx'] + self.create_argv_list() - wx.initialize(argv) + wx.initialize(argv, mpi_comm=mpi_comm) def evolve(self, nsteps=-1): from . import wx diff --git a/Python/pywarpx/_libwarpx.py b/Python/pywarpx/_libwarpx.py index d51bdfa9d..f1cc6564d 100755 --- a/Python/pywarpx/_libwarpx.py +++ b/Python/pywarpx/_libwarpx.py @@ -18,10 +18,16 @@ from .Geometry import geometry try: # --- If mpi4py is going to be used, this needs to be imported - # --- before libwarpx is loaded (though don't know why) + # --- before libwarpx is loaded, because mpi4py calls MPI_Init from mpi4py import MPI + # --- Change MPI Comm type depending on MPICH (int) or OpenMPI (void*) + if MPI._sizeof(MPI.Comm) == ctypes.sizeof(ctypes.c_int): + _MPI_Comm_type = ctypes.c_int + else: + _MPI_Comm_type = ctypes.c_void_p except ImportError: - pass + MPI = None + _MPI_Comm_type = ctypes.c_void_p # --- Is there a better way of handling constants? clight = 2.99792458e+8 # m/s @@ -136,6 +142,7 @@ def _array1d_from_pointer(pointer, dtype, size): # set the arg and return types of the wrapped functions libwarpx.amrex_init.argtypes = (ctypes.c_int, _LP_LP_c_char) +libwarpx.amrex_init_with_inited_mpi.argtypes = (ctypes.c_int, _LP_LP_c_char, _MPI_Comm_type) libwarpx.warpx_getParticleStructs.restype = _LP_particle_p libwarpx.warpx_getParticleArrays.restype = _LP_LP_c_particlereal libwarpx.warpx_getEfield.restype = _LP_LP_c_real @@ -208,6 +215,8 @@ libwarpx.warpx_getdt.restype = c_real libwarpx.warpx_maxStep.restype = ctypes.c_int libwarpx.warpx_stopTime.restype = c_real libwarpx.warpx_finestLevel.restype = ctypes.c_int +libwarpx.warpx_getMyProc.restype = ctypes.c_int +libwarpx.warpx_getNProcs.restype = ctypes.c_int libwarpx.warpx_EvolveE.argtypes = [c_real] libwarpx.warpx_EvolveB.argtypes = [c_real] @@ -234,7 +243,7 @@ def get_nattr(): # --- The -3 is because the comps include the velocites return libwarpx.warpx_nComps() - 3 -def amrex_init(argv): +def amrex_init(argv, mpi_comm=None): # --- Construct the ctype list of strings to pass in argc = len(argv) argvC = (_LP_c_char * (argc+1))() @@ -242,9 +251,14 @@ def amrex_init(argv): enc_arg = arg.encode('utf-8') argvC[i] = ctypes.create_string_buffer(enc_arg) - libwarpx.amrex_init(argc, argvC) + if mpi_comm is None or MPI is None: + libwarpx.amrex_init(argc, argvC) + else: + comm_ptr = MPI._addressof(mpi_comm) + comm_val = _MPI_Comm_type.from_address(comm_ptr) + libwarpx.amrex_init_with_inited_mpi(argc, argvC, comm_val) -def initialize(argv=None): +def initialize(argv=None, mpi_comm=None): ''' Initialize WarpX and AMReX. Must be called before @@ -253,7 +267,7 @@ def initialize(argv=None): ''' if argv is None: argv = sys.argv - amrex_init(argv) + amrex_init(argv, mpi_comm) libwarpx.warpx_ConvertLabParamsToBoost() libwarpx.warpx_ReadBCParams() if geometry_dim == 'rz': diff --git a/Python/pywarpx/picmi.py b/Python/pywarpx/picmi.py index fdc2ce1e0..257cadb9a 100644 --- a/Python/pywarpx/picmi.py +++ b/Python/pywarpx/picmi.py @@ -804,12 +804,12 @@ class Simulation(picmistandard.PICMI_Simulation): for diagnostic in self.diagnostics: diagnostic.initialize_inputs() - def initialize_warpx(self): + def initialize_warpx(self, mpi_comm=None): if self.warpx_initialized: return self.warpx_initialized = True - pywarpx.warpx.init() + pywarpx.warpx.init(mpi_comm) def write_input_file(self, file_name='inputs'): self.initialize_inputs() @@ -820,9 +820,9 @@ class Simulation(picmistandard.PICMI_Simulation): kw['stop_time'] = self.max_time pywarpx.warpx.write_inputs(file_name, **kw) - def step(self, nsteps=None): + def step(self, nsteps=None, mpi_comm=None): self.initialize_inputs() - self.initialize_warpx() + self.initialize_warpx(mpi_comm) if nsteps is None: if self.max_steps is not None: nsteps = self.max_steps |