aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Source/WarpX.H8
-rw-r--r--Source/WarpXFFT.cpp163
-rw-r--r--Source/WarpX_boosted_frame.F906
-rw-r--r--Source/WarpX_f.H6
-rw-r--r--Source/WarpX_fft.F90131
-rw-r--r--Source/WarpX_laser.F904
-rw-r--r--Source/WarpX_picsar.F903
-rw-r--r--Source/main.cpp11
8 files changed, 224 insertions, 108 deletions
diff --git a/Source/WarpX.H b/Source/WarpX.H
index 97da4ea27..ebd5fe420 100644
--- a/Source/WarpX.H
+++ b/Source/WarpX.H
@@ -22,6 +22,10 @@
#include <WarpXPML.H>
#include <WarpXBoostedFrameDiagnostic.H>
+#ifdef WARPX_USE_PSATD
+#include <fftw3.h>
+#endif
+
class NoOpPhysBC
: public amrex::PhysBCFunctBase
{
@@ -445,8 +449,8 @@ private:
void* data[N] = { nullptr };
~FFTData () {
- for (int i = 0; i < N; ++i) { // The memory is allocated by Fortran routines
- std::free(data[i]);
+ for (int i = 0; i < N; ++i) { // The memory is allocated with fftw_alloc.
+ fftw_free(data[i]);
data[i] = nullptr;
}
}
diff --git a/Source/WarpXFFT.cpp b/Source/WarpXFFT.cpp
index 8d37b9307..f394d5f64 100644
--- a/Source/WarpXFFT.cpp
+++ b/Source/WarpXFFT.cpp
@@ -1,6 +1,8 @@
#include <WarpX.H>
#include <WarpX_f.H>
+#include <AMReX_BaseFab_f.H>
+#include <AMReX_iMultiFab.H>
using namespace amrex;
@@ -8,21 +10,105 @@ constexpr int WarpX::FFTData::N;
namespace {
+/** \brief Returns an "owner mask" which 1 for all cells, except
+ * for the duplicated (physical) cells of a nodal grid.
+ *
+ * More precisely, for these cells (which are represented on several grids)
+ * the owner mask is 1 only if these cells are at the lower left end of
+ * the local grid - or if these cells are at the end of the physical domain
+ * Therefore, there for these cells, there will be only one grid for
+ * which the owner mask is non-zero.
+ */
+static iMultiFab
+BuildFFTOwnerMask (const MultiFab& mf, const Geometry& geom)
+{
+ const BoxArray& ba = mf.boxArray();
+ const DistributionMapping& dm = mf.DistributionMap();
+ iMultiFab mask(ba, dm, 1, 0);
+ const int owner = 1;
+ const int nonowner = 0;
+ mask.setVal(owner);
+
+ const Box& domain_box = amrex::convert(geom.Domain(), ba.ixType());
+
+ AMREX_ASSERT(ba.complementIn(domain_box).isEmpty());
+
+#ifdef _OPENMP
+#pragma omp parallel
+#endif
+ for (MFIter mfi(mask); mfi.isValid(); ++mfi)
+ {
+ IArrayBox& fab = mask[mfi];
+ const Box& bx = fab.box();
+ Box bx2 = bx;
+ for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
+ // Detect nodal dimensions
+ if (bx2.type(idim) == IndexType::NODE) {
+ // Make sure that this grid does not touch the end of
+ // the physical domain.
+ if (bx2.bigEnd(idim) < domain_box.bigEnd(idim)) {
+ bx2.growHi(idim, -1);
+ }
+ }
+ }
+ const BoxList& bl = amrex::boxDiff(bx, bx2);
+ // Set owner mask in these cells
+ for (const auto& b : bl) {
+ fab.setVal(nonowner, b, 0, 1);
+ }
+ }
+
+ return mask;
+}
+
+/** \brief Copy the data from the FFT grid to the regular grid
+ *
+ * Because, for nodal grid, some cells are duplicated on several boxes,
+ * special care has to be taken in order to have consistent values on
+ * each boxes when copying this data. Here this is done by setting a
+ * mask, where, for these duplicated cells, the mask is non-zero on only
+ * one box.
+ */
static void
-CopyDataFromFFTToValid (MultiFab& mf, const MultiFab& mf_fft, const BoxArray& ba_valid_fft)
+CopyDataFromFFTToValid (MultiFab& mf, const MultiFab& mf_fft, const BoxArray& ba_valid_fft, const Geometry& geom)
{
auto idx_type = mf_fft.ixType();
MultiFab mftmp(amrex::convert(ba_valid_fft,idx_type), mf_fft.DistributionMap(), 1, 0);
+
+ const iMultiFab& mask = BuildFFTOwnerMask(mftmp, geom);
+
+ // Local copy: whenever an MPI rank owns both the data from the FFT
+ // grid and from the regular grid, for overlapping region, copy it locally
+#ifdef _OPENMP
+#pragma omp parallel
+#endif
for (MFIter mfi(mftmp,true); mfi.isValid(); ++mfi)
{
const Box& bx = mfi.tilebox();
- if (mf_fft[mfi].box().contains(bx))
+ FArrayBox& dstfab = mftmp[mfi];
+
+ const FArrayBox& srcfab = mf_fft[mfi];
+ const Box& srcbox = srcfab.box();
+
+ if (srcbox.contains(bx))
{
- mftmp[mfi].copy(mf_fft[mfi], bx, 0, bx, 0, 1);
+ // Copy the interior region (without guard cells)
+ dstfab.copy(srcfab, bx, 0, bx, 0, 1);
+ // Set the value to 0 whenever the mask is 0
+ // (i.e. for nodal duplicated cells, there is a single box
+ // for which the mask is different than 0)
+ amrex_fab_setval_ifnot (BL_TO_FORTRAN_BOX(bx),
+ BL_TO_FORTRAN_FAB(dstfab),
+ BL_TO_FORTRAN_ANYD(mask[mfi]),
+ 0.0); // if mask == 0, set value to zero
}
}
- mf.ParallelCopy(mftmp);
+ // Global copy: Get the remaining the data from other procs
+ // Use ParallelAdd instead of ParallelCopy, so that the value from
+ // the cell that has non-zero mask is the one which is retained.
+ mf.setVal(0.0, 0);
+ mf.ParallelAdd(mftmp);
}
}
@@ -146,7 +232,7 @@ WarpX::InitFFTComm (int lev)
// # of processes in the subcommunicator
int np_fft = nprocs / ngroups_fft;
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(np_fft*ngroups_fft == nprocs,
- "Number of processes must be divisilbe by number of FFT groups");
+ "Number of processes must be divisible by number of FFT groups");
int myproc = ParallelDescriptor::MyProc();
// my color in ngroups_fft subcommunicators. 0 <= color_fft < ngroups_fft
@@ -180,57 +266,50 @@ void
WarpX::FFTDomainDecompsition (int lev, BoxArray& ba_fft, DistributionMapping& dm_fft,
BoxArray& ba_valid, Box& domain_fft, const Box& domain)
{
+ AMREX_ALWAYS_ASSERT_WITH_MESSAGE(AMREX_SPACEDIM == 3, "PSATD only works in 3D");
+
IntVect nguards_fft(AMREX_D_DECL(nox_fft/2,noy_fft/2,noz_fft/2));
int nprocs = ParallelDescriptor::NProcs();
- int np_fft;
- MPI_Comm_size(comm_fft[lev], &np_fft);
-
- BoxList bl_fft; // List of boxes: will be filled by the boxes attributed to each proc
- bl_fft.reserve(nprocs);
- Vector<int> gid_fft; // List of group ID: will be filled with the FFT group ID of each box
- gid_fft.reserve(nprocs);
BoxList bl(domain, ngroups_fft); // This does a multi-D domain decomposition for groups
AMREX_ALWAYS_ASSERT(bl.size() == ngroups_fft);
const Vector<Box>& bldata = bl.data();
- // Fill bl_fft and gid_fft ; loop over FFT groups
- for (int igroup = 0; igroup < ngroups_fft; ++igroup)
- {
- // Within the group, 1d domain decomposition is performed.
- const Box& bx = amrex::grow(bldata[igroup], nguards_fft);
- // chop in z-direction into np_fft for FFTW
- BoxList tbl(bx, np_fft, Direction::z);
- bl_fft.join(tbl);
- for (int i = 0; i < np_fft; ++i) {
- gid_fft.push_back(igroup);
- }
- // Determine the sub-domain associated with the FFT group of the local proc
- if (igroup == color_fft[lev]) {
- domain_fft = bx;
- }
+
+ // This is the domain for the FFT sub-group (including guard cells)
+ domain_fft = amrex::grow(bldata[color_fft[lev]], nguards_fft);
+ // Ask FFTW to chop the current FFT sub-group domain in the z-direction
+ // and give a chunk to each MPI rank in the current sub-group.
+ int nz_fft, z0_fft;
+ warpx_fft_domain_decomp(&nz_fft, &z0_fft, BL_TO_FORTRAN_BOX(domain_fft));
+ // Each MPI rank adds a box with its chunk of the FFT grid
+ // (given by the above decomposition) to the list `bx_fft`,
+ // then list is shared among all MPI ranks via AllGather
+ Vector<Box> bx_fft;
+ if (nz_fft > 0) {
+ Box b = domain_fft;
+ b.setRange(2, z0_fft+domain_fft.smallEnd(2), nz_fft);
+ bx_fft.push_back(b);
}
+ amrex::AllGatherBoxes(bx_fft);
- // This BoxArray contains local FFT domains for each process
- ba_fft.define(std::move(bl_fft));
+ // Define the AMReX objects for the FFT grid: BoxArray and DistributionMapping
+ ba_fft.define(BoxList(std::move(bx_fft)));
AMREX_ALWAYS_ASSERT(ba_fft.size() == ParallelDescriptor::NProcs());
-
Vector<int> pmap(ba_fft.size());
std::iota(pmap.begin(), pmap.end(), 0);
dm_fft.define(std::move(pmap));
- //
// For communication between WarpX normal domain and FFT domain, we need to create a
// special BoxArray ba_valid
- //
-
const Box foobox(-nguards_fft-2, -nguards_fft-2);
BoxList bl_valid; // List of boxes: will be filled by the valid part of the subdomains of ba_fft
bl_valid.reserve(ba_fft.size());
+ int np_fft = nprocs / ngroups_fft;
for (int i = 0; i < ba_fft.size(); ++i)
{
- int igroup = gid_fft[i];
+ int igroup = dm_fft[i] / np_fft; // This should be consistent with InitFFTComm
const Box& bx = ba_fft[i] & bldata[igroup]; // Intersection with the domain of
// the FFT group *without* guard cells
if (bx.ok())
@@ -262,9 +341,7 @@ WarpX::InitFFTDataPlan (int lev)
for (MFIter mfi(*Efield_fp_fft[lev][0]); mfi.isValid(); ++mfi)
{
const Box& local_domain = amrex::enclosedCells(mfi.fabbox());
- warpx_fft_dataplan_init(BL_TO_FORTRAN_BOX(domain_fp_fft[lev]),
- BL_TO_FORTRAN_BOX(local_domain),
- &nox_fft, &noy_fft, &noz_fft,
+ warpx_fft_dataplan_init(&nox_fft, &noy_fft, &noz_fft,
(*dataptr_fp_fft[lev])[mfi].data, &FFTData::N,
dx_fp.data(), &dt[lev], &fftw_plan_measure );
}
@@ -335,12 +412,12 @@ WarpX::PushPSATD (int lev, amrex::Real /* dt */)
BL_PROFILE_VAR_STOP(blp_push_eb);
BL_PROFILE_VAR_START(blp_copy);
- CopyDataFromFFTToValid(*Efield_fp[lev][0], *Efield_fp_fft[lev][0], ba_valid_fp_fft[lev]);
- CopyDataFromFFTToValid(*Efield_fp[lev][1], *Efield_fp_fft[lev][1], ba_valid_fp_fft[lev]);
- CopyDataFromFFTToValid(*Efield_fp[lev][2], *Efield_fp_fft[lev][2], ba_valid_fp_fft[lev]);
- CopyDataFromFFTToValid(*Bfield_fp[lev][0], *Bfield_fp_fft[lev][0], ba_valid_fp_fft[lev]);
- CopyDataFromFFTToValid(*Bfield_fp[lev][1], *Bfield_fp_fft[lev][1], ba_valid_fp_fft[lev]);
- CopyDataFromFFTToValid(*Bfield_fp[lev][2], *Bfield_fp_fft[lev][2], ba_valid_fp_fft[lev]);
+ CopyDataFromFFTToValid(*Efield_fp[lev][0], *Efield_fp_fft[lev][0], ba_valid_fp_fft[lev], geom[lev]);
+ CopyDataFromFFTToValid(*Efield_fp[lev][1], *Efield_fp_fft[lev][1], ba_valid_fp_fft[lev], geom[lev]);
+ CopyDataFromFFTToValid(*Efield_fp[lev][2], *Efield_fp_fft[lev][2], ba_valid_fp_fft[lev], geom[lev]);
+ CopyDataFromFFTToValid(*Bfield_fp[lev][0], *Bfield_fp_fft[lev][0], ba_valid_fp_fft[lev], geom[lev]);
+ CopyDataFromFFTToValid(*Bfield_fp[lev][1], *Bfield_fp_fft[lev][1], ba_valid_fp_fft[lev], geom[lev]);
+ CopyDataFromFFTToValid(*Bfield_fp[lev][2], *Bfield_fp_fft[lev][2], ba_valid_fp_fft[lev], geom[lev]);
BL_PROFILE_VAR_STOP(blp_copy);
if (lev > 0)
diff --git a/Source/WarpX_boosted_frame.F90 b/Source/WarpX_boosted_frame.F90
index 3f8f0607c..3bce1817e 100644
--- a/Source/WarpX_boosted_frame.F90
+++ b/Source/WarpX_boosted_frame.F90
@@ -3,7 +3,7 @@ module warpx_boosted_frame_module
use iso_c_binding
use amrex_fort_module, only : amrex_real
- use constants
+ use constants, only : clight
implicit none
@@ -62,8 +62,6 @@ contains
i_boost, i_lab) &
bind(C, name="warpx_copy_slice_3d")
- use amrex_fort_module, only : amrex_real
-
integer , intent(in) :: ncomp, i_boost, i_lab
integer , intent(in) :: lo(3), hi(3)
integer , intent(in) :: tlo(3), thi(3)
@@ -88,8 +86,6 @@ contains
i_boost, i_lab) &
bind(C, name="warpx_copy_slice_2d")
- use amrex_fort_module, only : amrex_real
-
integer , intent(in) :: ncomp, i_boost, i_lab
integer , intent(in) :: lo(2), hi(2)
integer , intent(in) :: tlo(2), thi(2)
diff --git a/Source/WarpX_f.H b/Source/WarpX_f.H
index 948588ca9..029c07377 100644
--- a/Source/WarpX_f.H
+++ b/Source/WarpX_f.H
@@ -420,9 +420,9 @@ extern "C"
#ifdef WARPX_USE_PSATD
void warpx_fft_mpi_init (int fcomm);
- void warpx_fft_dataplan_init (const int* global_lo, const int* global_hi,
- const int* local_lo, const int* local_hi,
- const int* nox, const int* noy, const int* noz,
+ void warpx_fft_domain_decomp (int* warpx_local_nz, int* warpx_local_z0,
+ const int* global_lo, const int* global_hi);
+ void warpx_fft_dataplan_init (const int* nox, const int* noy, const int* noz,
void* fft_data, const int* ndata,
const amrex_real* dx_w, const amrex_real* dt_w,
const int* fftw_plan_measure );
diff --git a/Source/WarpX_fft.F90 b/Source/WarpX_fft.F90
index 6f1a5491c..406f3f90f 100644
--- a/Source/WarpX_fft.F90
+++ b/Source/WarpX_fft.F90
@@ -1,12 +1,15 @@
module warpx_fft_module
- use amrex_error_module
- use amrex_fort_module
+ use amrex_error_module, only : amrex_error, amrex_abort
+ use amrex_fort_module, only : amrex_real
use iso_c_binding
implicit none
+ include 'fftw3-mpi.f03'
+
private
- public :: warpx_fft_mpi_init, warpx_fft_dataplan_init, warpx_fft_nullify, warpx_fft_push_eb
+ public :: warpx_fft_mpi_init, warpx_fft_domain_decomp, warpx_fft_dataplan_init, warpx_fft_nullify, &
+ warpx_fft_push_eb
contains
@@ -25,21 +28,64 @@ contains
call mpi_comm_rank(comm, lrank, ierr)
rank = lrank
+
+#ifdef _OPENMP
+ ierr = fftw_init_threads()
+ if (ierr.eq.0) call amrex_error("fftw_init_threads failed")
+#endif
+ call fftw_mpi_init()
+#ifdef _OPENMP
+ call dfftw_init_threads(ierr)
+ if (ierr.eq.0) call amrex_error("dfftw_init_threads failed")
+#endif
end subroutine warpx_fft_mpi_init
!> @brief
+!! Ask FFTW to do domain decomposition.
+!
+! This is always a 1d domain decomposition along z ; it is typically
+! done on the *FFT sub-groups*, not the all domain
+ subroutine warpx_fft_domain_decomp (warpx_local_nz, warpx_local_z0, global_lo, global_hi) &
+ bind(c,name='warpx_fft_domain_decomp')
+ use picsar_precision, only : idp
+ use shared_data, only : comm, &
+ nx_global, ny_global, nz_global, & ! size of global FFT
+ nx, ny, nz ! size of local subdomains
+ use mpi_fftw3, only : local_nz, local_z0, fftw_mpi_local_size_3d, alloc_local
+
+ integer, intent(out) :: warpx_local_nz, warpx_local_z0
+ integer, dimension(3), intent(in) :: global_lo, global_hi
+
+ nx_global = INT(global_hi(1)-global_lo(1)+1,idp)
+ ny_global = INT(global_hi(2)-global_lo(2)+1,idp)
+ nz_global = INT(global_hi(3)-global_lo(3)+1,idp)
+
+ alloc_local = fftw_mpi_local_size_3d( &
+ INT(nz_global,C_INTPTR_T), &
+ INT(ny_global,C_INTPTR_T), &
+ INT(nx_global,C_INTPTR_T)/2+1, &
+ comm, local_nz, local_z0)
+
+ nx = nx_global
+ ny = ny_global
+ nz = local_nz
+
+ warpx_local_nz = local_nz
+ warpx_local_z0 = local_z0
+ end subroutine warpx_fft_domain_decomp
+
+
+!> @brief
!! Set all the flags and metadata of the PICSAR FFT module.
!! Allocate the auxiliary arrays of `fft_data`
!!
!! Note: fft_data is a stuct containing 22 pointers to arrays
!! 1-11: padded arrays in real space ; 12-22 arrays for the fields in Fourier space
- subroutine warpx_fft_dataplan_init (global_lo, global_hi, local_lo, local_hi, &
- nox, noy, noz, fft_data, ndata, dx_wrpx, dt_wrpx, fftw_measure) &
+ subroutine warpx_fft_dataplan_init (nox, noy, noz, fft_data, ndata, dx_wrpx, dt_wrpx, fftw_measure) &
bind(c,name='warpx_fft_dataplan_init')
USE picsar_precision, only: idp
- use shared_data, only : comm, c_dim, p3dfft_flag, fftw_plan_measure, &
+ use shared_data, only : c_dim, p3dfft_flag, fftw_plan_measure, &
fftw_with_mpi, fftw_threads_ok, fftw_hybrid, fftw_mpi_transpose, &
- nx_global, ny_global, nz_global, & ! size of global FFT
nx, ny, nz, & ! size of local subdomains
nkx, nky, nkz, & ! size of local ffts
iz_min_r, iz_max_r, iy_min_r, iy_max_r, ix_min_r, ix_max_r, & ! loop bounds
@@ -50,13 +96,12 @@ contains
exf, eyf, ezf, bxf, byf, bzf, &
jxf, jyf, jzf, rhof, rhooldf, &
l_spectral, l_staggered, norderx, nordery, norderz
- use mpi_fftw3, only : local_nz, local_z0, fftw_mpi_local_size_3d, alloc_local
+ use mpi_fftw3, only : alloc_local
use omp_lib, only: omp_get_max_threads
USE gpstd_solver, only: init_gpstd
USE fourier_psaotd, only: init_plans_fourier_mpi
use params, only : dt
- integer, dimension(3), intent(in) :: global_lo, global_hi, local_lo, local_hi
integer, intent(in) :: nox, noy, noz, ndata
integer, intent(in) :: fftw_measure
type(c_ptr), intent(inout) :: fft_data(ndata)
@@ -67,27 +112,12 @@ contains
integer :: nx_padded
integer, dimension(3) :: shp
integer(kind=c_size_t) :: sz
- real(c_double) :: realfoo
- complex(c_double_complex) :: complexfoo
- ! Define size of domains: necessary for the initialization of the global FFT
- nx_global = INT(global_hi(1)-global_lo(1)+1,idp)
- ny_global = INT(global_hi(2)-global_lo(2)+1,idp)
- nz_global = INT(global_hi(3)-global_lo(3)+1,idp)
- nx = INT(local_hi(1)-local_lo(1)+1,idp)
- ny = INT(local_hi(2)-local_lo(2)+1,idp)
- nz = INT(local_hi(3)-local_lo(3)+1,idp)
! No need to distinguish physical and guard cells for the global FFT;
! only nx+2*nxguards counts. Thus we declare 0 guard cells for simplicity
nxguards = 0_idp
nyguards = 0_idp
nzguards = 0_idp
- ! Find the decomposition that FFTW imposes in kspace
- alloc_local = fftw_mpi_local_size_3d( &
- INT(nz_global,C_INTPTR_T), &
- INT(ny_global,C_INTPTR_T), &
- INT(nx_global,C_INTPTR_T)/2+1, &
- comm, local_nz, local_z0)
! For the calculation of the modified [k] vectors
l_staggered = .TRUE.
@@ -103,7 +133,6 @@ contains
p3dfft_flag = .FALSE.
l_spectral = .TRUE. ! Activate spectral Solver, using FFT
#ifdef _OPENMP
- CALL DFFTW_INIT_THREADS(iret)
fftw_threads_ok = .TRUE.
nopenmp = OMP_GET_MAX_THREADS()
#else
@@ -114,28 +143,28 @@ contains
! Allocate padded arrays for MPI FFTW
nx_padded = 2*(nx/2 + 1)
shp = [nx_padded, int(ny), int(nz)]
- sz = c_sizeof(realfoo) * int(shp(1),c_size_t) * int(shp(2),c_size_t) * int(shp(3),c_size_t)
- fft_data(1) = amrex_malloc(sz)
+ sz = 2*alloc_local
+ fft_data(1) = fftw_alloc_real(sz)
call c_f_pointer(fft_data(1), ex_r, shp)
- fft_data(2) = amrex_malloc(sz)
+ fft_data(2) = fftw_alloc_real(sz)
call c_f_pointer(fft_data(2), ey_r, shp)
- fft_data(3) = amrex_malloc(sz)
+ fft_data(3) = fftw_alloc_real(sz)
call c_f_pointer(fft_data(3), ez_r, shp)
- fft_data(4) = amrex_malloc(sz)
+ fft_data(4) = fftw_alloc_real(sz)
call c_f_pointer(fft_data(4), bx_r, shp)
- fft_data(5) = amrex_malloc(sz)
+ fft_data(5) = fftw_alloc_real(sz)
call c_f_pointer(fft_data(5), by_r, shp)
- fft_data(6) = amrex_malloc(sz)
+ fft_data(6) = fftw_alloc_real(sz)
call c_f_pointer(fft_data(6), bz_r, shp)
- fft_data(7) = amrex_malloc(sz)
+ fft_data(7) = fftw_alloc_real(sz)
call c_f_pointer(fft_data(7), jx_r, shp)
- fft_data(8) = amrex_malloc(sz)
+ fft_data(8) = fftw_alloc_real(sz)
call c_f_pointer(fft_data(8), jy_r, shp)
- fft_data(9) = amrex_malloc(sz)
+ fft_data(9) = fftw_alloc_real(sz)
call c_f_pointer(fft_data(9), jz_r, shp)
- fft_data(10) = amrex_malloc(sz)
+ fft_data(10) = fftw_alloc_real(sz)
call c_f_pointer(fft_data(10), rho_r, shp)
- fft_data(11) = amrex_malloc(sz)
+ fft_data(11) = fftw_alloc_real(sz)
call c_f_pointer(fft_data(11), rhoold_r, shp)
! Set array bounds when copying ex to ex_r in PICSAR
@@ -147,28 +176,28 @@ contains
nky = ny
nkz = nz
shp = [int(nkx), int(nky), int(nkz)]
- sz = c_sizeof(complexfoo) * int(shp(1),c_size_t) * int(shp(2),c_size_t) * int(shp(3),c_size_t)
- fft_data(12) = amrex_malloc(sz)
+ sz = alloc_local
+ fft_data(12) = fftw_alloc_complex(sz)
call c_f_pointer(fft_data(12), exf, shp)
- fft_data(13) = amrex_malloc(sz)
+ fft_data(13) = fftw_alloc_complex(sz)
call c_f_pointer(fft_data(13), eyf, shp)
- fft_data(14) = amrex_malloc(sz)
+ fft_data(14) = fftw_alloc_complex(sz)
call c_f_pointer(fft_data(14), ezf, shp)
- fft_data(15) = amrex_malloc(sz)
+ fft_data(15) = fftw_alloc_complex(sz)
call c_f_pointer(fft_data(15), bxf, shp)
- fft_data(16) = amrex_malloc(sz)
+ fft_data(16) = fftw_alloc_complex(sz)
call c_f_pointer(fft_data(16), byf, shp)
- fft_data(17) = amrex_malloc(sz)
+ fft_data(17) = fftw_alloc_complex(sz)
call c_f_pointer(fft_data(17), bzf, shp)
- fft_data(18) = amrex_malloc(sz)
+ fft_data(18) = fftw_alloc_complex(sz)
call c_f_pointer(fft_data(18), jxf, shp)
- fft_data(19) = amrex_malloc(sz)
+ fft_data(19) = fftw_alloc_complex(sz)
call c_f_pointer(fft_data(19), jyf, shp)
- fft_data(20) = amrex_malloc(sz)
+ fft_data(20) = fftw_alloc_complex(sz)
call c_f_pointer(fft_data(20), jzf, shp)
- fft_data(21) = amrex_malloc(sz)
+ fft_data(21) = fftw_alloc_complex(sz)
call c_f_pointer(fft_data(21), rhof, shp)
- fft_data(22) = amrex_malloc(sz)
+ fft_data(22) = fftw_alloc_complex(sz)
call c_f_pointer(fft_data(22), rhooldf, shp)
if (ndata < 22) then
@@ -193,6 +222,7 @@ contains
jx_r, jy_r, jz_r, rho_r, rhoold_r, &
exf, eyf, ezf, bxf, byf, bzf, &
jxf, jyf, jzf, rhof, rhooldf
+ use mpi_fftw3, only : plan_r2c_mpi, plan_c2r_mpi
nullify(ex_r)
nullify(ey_r)
nullify(ez_r)
@@ -215,6 +245,9 @@ contains
nullify(jzf)
nullify(rhof)
nullify(rhooldf)
+ call fftw_destroy_plan(plan_r2c_mpi)
+ call fftw_destroy_plan(plan_c2r_mpi)
+ call fftw_mpi_cleanup()
end subroutine warpx_fft_nullify
diff --git a/Source/WarpX_laser.F90 b/Source/WarpX_laser.F90
index fc1325d8d..77a0a508c 100644
--- a/Source/WarpX_laser.F90
+++ b/Source/WarpX_laser.F90
@@ -3,8 +3,8 @@ module warpx_laser_module
use iso_c_binding
use amrex_fort_module, only : amrex_real
- use constants
- use parser_wrapper
+ use constants, only : clight, pi
+ use parser_wrapper, only : parser_evaluate_function
implicit none
diff --git a/Source/WarpX_picsar.F90 b/Source/WarpX_picsar.F90
index 77860fd67..87336c9de 100644
--- a/Source/WarpX_picsar.F90
+++ b/Source/WarpX_picsar.F90
@@ -41,7 +41,6 @@ module warpx_to_pxr_module
use iso_c_binding
use amrex_fort_module, only : amrex_real
- use constants
implicit none
@@ -153,8 +152,6 @@ subroutine warpx_charge_deposition(rho,np,xp,yp,zp,w,q,xmin,ymin,zmin,dx,dy,dz,n
nxguard,nyguard,nzguard,nox,noy,noz,lvect,charge_depo_algo) &
bind(C, name="warpx_charge_deposition")
- use amrex_error_module
-
integer(c_long), intent(IN) :: np
integer(c_long), intent(IN) :: nx,ny,nz
integer(c_long), intent(IN) :: nxguard,nyguard,nzguard
diff --git a/Source/main.cpp b/Source/main.cpp
index d648d22a2..757c29e1f 100644
--- a/Source/main.cpp
+++ b/Source/main.cpp
@@ -11,7 +11,15 @@ using namespace amrex;
int main(int argc, char* argv[])
{
- amrex::Initialize(argc,argv);
+#if defined(_OPENMP) && defined(WARPX_USE_PSATD)
+ int provided;
+ MPI_Init_thread(&argc, &argv, MPI_THREAD_FUNNELED, &provided);
+ assert(provided >= MPI_THREAD_FUNNELED);
+#else
+ MPI_Init(&argc, &argv);
+#endif
+
+ amrex::Initialize(argc,argv,MPI_COMM_WORLD);
BL_PROFILE_VAR("main()", pmain);
@@ -35,4 +43,5 @@ int main(int argc, char* argv[])
BL_PROFILE_VAR_STOP(pmain);
amrex::Finalize();
+ MPI_Finalize();
}