diff options
Diffstat (limited to 'Source/Parser/GpuParser.H')
-rw-r--r-- | Source/Parser/GpuParser.H | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/Source/Parser/GpuParser.H b/Source/Parser/GpuParser.H new file mode 100644 index 000000000..1533ee6b9 --- /dev/null +++ b/Source/Parser/GpuParser.H @@ -0,0 +1,72 @@ +#ifndef WARPX_GPU_PARSER_H_ +#define WARPX_GPU_PARSER_H_ + +#include <WarpXParser.H> +#include <AMReX_Gpu.H> + +// When compiled for CPU, wrap WarpXParser and enable threading. +// When compiled for GPU, store one copy of the parser in +// CUDA managed memory for __device__ code, and one copy of the parser +// in CUDA managed memory for __host__ code. This way, the parser can be +// efficiently called from both host and device. +class GpuParser +{ +public: + GpuParser (WarpXParser const& wp); + void clear (); + + AMREX_GPU_HOST_DEVICE + double + operator() (double x, double y, double z) const noexcept + { +#ifdef AMREX_USE_GPU + +#ifdef AMREX_DEVICE_COMPILE +// WarpX compiled for GPU, function compiled for __device__ + // the 3D position of each particle is stored in shared memory. + amrex::Gpu::SharedMemory<double> gsm; + double* p = gsm.dataPtr(); + int tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*(blockDim.x*blockDim.y); + p[tid*3] = x; + p[tid*3+1] = y; + p[tid*3+2] = z; + return wp_ast_eval(m_gpu_parser.ast); +#else +// WarpX compiled for GPU, function compiled for __host__ + m_var.x = x; + m_var.y = y; + m_var.z = z; + return wp_ast_eval(m_cpu_parser.ast); +#endif + +#else +// WarpX compiled for CPU +#ifdef _OPENMP + int tid = omp_get_thread_num(); +#else + int tid = 0; +#endif + m_var[tid].x = x; + m_var[tid].y = y; + m_var[tid].z = z; + return wp_ast_eval(m_parser[tid]->ast); +#endif + } + +private: + +#ifdef AMREX_USE_GPU + // Copy of the parser running on __device__ + struct wp_parser m_gpu_parser; + // Copy of the parser running on __host__ + struct wp_parser m_cpu_parser; + mutable amrex::XDim3 m_var; +#else + // Only one parser + struct wp_parser** m_parser; + mutable amrex::XDim3* m_var; + int nthreads; +#endif +}; + +#endif |