diff options
Diffstat (limited to 'Source/Particles/ParticleCreation/SmartCopy.H')
-rw-r--r-- | Source/Particles/ParticleCreation/SmartCopy.H | 67 |
1 files changed, 57 insertions, 10 deletions
diff --git a/Source/Particles/ParticleCreation/SmartCopy.H b/Source/Particles/ParticleCreation/SmartCopy.H index 2247d8109..827f23678 100644 --- a/Source/Particles/ParticleCreation/SmartCopy.H +++ b/Source/Particles/ParticleCreation/SmartCopy.H @@ -1,6 +1,8 @@ #ifndef SMART_COPY_H_ #define SMART_COPY_H_ +#include <DefaultInitialization.H> + #include <AMReX_GpuContainers.H> #include <map> @@ -17,25 +19,49 @@ struct SmartCopyTag int size () const noexcept { return common_names.size(); } }; +struct DefaultInitializationTag +{ + amrex::Gpu::DeviceVector<InitializationPolicy> m_policy_real; + amrex::Gpu::DeviceVector<InitializationPolicy> m_policy_int; + +}; + SmartCopyTag getSmartCopyTag (const NameMap& src, const NameMap& dst); struct SmartCopy { - int m_num_real; + int m_num_copy_real; const int* m_src_comps_r; const int* m_dst_comps_r; - int m_num_int; + int m_num_copy_int; const int* m_src_comps_i; const int* m_dst_comps_i; - template <typename DstData, typename SrcData> + const InitializationPolicy* m_policy_real; + const InitializationPolicy* m_policy_int; + + template <typename DstData, typename SrcData> AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void operator() (const DstData& dst, const SrcData& src, int i_src, int i_dst) const noexcept { + // the particle struct is always copied over dst.m_aos[i_dst] = src.m_aos[i_src]; - for (int j = 0; j < m_num_real; ++j) + // initialize the real components + for (int j = 0; j < DstData::NAR; ++j) + dst.m_rdata[j] = initializeRealValue(m_policy_real[j]); + for (int j = 0; j < dst.m_num_runtime_real; ++j) + dst.m_rdata[j] = initializeRealValue(m_policy_real[j+DstData::NAR]); + + // initialize the int components + for (int j = 0; j < DstData::NAI; ++j) + dst.m_idata[j] = initializeIntValue(m_policy_int[j]); + for (int j = 0; j < dst.m_num_runtime_int; ++j) + dst.m_idata[j] = initializeIntValue(m_policy_int[j+DstData::NAI]); + + // copy the shared real components + for (int j = 0; j < m_num_copy_real; ++j) { int src_comp = (m_src_comps_r[j] < SrcData::NAR) ? m_src_comps_r[j] : m_src_comps_r[j] - SrcData::NAR; int dst_comp = (m_dst_comps_r[j] < DstData::NAR) ? m_dst_comps_r[j] : m_dst_comps_r[j] - DstData::NAR; @@ -46,7 +72,8 @@ struct SmartCopy dst_data[dst_comp] = src_data[src_comp]; } - for (int j = 0; j < m_num_int; ++j) + // copy the shared int components + for (int j = 0; j < m_num_copy_int; ++j) { int src_comp = (m_src_comps_i[j] < SrcData::NAI) ? m_src_comps_i[j] : m_src_comps_i[j] - SrcData::NAI; int dst_comp = (m_dst_comps_i[j] < DstData::NAI) ? m_dst_comps_i[j] : m_dst_comps_i[j] - DstData::NAI; @@ -63,14 +90,32 @@ class SmartCopyFactory { SmartCopyTag m_tag_real; SmartCopyTag m_tag_int; + amrex::Gpu::DeviceVector<InitializationPolicy> m_policy_real; + amrex::Gpu::DeviceVector<InitializationPolicy> m_policy_int; bool m_defined; - template <class PC1, class PC2> - SmartCopyFactory (const PC1& pc1, const PC2& pc2) noexcept +public: + template <class SrcPC, class DstPC> + SmartCopyFactory (const SrcPC& src, const DstPC& dst) noexcept : m_defined(false) { - m_tag_real = getSmartCopyTag(pc1.getParticleComps(), pc2.getParticleComps()); - m_tag_int = getSmartCopyTag(pc1.getParticleiComps(), pc2.getParticleiComps()); + m_tag_real = getSmartCopyTag(src.getParticleComps(), dst.getParticleComps()); + m_tag_int = getSmartCopyTag(src.getParticleiComps(), dst.getParticleiComps()); + + auto real_comps = dst.getParticleComps(); + m_policy_real.resize(real_comps.size()); + for (const auto& kv : real_comps) + { + m_policy_real[kv.second] = initialization_policies[kv.first]; + } + + auto int_comps = dst.getParticleiComps(); + m_policy_int.resize(int_comps.size()); + for (const auto& kv : int_comps) + { + m_policy_int[kv.second] = initialization_policies[kv.first]; + } + m_defined = true; } @@ -82,7 +127,9 @@ class SmartCopyFactory m_tag_real.dst_comps.dataPtr(), m_tag_int.size(), m_tag_int. src_comps.dataPtr(), - m_tag_int. dst_comps.dataPtr()}; + m_tag_int. dst_comps.dataPtr(), + m_policy_real.dataPtr(), + m_policy_int.dataPtr()}; } bool isDefined () const noexcept { return m_defined; } |