diff options
-rw-r--r-- | Source/WarpX.H | 8 | ||||
-rw-r--r-- | Source/WarpXFFT.cpp | 163 | ||||
-rw-r--r-- | Source/WarpX_boosted_frame.F90 | 6 | ||||
-rw-r--r-- | Source/WarpX_f.H | 6 | ||||
-rw-r--r-- | Source/WarpX_fft.F90 | 131 | ||||
-rw-r--r-- | Source/WarpX_laser.F90 | 4 | ||||
-rw-r--r-- | Source/WarpX_picsar.F90 | 3 | ||||
-rw-r--r-- | Source/main.cpp | 11 |
8 files changed, 224 insertions, 108 deletions
diff --git a/Source/WarpX.H b/Source/WarpX.H index 97da4ea27..ebd5fe420 100644 --- a/Source/WarpX.H +++ b/Source/WarpX.H @@ -22,6 +22,10 @@ #include <WarpXPML.H> #include <WarpXBoostedFrameDiagnostic.H> +#ifdef WARPX_USE_PSATD +#include <fftw3.h> +#endif + class NoOpPhysBC : public amrex::PhysBCFunctBase { @@ -445,8 +449,8 @@ private: void* data[N] = { nullptr }; ~FFTData () { - for (int i = 0; i < N; ++i) { // The memory is allocated by Fortran routines - std::free(data[i]); + for (int i = 0; i < N; ++i) { // The memory is allocated with fftw_alloc. + fftw_free(data[i]); data[i] = nullptr; } } diff --git a/Source/WarpXFFT.cpp b/Source/WarpXFFT.cpp index 8d37b9307..f394d5f64 100644 --- a/Source/WarpXFFT.cpp +++ b/Source/WarpXFFT.cpp @@ -1,6 +1,8 @@ #include <WarpX.H> #include <WarpX_f.H> +#include <AMReX_BaseFab_f.H> +#include <AMReX_iMultiFab.H> using namespace amrex; @@ -8,21 +10,105 @@ constexpr int WarpX::FFTData::N; namespace { +/** \brief Returns an "owner mask" which 1 for all cells, except + * for the duplicated (physical) cells of a nodal grid. + * + * More precisely, for these cells (which are represented on several grids) + * the owner mask is 1 only if these cells are at the lower left end of + * the local grid - or if these cells are at the end of the physical domain + * Therefore, there for these cells, there will be only one grid for + * which the owner mask is non-zero. + */ +static iMultiFab +BuildFFTOwnerMask (const MultiFab& mf, const Geometry& geom) +{ + const BoxArray& ba = mf.boxArray(); + const DistributionMapping& dm = mf.DistributionMap(); + iMultiFab mask(ba, dm, 1, 0); + const int owner = 1; + const int nonowner = 0; + mask.setVal(owner); + + const Box& domain_box = amrex::convert(geom.Domain(), ba.ixType()); + + AMREX_ASSERT(ba.complementIn(domain_box).isEmpty()); + +#ifdef _OPENMP +#pragma omp parallel +#endif + for (MFIter mfi(mask); mfi.isValid(); ++mfi) + { + IArrayBox& fab = mask[mfi]; + const Box& bx = fab.box(); + Box bx2 = bx; + for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) { + // Detect nodal dimensions + if (bx2.type(idim) == IndexType::NODE) { + // Make sure that this grid does not touch the end of + // the physical domain. + if (bx2.bigEnd(idim) < domain_box.bigEnd(idim)) { + bx2.growHi(idim, -1); + } + } + } + const BoxList& bl = amrex::boxDiff(bx, bx2); + // Set owner mask in these cells + for (const auto& b : bl) { + fab.setVal(nonowner, b, 0, 1); + } + } + + return mask; +} + +/** \brief Copy the data from the FFT grid to the regular grid + * + * Because, for nodal grid, some cells are duplicated on several boxes, + * special care has to be taken in order to have consistent values on + * each boxes when copying this data. Here this is done by setting a + * mask, where, for these duplicated cells, the mask is non-zero on only + * one box. + */ static void -CopyDataFromFFTToValid (MultiFab& mf, const MultiFab& mf_fft, const BoxArray& ba_valid_fft) +CopyDataFromFFTToValid (MultiFab& mf, const MultiFab& mf_fft, const BoxArray& ba_valid_fft, const Geometry& geom) { auto idx_type = mf_fft.ixType(); MultiFab mftmp(amrex::convert(ba_valid_fft,idx_type), mf_fft.DistributionMap(), 1, 0); + + const iMultiFab& mask = BuildFFTOwnerMask(mftmp, geom); + + // Local copy: whenever an MPI rank owns both the data from the FFT + // grid and from the regular grid, for overlapping region, copy it locally +#ifdef _OPENMP +#pragma omp parallel +#endif for (MFIter mfi(mftmp,true); mfi.isValid(); ++mfi) { const Box& bx = mfi.tilebox(); - if (mf_fft[mfi].box().contains(bx)) + FArrayBox& dstfab = mftmp[mfi]; + + const FArrayBox& srcfab = mf_fft[mfi]; + const Box& srcbox = srcfab.box(); + + if (srcbox.contains(bx)) { - mftmp[mfi].copy(mf_fft[mfi], bx, 0, bx, 0, 1); + // Copy the interior region (without guard cells) + dstfab.copy(srcfab, bx, 0, bx, 0, 1); + // Set the value to 0 whenever the mask is 0 + // (i.e. for nodal duplicated cells, there is a single box + // for which the mask is different than 0) + amrex_fab_setval_ifnot (BL_TO_FORTRAN_BOX(bx), + BL_TO_FORTRAN_FAB(dstfab), + BL_TO_FORTRAN_ANYD(mask[mfi]), + 0.0); // if mask == 0, set value to zero } } - mf.ParallelCopy(mftmp); + // Global copy: Get the remaining the data from other procs + // Use ParallelAdd instead of ParallelCopy, so that the value from + // the cell that has non-zero mask is the one which is retained. + mf.setVal(0.0, 0); + mf.ParallelAdd(mftmp); } } @@ -146,7 +232,7 @@ WarpX::InitFFTComm (int lev) // # of processes in the subcommunicator int np_fft = nprocs / ngroups_fft; AMREX_ALWAYS_ASSERT_WITH_MESSAGE(np_fft*ngroups_fft == nprocs, - "Number of processes must be divisilbe by number of FFT groups"); + "Number of processes must be divisible by number of FFT groups"); int myproc = ParallelDescriptor::MyProc(); // my color in ngroups_fft subcommunicators. 0 <= color_fft < ngroups_fft @@ -180,57 +266,50 @@ void WarpX::FFTDomainDecompsition (int lev, BoxArray& ba_fft, DistributionMapping& dm_fft, BoxArray& ba_valid, Box& domain_fft, const Box& domain) { + AMREX_ALWAYS_ASSERT_WITH_MESSAGE(AMREX_SPACEDIM == 3, "PSATD only works in 3D"); + IntVect nguards_fft(AMREX_D_DECL(nox_fft/2,noy_fft/2,noz_fft/2)); int nprocs = ParallelDescriptor::NProcs(); - int np_fft; - MPI_Comm_size(comm_fft[lev], &np_fft); - - BoxList bl_fft; // List of boxes: will be filled by the boxes attributed to each proc - bl_fft.reserve(nprocs); - Vector<int> gid_fft; // List of group ID: will be filled with the FFT group ID of each box - gid_fft.reserve(nprocs); BoxList bl(domain, ngroups_fft); // This does a multi-D domain decomposition for groups AMREX_ALWAYS_ASSERT(bl.size() == ngroups_fft); const Vector<Box>& bldata = bl.data(); - // Fill bl_fft and gid_fft ; loop over FFT groups - for (int igroup = 0; igroup < ngroups_fft; ++igroup) - { - // Within the group, 1d domain decomposition is performed. - const Box& bx = amrex::grow(bldata[igroup], nguards_fft); - // chop in z-direction into np_fft for FFTW - BoxList tbl(bx, np_fft, Direction::z); - bl_fft.join(tbl); - for (int i = 0; i < np_fft; ++i) { - gid_fft.push_back(igroup); - } - // Determine the sub-domain associated with the FFT group of the local proc - if (igroup == color_fft[lev]) { - domain_fft = bx; - } + + // This is the domain for the FFT sub-group (including guard cells) + domain_fft = amrex::grow(bldata[color_fft[lev]], nguards_fft); + // Ask FFTW to chop the current FFT sub-group domain in the z-direction + // and give a chunk to each MPI rank in the current sub-group. + int nz_fft, z0_fft; + warpx_fft_domain_decomp(&nz_fft, &z0_fft, BL_TO_FORTRAN_BOX(domain_fft)); + // Each MPI rank adds a box with its chunk of the FFT grid + // (given by the above decomposition) to the list `bx_fft`, + // then list is shared among all MPI ranks via AllGather + Vector<Box> bx_fft; + if (nz_fft > 0) { + Box b = domain_fft; + b.setRange(2, z0_fft+domain_fft.smallEnd(2), nz_fft); + bx_fft.push_back(b); } + amrex::AllGatherBoxes(bx_fft); - // This BoxArray contains local FFT domains for each process - ba_fft.define(std::move(bl_fft)); + // Define the AMReX objects for the FFT grid: BoxArray and DistributionMapping + ba_fft.define(BoxList(std::move(bx_fft))); AMREX_ALWAYS_ASSERT(ba_fft.size() == ParallelDescriptor::NProcs()); - Vector<int> pmap(ba_fft.size()); std::iota(pmap.begin(), pmap.end(), 0); dm_fft.define(std::move(pmap)); - // // For communication between WarpX normal domain and FFT domain, we need to create a // special BoxArray ba_valid - // - const Box foobox(-nguards_fft-2, -nguards_fft-2); BoxList bl_valid; // List of boxes: will be filled by the valid part of the subdomains of ba_fft bl_valid.reserve(ba_fft.size()); + int np_fft = nprocs / ngroups_fft; for (int i = 0; i < ba_fft.size(); ++i) { - int igroup = gid_fft[i]; + int igroup = dm_fft[i] / np_fft; // This should be consistent with InitFFTComm const Box& bx = ba_fft[i] & bldata[igroup]; // Intersection with the domain of // the FFT group *without* guard cells if (bx.ok()) @@ -262,9 +341,7 @@ WarpX::InitFFTDataPlan (int lev) for (MFIter mfi(*Efield_fp_fft[lev][0]); mfi.isValid(); ++mfi) { const Box& local_domain = amrex::enclosedCells(mfi.fabbox()); - warpx_fft_dataplan_init(BL_TO_FORTRAN_BOX(domain_fp_fft[lev]), - BL_TO_FORTRAN_BOX(local_domain), - &nox_fft, &noy_fft, &noz_fft, + warpx_fft_dataplan_init(&nox_fft, &noy_fft, &noz_fft, (*dataptr_fp_fft[lev])[mfi].data, &FFTData::N, dx_fp.data(), &dt[lev], &fftw_plan_measure ); } @@ -335,12 +412,12 @@ WarpX::PushPSATD (int lev, amrex::Real /* dt */) BL_PROFILE_VAR_STOP(blp_push_eb); BL_PROFILE_VAR_START(blp_copy); - CopyDataFromFFTToValid(*Efield_fp[lev][0], *Efield_fp_fft[lev][0], ba_valid_fp_fft[lev]); - CopyDataFromFFTToValid(*Efield_fp[lev][1], *Efield_fp_fft[lev][1], ba_valid_fp_fft[lev]); - CopyDataFromFFTToValid(*Efield_fp[lev][2], *Efield_fp_fft[lev][2], ba_valid_fp_fft[lev]); - CopyDataFromFFTToValid(*Bfield_fp[lev][0], *Bfield_fp_fft[lev][0], ba_valid_fp_fft[lev]); - CopyDataFromFFTToValid(*Bfield_fp[lev][1], *Bfield_fp_fft[lev][1], ba_valid_fp_fft[lev]); - CopyDataFromFFTToValid(*Bfield_fp[lev][2], *Bfield_fp_fft[lev][2], ba_valid_fp_fft[lev]); + CopyDataFromFFTToValid(*Efield_fp[lev][0], *Efield_fp_fft[lev][0], ba_valid_fp_fft[lev], geom[lev]); + CopyDataFromFFTToValid(*Efield_fp[lev][1], *Efield_fp_fft[lev][1], ba_valid_fp_fft[lev], geom[lev]); + CopyDataFromFFTToValid(*Efield_fp[lev][2], *Efield_fp_fft[lev][2], ba_valid_fp_fft[lev], geom[lev]); + CopyDataFromFFTToValid(*Bfield_fp[lev][0], *Bfield_fp_fft[lev][0], ba_valid_fp_fft[lev], geom[lev]); + CopyDataFromFFTToValid(*Bfield_fp[lev][1], *Bfield_fp_fft[lev][1], ba_valid_fp_fft[lev], geom[lev]); + CopyDataFromFFTToValid(*Bfield_fp[lev][2], *Bfield_fp_fft[lev][2], ba_valid_fp_fft[lev], geom[lev]); BL_PROFILE_VAR_STOP(blp_copy); if (lev > 0) diff --git a/Source/WarpX_boosted_frame.F90 b/Source/WarpX_boosted_frame.F90 index 3f8f0607c..3bce1817e 100644 --- a/Source/WarpX_boosted_frame.F90 +++ b/Source/WarpX_boosted_frame.F90 @@ -3,7 +3,7 @@ module warpx_boosted_frame_module use iso_c_binding use amrex_fort_module, only : amrex_real - use constants + use constants, only : clight implicit none @@ -62,8 +62,6 @@ contains i_boost, i_lab) & bind(C, name="warpx_copy_slice_3d") - use amrex_fort_module, only : amrex_real - integer , intent(in) :: ncomp, i_boost, i_lab integer , intent(in) :: lo(3), hi(3) integer , intent(in) :: tlo(3), thi(3) @@ -88,8 +86,6 @@ contains i_boost, i_lab) & bind(C, name="warpx_copy_slice_2d") - use amrex_fort_module, only : amrex_real - integer , intent(in) :: ncomp, i_boost, i_lab integer , intent(in) :: lo(2), hi(2) integer , intent(in) :: tlo(2), thi(2) diff --git a/Source/WarpX_f.H b/Source/WarpX_f.H index 948588ca9..029c07377 100644 --- a/Source/WarpX_f.H +++ b/Source/WarpX_f.H @@ -420,9 +420,9 @@ extern "C" #ifdef WARPX_USE_PSATD void warpx_fft_mpi_init (int fcomm); - void warpx_fft_dataplan_init (const int* global_lo, const int* global_hi, - const int* local_lo, const int* local_hi, - const int* nox, const int* noy, const int* noz, + void warpx_fft_domain_decomp (int* warpx_local_nz, int* warpx_local_z0, + const int* global_lo, const int* global_hi); + void warpx_fft_dataplan_init (const int* nox, const int* noy, const int* noz, void* fft_data, const int* ndata, const amrex_real* dx_w, const amrex_real* dt_w, const int* fftw_plan_measure ); diff --git a/Source/WarpX_fft.F90 b/Source/WarpX_fft.F90 index 6f1a5491c..406f3f90f 100644 --- a/Source/WarpX_fft.F90 +++ b/Source/WarpX_fft.F90 @@ -1,12 +1,15 @@ module warpx_fft_module - use amrex_error_module - use amrex_fort_module + use amrex_error_module, only : amrex_error, amrex_abort + use amrex_fort_module, only : amrex_real use iso_c_binding implicit none + include 'fftw3-mpi.f03' + private - public :: warpx_fft_mpi_init, warpx_fft_dataplan_init, warpx_fft_nullify, warpx_fft_push_eb + public :: warpx_fft_mpi_init, warpx_fft_domain_decomp, warpx_fft_dataplan_init, warpx_fft_nullify, & + warpx_fft_push_eb contains @@ -25,21 +28,64 @@ contains call mpi_comm_rank(comm, lrank, ierr) rank = lrank + +#ifdef _OPENMP + ierr = fftw_init_threads() + if (ierr.eq.0) call amrex_error("fftw_init_threads failed") +#endif + call fftw_mpi_init() +#ifdef _OPENMP + call dfftw_init_threads(ierr) + if (ierr.eq.0) call amrex_error("dfftw_init_threads failed") +#endif end subroutine warpx_fft_mpi_init !> @brief +!! Ask FFTW to do domain decomposition. +! +! This is always a 1d domain decomposition along z ; it is typically +! done on the *FFT sub-groups*, not the all domain + subroutine warpx_fft_domain_decomp (warpx_local_nz, warpx_local_z0, global_lo, global_hi) & + bind(c,name='warpx_fft_domain_decomp') + use picsar_precision, only : idp + use shared_data, only : comm, & + nx_global, ny_global, nz_global, & ! size of global FFT + nx, ny, nz ! size of local subdomains + use mpi_fftw3, only : local_nz, local_z0, fftw_mpi_local_size_3d, alloc_local + + integer, intent(out) :: warpx_local_nz, warpx_local_z0 + integer, dimension(3), intent(in) :: global_lo, global_hi + + nx_global = INT(global_hi(1)-global_lo(1)+1,idp) + ny_global = INT(global_hi(2)-global_lo(2)+1,idp) + nz_global = INT(global_hi(3)-global_lo(3)+1,idp) + + alloc_local = fftw_mpi_local_size_3d( & + INT(nz_global,C_INTPTR_T), & + INT(ny_global,C_INTPTR_T), & + INT(nx_global,C_INTPTR_T)/2+1, & + comm, local_nz, local_z0) + + nx = nx_global + ny = ny_global + nz = local_nz + + warpx_local_nz = local_nz + warpx_local_z0 = local_z0 + end subroutine warpx_fft_domain_decomp + + +!> @brief !! Set all the flags and metadata of the PICSAR FFT module. !! Allocate the auxiliary arrays of `fft_data` !! !! Note: fft_data is a stuct containing 22 pointers to arrays !! 1-11: padded arrays in real space ; 12-22 arrays for the fields in Fourier space - subroutine warpx_fft_dataplan_init (global_lo, global_hi, local_lo, local_hi, & - nox, noy, noz, fft_data, ndata, dx_wrpx, dt_wrpx, fftw_measure) & + subroutine warpx_fft_dataplan_init (nox, noy, noz, fft_data, ndata, dx_wrpx, dt_wrpx, fftw_measure) & bind(c,name='warpx_fft_dataplan_init') USE picsar_precision, only: idp - use shared_data, only : comm, c_dim, p3dfft_flag, fftw_plan_measure, & + use shared_data, only : c_dim, p3dfft_flag, fftw_plan_measure, & fftw_with_mpi, fftw_threads_ok, fftw_hybrid, fftw_mpi_transpose, & - nx_global, ny_global, nz_global, & ! size of global FFT nx, ny, nz, & ! size of local subdomains nkx, nky, nkz, & ! size of local ffts iz_min_r, iz_max_r, iy_min_r, iy_max_r, ix_min_r, ix_max_r, & ! loop bounds @@ -50,13 +96,12 @@ contains exf, eyf, ezf, bxf, byf, bzf, & jxf, jyf, jzf, rhof, rhooldf, & l_spectral, l_staggered, norderx, nordery, norderz - use mpi_fftw3, only : local_nz, local_z0, fftw_mpi_local_size_3d, alloc_local + use mpi_fftw3, only : alloc_local use omp_lib, only: omp_get_max_threads USE gpstd_solver, only: init_gpstd USE fourier_psaotd, only: init_plans_fourier_mpi use params, only : dt - integer, dimension(3), intent(in) :: global_lo, global_hi, local_lo, local_hi integer, intent(in) :: nox, noy, noz, ndata integer, intent(in) :: fftw_measure type(c_ptr), intent(inout) :: fft_data(ndata) @@ -67,27 +112,12 @@ contains integer :: nx_padded integer, dimension(3) :: shp integer(kind=c_size_t) :: sz - real(c_double) :: realfoo - complex(c_double_complex) :: complexfoo - ! Define size of domains: necessary for the initialization of the global FFT - nx_global = INT(global_hi(1)-global_lo(1)+1,idp) - ny_global = INT(global_hi(2)-global_lo(2)+1,idp) - nz_global = INT(global_hi(3)-global_lo(3)+1,idp) - nx = INT(local_hi(1)-local_lo(1)+1,idp) - ny = INT(local_hi(2)-local_lo(2)+1,idp) - nz = INT(local_hi(3)-local_lo(3)+1,idp) ! No need to distinguish physical and guard cells for the global FFT; ! only nx+2*nxguards counts. Thus we declare 0 guard cells for simplicity nxguards = 0_idp nyguards = 0_idp nzguards = 0_idp - ! Find the decomposition that FFTW imposes in kspace - alloc_local = fftw_mpi_local_size_3d( & - INT(nz_global,C_INTPTR_T), & - INT(ny_global,C_INTPTR_T), & - INT(nx_global,C_INTPTR_T)/2+1, & - comm, local_nz, local_z0) ! For the calculation of the modified [k] vectors l_staggered = .TRUE. @@ -103,7 +133,6 @@ contains p3dfft_flag = .FALSE. l_spectral = .TRUE. ! Activate spectral Solver, using FFT #ifdef _OPENMP - CALL DFFTW_INIT_THREADS(iret) fftw_threads_ok = .TRUE. nopenmp = OMP_GET_MAX_THREADS() #else @@ -114,28 +143,28 @@ contains ! Allocate padded arrays for MPI FFTW nx_padded = 2*(nx/2 + 1) shp = [nx_padded, int(ny), int(nz)] - sz = c_sizeof(realfoo) * int(shp(1),c_size_t) * int(shp(2),c_size_t) * int(shp(3),c_size_t) - fft_data(1) = amrex_malloc(sz) + sz = 2*alloc_local + fft_data(1) = fftw_alloc_real(sz) call c_f_pointer(fft_data(1), ex_r, shp) - fft_data(2) = amrex_malloc(sz) + fft_data(2) = fftw_alloc_real(sz) call c_f_pointer(fft_data(2), ey_r, shp) - fft_data(3) = amrex_malloc(sz) + fft_data(3) = fftw_alloc_real(sz) call c_f_pointer(fft_data(3), ez_r, shp) - fft_data(4) = amrex_malloc(sz) + fft_data(4) = fftw_alloc_real(sz) call c_f_pointer(fft_data(4), bx_r, shp) - fft_data(5) = amrex_malloc(sz) + fft_data(5) = fftw_alloc_real(sz) call c_f_pointer(fft_data(5), by_r, shp) - fft_data(6) = amrex_malloc(sz) + fft_data(6) = fftw_alloc_real(sz) call c_f_pointer(fft_data(6), bz_r, shp) - fft_data(7) = amrex_malloc(sz) + fft_data(7) = fftw_alloc_real(sz) call c_f_pointer(fft_data(7), jx_r, shp) - fft_data(8) = amrex_malloc(sz) + fft_data(8) = fftw_alloc_real(sz) call c_f_pointer(fft_data(8), jy_r, shp) - fft_data(9) = amrex_malloc(sz) + fft_data(9) = fftw_alloc_real(sz) call c_f_pointer(fft_data(9), jz_r, shp) - fft_data(10) = amrex_malloc(sz) + fft_data(10) = fftw_alloc_real(sz) call c_f_pointer(fft_data(10), rho_r, shp) - fft_data(11) = amrex_malloc(sz) + fft_data(11) = fftw_alloc_real(sz) call c_f_pointer(fft_data(11), rhoold_r, shp) ! Set array bounds when copying ex to ex_r in PICSAR @@ -147,28 +176,28 @@ contains nky = ny nkz = nz shp = [int(nkx), int(nky), int(nkz)] - sz = c_sizeof(complexfoo) * int(shp(1),c_size_t) * int(shp(2),c_size_t) * int(shp(3),c_size_t) - fft_data(12) = amrex_malloc(sz) + sz = alloc_local + fft_data(12) = fftw_alloc_complex(sz) call c_f_pointer(fft_data(12), exf, shp) - fft_data(13) = amrex_malloc(sz) + fft_data(13) = fftw_alloc_complex(sz) call c_f_pointer(fft_data(13), eyf, shp) - fft_data(14) = amrex_malloc(sz) + fft_data(14) = fftw_alloc_complex(sz) call c_f_pointer(fft_data(14), ezf, shp) - fft_data(15) = amrex_malloc(sz) + fft_data(15) = fftw_alloc_complex(sz) call c_f_pointer(fft_data(15), bxf, shp) - fft_data(16) = amrex_malloc(sz) + fft_data(16) = fftw_alloc_complex(sz) call c_f_pointer(fft_data(16), byf, shp) - fft_data(17) = amrex_malloc(sz) + fft_data(17) = fftw_alloc_complex(sz) call c_f_pointer(fft_data(17), bzf, shp) - fft_data(18) = amrex_malloc(sz) + fft_data(18) = fftw_alloc_complex(sz) call c_f_pointer(fft_data(18), jxf, shp) - fft_data(19) = amrex_malloc(sz) + fft_data(19) = fftw_alloc_complex(sz) call c_f_pointer(fft_data(19), jyf, shp) - fft_data(20) = amrex_malloc(sz) + fft_data(20) = fftw_alloc_complex(sz) call c_f_pointer(fft_data(20), jzf, shp) - fft_data(21) = amrex_malloc(sz) + fft_data(21) = fftw_alloc_complex(sz) call c_f_pointer(fft_data(21), rhof, shp) - fft_data(22) = amrex_malloc(sz) + fft_data(22) = fftw_alloc_complex(sz) call c_f_pointer(fft_data(22), rhooldf, shp) if (ndata < 22) then @@ -193,6 +222,7 @@ contains jx_r, jy_r, jz_r, rho_r, rhoold_r, & exf, eyf, ezf, bxf, byf, bzf, & jxf, jyf, jzf, rhof, rhooldf + use mpi_fftw3, only : plan_r2c_mpi, plan_c2r_mpi nullify(ex_r) nullify(ey_r) nullify(ez_r) @@ -215,6 +245,9 @@ contains nullify(jzf) nullify(rhof) nullify(rhooldf) + call fftw_destroy_plan(plan_r2c_mpi) + call fftw_destroy_plan(plan_c2r_mpi) + call fftw_mpi_cleanup() end subroutine warpx_fft_nullify diff --git a/Source/WarpX_laser.F90 b/Source/WarpX_laser.F90 index fc1325d8d..77a0a508c 100644 --- a/Source/WarpX_laser.F90 +++ b/Source/WarpX_laser.F90 @@ -3,8 +3,8 @@ module warpx_laser_module use iso_c_binding use amrex_fort_module, only : amrex_real - use constants - use parser_wrapper + use constants, only : clight, pi + use parser_wrapper, only : parser_evaluate_function implicit none diff --git a/Source/WarpX_picsar.F90 b/Source/WarpX_picsar.F90 index 77860fd67..87336c9de 100644 --- a/Source/WarpX_picsar.F90 +++ b/Source/WarpX_picsar.F90 @@ -41,7 +41,6 @@ module warpx_to_pxr_module use iso_c_binding use amrex_fort_module, only : amrex_real - use constants implicit none @@ -153,8 +152,6 @@ subroutine warpx_charge_deposition(rho,np,xp,yp,zp,w,q,xmin,ymin,zmin,dx,dy,dz,n nxguard,nyguard,nzguard,nox,noy,noz,lvect,charge_depo_algo) & bind(C, name="warpx_charge_deposition") - use amrex_error_module - integer(c_long), intent(IN) :: np integer(c_long), intent(IN) :: nx,ny,nz integer(c_long), intent(IN) :: nxguard,nyguard,nzguard diff --git a/Source/main.cpp b/Source/main.cpp index d648d22a2..757c29e1f 100644 --- a/Source/main.cpp +++ b/Source/main.cpp @@ -11,7 +11,15 @@ using namespace amrex; int main(int argc, char* argv[]) { - amrex::Initialize(argc,argv); +#if defined(_OPENMP) && defined(WARPX_USE_PSATD) + int provided; + MPI_Init_thread(&argc, &argv, MPI_THREAD_FUNNELED, &provided); + assert(provided >= MPI_THREAD_FUNNELED); +#else + MPI_Init(&argc, &argv); +#endif + + amrex::Initialize(argc,argv,MPI_COMM_WORLD); BL_PROFILE_VAR("main()", pmain); @@ -35,4 +43,5 @@ int main(int argc, char* argv[]) BL_PROFILE_VAR_STOP(pmain); amrex::Finalize(); + MPI_Finalize(); } |