aboutsummaryrefslogtreecommitdiff
path: root/Source/Utils/WarpX_Complex.H
diff options
context:
space:
mode:
Diffstat (limited to 'Source/Utils/WarpX_Complex.H')
-rw-r--r--Source/Utils/WarpX_Complex.H14
1 files changed, 7 insertions, 7 deletions
diff --git a/Source/Utils/WarpX_Complex.H b/Source/Utils/WarpX_Complex.H
index 25ed1a53a..1f265d3c5 100644
--- a/Source/Utils/WarpX_Complex.H
+++ b/Source/Utils/WarpX_Complex.H
@@ -3,13 +3,14 @@
#include <AMReX_REAL.H>
#include <AMReX_Gpu.H>
+#include <AMReX_GpuComplex.H>
+
+#include <complex>
// Define complex type on GPU/CPU
#ifdef AMREX_USE_GPU
-#include <thrust/complex.h>
-
-using Complex = thrust::complex<amrex::Real>;
+using Complex = amrex::GpuComplex<amrex::Real>;
#ifdef WARPX_USE_PSATD
#include <cufft.h>
@@ -19,7 +20,6 @@ static_assert( sizeof(Complex) == sizeof(cuDoubleComplex),
#else
-#include <complex>
using Complex = std::complex<amrex::Real>;
#ifdef WARPX_USE_PSATD
@@ -39,7 +39,7 @@ namespace MathFunc
template<typename T>
AMREX_GPU_HOST_DEVICE T exp (const T& val){
#ifdef AMREX_USE_GPU
- return thrust::exp(val);
+ return amrex::exp(val);
#else
return std::exp(val);
#endif
@@ -49,7 +49,7 @@ namespace MathFunc
template<typename T>
AMREX_GPU_HOST_DEVICE T sqrt (const T& val){
#ifdef AMREX_USE_GPU
- return thrust::sqrt(val);
+ return amrex::sqrt(val);
#else
return std::sqrt(val);
#endif
@@ -59,7 +59,7 @@ namespace MathFunc
template<typename T1, typename T2>
AMREX_GPU_HOST_DEVICE T1 pow (const T1& val, const T2& power){
#ifdef AMREX_USE_GPU
- return thrust::pow(val, power);
+ return amrex::pow(val, power);
#else
return std::pow(val, power);
#endif