diff options
Diffstat (limited to 'Source/Utils/MsgLogger')
-rw-r--r-- | Source/Utils/MsgLogger/CMakeLists.txt | 4 | ||||
-rw-r--r-- | Source/Utils/MsgLogger/Make.package | 3 | ||||
-rw-r--r-- | Source/Utils/MsgLogger/MsgLogger.H | 293 | ||||
-rw-r--r-- | Source/Utils/MsgLogger/MsgLogger.cpp | 653 | ||||
-rw-r--r-- | Source/Utils/MsgLogger/MsgLoggerSerialization.H | 189 | ||||
-rw-r--r-- | Source/Utils/MsgLogger/MsgLogger_fwd.H | 24 |
6 files changed, 1166 insertions, 0 deletions
diff --git a/Source/Utils/MsgLogger/CMakeLists.txt b/Source/Utils/MsgLogger/CMakeLists.txt new file mode 100644 index 000000000..bd167d0e4 --- /dev/null +++ b/Source/Utils/MsgLogger/CMakeLists.txt @@ -0,0 +1,4 @@ +target_sources(WarpX + PRIVATE + MsgLogger.cpp +) diff --git a/Source/Utils/MsgLogger/Make.package b/Source/Utils/MsgLogger/Make.package new file mode 100644 index 000000000..e02476222 --- /dev/null +++ b/Source/Utils/MsgLogger/Make.package @@ -0,0 +1,3 @@ +CEXE_sources += MsgLogger.cpp + +VPATH_LOCATIONS += $(WARPX_HOME)/Source/Utils/MsgLogger diff --git a/Source/Utils/MsgLogger/MsgLogger.H b/Source/Utils/MsgLogger/MsgLogger.H new file mode 100644 index 000000000..ca55289b2 --- /dev/null +++ b/Source/Utils/MsgLogger/MsgLogger.H @@ -0,0 +1,293 @@ +/* Copyright 2021 Luca Fedeli + * + * This file is part of WarpX. + * + * License: BSD-3-Clause-LBNL + */ + +#ifndef WARPX_MSG_LOGGER_H_ +#define WARPX_MSG_LOGGER_H_ + +#include <AMReX.H> + +#include <cstdint> +#include <map> +#include <string> +#include <utility> +#include <vector> + +namespace Utils{ +namespace MsgLogger{ + + /** Priority is recorded together with messages. It influences + * the display order and the appearance of a message. + */ + enum class Priority + { + /** Low priority message */ + low, + /** Medium priority message */ + medium, + /** High priority message */ + high + }; + + /** + * \brief This function converts a Priority into the corresponding + * string (e.g, Priority::low --> "low") + * + * @param[in] priority the priority + * @return the corresponding string + */ + std::string PriorityToString(const Priority& priority); + + /** + * \brief This function converts a string into the corresponding + * priority (e.g, "low" --> Priority::low) + * + * @param[in] priority_string the priority string + * @return the corresponding priority + */ + Priority StringToPriority(const std::string& priority_string); + + /** + * This struct represents a message, which is composed by + * a topic, a text and a priority. It also provides methods for + * serialization and deserialization. + */ + struct Msg + { + std::string topic /*! The message topic*/; + std::string text /*! The message text*/; + Priority priority /*! The priority of the message*/; + + /** + * \brief This function returns a byte representation of the struct + * + * @return a byte vector + */ + std::vector<char> serialize() const; + + /** + * \brief This function generates a Msg struct from a byte vector + * + * @param[in] it iterator of a byte array + * @return a Msg struct + */ + static Msg deserialize(std::vector<char>::const_iterator& it); + + /** + * \brief Same as static Msg deserialize(std::vector<char>::const_iterator& it) + * but accepting an rvalue as an argument + * + * @param[in] it iterator of a byte array + * @return a Msg struct + */ + static Msg deserialize(std::vector<char>::const_iterator&& it); + }; + + /** + * This struct represents a message with counter, which is composed + * by a message and a counter. The latter is intended to store the + * number of times a message is recorded. The struct also provides + * methods for serialization and deserialization. + */ + struct MsgWithCounter + { + Msg msg /*! A message*/; + std::int64_t counter /*! The counter*/; + + /** + * \brief This function returns a byte representation of the struct + * + * @return a byte vector + */ + std::vector<char> serialize() const; + + /** + * \brief This function generates a MsgWithCounter struct from a byte vector + * + * @param[in] it iterator of a byte array + * @return a MsgWithCounter struct + */ + static MsgWithCounter deserialize(std::vector<char>::const_iterator& it); + + /** + * \brief Same as static Msg MsgWithCounter(std::vector<char>::const_iterator& it) + * but accepting an rvalue as an argument + * + * @param[in] it iterator of a byte array + * @return a MsgWithCounter struct + */ + static MsgWithCounter deserialize(std::vector<char>::const_iterator&& it); + }; + + /** + * This struct represents a message with counter and ranks, which is + * composed by a message with counter, a bool flag and a std::vector<int>. + * The bool flag is used to store if a message is emitted by all the ranks. + * The std::vector<int> is used to store the affected ranks + * (note: when we switch to C++17, should we consider variants?). + * The struct also provides methods for serialization and deserialization. + */ + struct MsgWithCounterAndRanks + { + MsgWithCounter msg_with_counter /*! A message with counter*/; + bool all_ranks /*! Flag to store if message is emitted by all ranks*/; + std::vector<int> ranks /*! Affected ranks*/; + + /** + * \brief This function returns a byte representation of the struct + * + * @return a byte vector + */ + std::vector<char> serialize() const; + + /** + * \brief This function generates a MsgWithCounterAndRanks struct from a byte vector + * + * @param[in] it iterator of a byte array + * @return a MsgWithCounterAndRanks struct + */ + static MsgWithCounterAndRanks deserialize(std::vector<char>::const_iterator& it); + + /** + * \brief Same as static Msg MsgWithCounterAndRanks(std::vector<char>::const_iterator& it) + * but accepting an rvalue as an argument + * + * @param[in] it iterator of a byte array + * @return a MsgWithCounterAndRanks struct + */ + static MsgWithCounterAndRanks deserialize(std::vector<char>::const_iterator&& it); + }; + + /** + * \brief This implements the < operator for Msg. + * Warning messages are first ordered by priority (warning: high < medium < low + * to give precedence to higher priorities), then by topic (alphabetically), + * and finally by text (alphabetically). + * + * @param[in] l a Msg + * @param[in] r a Msg + * @return true if l<r, false otherwise + */ + constexpr bool operator<(const Msg& l, const Msg& r) + { + return + (l.priority > r.priority) || + ((l.priority == r.priority) && (l.topic < r.topic)) || + ((l.priority == r.priority) && (l.topic == r.topic) && (l.text < r.text)); + } + + /** + * This class is responsible for storing messages and merging messages + * collected by different processes. + */ + class Logger + { + public: + + /** + * \brief The constructor. + */ + Logger(); + + /** + * \brief This function records a message + * + * @param[in] msg a Msg struct + */ + void record_msg(Msg msg); + + /** + * \brief This function returns a vector containing the recorded messages + * + * @return a vector of the recorded messages + */ + std::vector<Msg> get_msgs() const; + + /** + * \brief This function returns a vector containing the recorded messages + * with the corresponding counters + * + * @return a vector of the recorded messages with counters + */ + std::vector<MsgWithCounter> get_msgs_with_counter() const; + + /** + * \brief This collective function generates a vector containing the messages + * with counters and emitting ranks by gathering data from + * all the ranks + * + * @return a vector of messages with counters and ranks if I/O rank, an empty vector otherwise + */ + std::vector<MsgWithCounterAndRanks> + collective_gather_msgs_with_counter_and_ranks() const; + + private: + + /** + * \brief This function implements the trivial special case of + * collective_gather_msgs_with_counter_and_ranks when there is only one rank. + * + * @return a vector of messages with counters and ranks + */ + std::vector<MsgWithCounterAndRanks> + one_rank_gather_msgs_with_counter_and_ranks() const; + +#ifdef AMREX_USE_MPI + /** + * \brief This collective function finds the rank having the + * most messages and how many messages this rank has. The + * rank having the most messages is designated as "gather rank". + * + * @param[in] how_many_msgs the number of messages that the current rank has + * @return a pair containing the ID of the "gather rank" and its number of messages + */ + std::pair<int, int> find_gather_rank_and_its_msgs( + int how_many_msgs) const; + + /** + * \brief This function uses data gathered on the "gather rank" to generate + * a vector of messages with global counters and emitting rank lists + * + * @param[in] my_msg_map messages and counters of the current rank (as a map) + * @param[in] all_data a byte array containing all the data gathered on the gather rank + * @param[in] displacements a vector of displacements to access data corresponding to a given rank in all_data + * @param[in] gather_rank the ID of the "gather rank" + * @return if gather_rank==m_rank a vector of messages with global counters and emitting rank lists, dummy data otherwise + */ + std::vector<MsgWithCounterAndRanks> + compute_msgs_with_counter_and_ranks( + const std::map<Msg,std::int64_t>& my_msg_map, + const std::vector<char>& all_data, + const std::vector<int>& displacements, + const int gather_rank + ) const; + + + /** + * \brief If the gather_rank is not the I/O rank, this function sends msgs_with_counter_and_ranks + * to the I/O rank. This function uses point-to-point communications. + * + * @param[in] msgs_with_counter_and_ranks a vector of messages with counters and ranks + * @param[in] gather_rank the ID of the "gather rank" + */ + void + swap_with_io_rank( + std::vector<MsgWithCounterAndRanks>& msgs_with_counter_and_ranks, + int gather_rank) const; + +#endif + + int m_rank = 0 /*! MPI rank of the current process*/; + int m_num_procs = 0 /*! Number of MPI ranks*/; + int m_io_rank = 0 /*! Rank of the I/O process*/; + bool m_am_i_io = false /*! Flag to store if the process is responsible for I/O*/; + + std::map<Msg, std::int64_t> m_messages /*! This stores a map to associate warning messages with the corresponding counters*/; + }; +} +} + +#endif //WARPX_MSG_LOGGER_H_ diff --git a/Source/Utils/MsgLogger/MsgLogger.cpp b/Source/Utils/MsgLogger/MsgLogger.cpp new file mode 100644 index 000000000..c1c1f9324 --- /dev/null +++ b/Source/Utils/MsgLogger/MsgLogger.cpp @@ -0,0 +1,653 @@ +/* Copyright 2021 Luca Fedeli + * + * This file is part of WarpX. + * + * License: BSD-3-Clause-LBNL + */ + +#include "MsgLogger.H" + +#include "MsgLoggerSerialization.H" + +#ifdef AMREX_USE_MPI +# include <AMReX_ParallelDescriptor.H> +#endif +#include <AMReX_Print.H> + +#include <iostream> +#include <sstream> +#include <numeric> + +using namespace Utils::MsgLogger; + +#ifdef AMREX_USE_MPI +// Helper functions used only in this source file +namespace +{ + /** + * \brief This collective function returns the messages of the "gather rank" + * as a byte array. + * + * @param[in] my_msgs the messages of the current rank + * @param[in] gather_rank the ID of the "gather rank" + * @param[in] my_rank the ID of the current rank + * @return the messages of the "gather rank" as a byte array + */ + std::vector<char> + get_serialized_gather_rank_msgs( + const std::vector<Msg>& my_msgs, + const int gather_rank, + const int my_rank); + + /** + * \brief This function generates data to send back to the "gather rank" + * + * @param[in] serialized_gather_rank_msgs the serialized messages of the gather rank + * @param[in] gather_rank_how_many_msgs number of messages of the "gather rank" + * @param[in] my_msg_map messages and counters of the current rank (as a map) + * @param[in] is_gather_rank true if the rank is the "gather rank", false otherwise + * @return a byte array to send back to the "gather rank" (or a dummy vector in case is_gather_rank is true) + */ + std::vector<char> + compute_package_for_gather_rank( + const std::vector<char>& serialized_gather_rank_msgs, + const std::int64_t gather_rank_how_many_msgs, + const std::map<Msg, std::int64_t>& my_msg_map, + const bool is_gather_rank + ); + + /** + * \brief This collective function gathers data generated with compute_package_for_gather_rank + * to the gather rank. + * If my_rank != gather_rank the function returns dummy data. Otherwise the function returns + * a pair containing: + * 1) a byte array containing info on messages seen by other ranks + * 2) a vector of displacements to access data corresponding to a given rank + * + * @param[in] package_for_gather_rank a byte array generated by compute_package_for_gather_rank + * @param[in] gather_rank the ID of the "gather rank" + * @param[in] my_rank the ID of the current rank + * @return (see function description) + */ + std::pair<std::vector<char>, std::vector<int>> + gather_all_data( + const std::vector<char>& package_for_gather_rank, + const int gather_rank, const int my_rank); + + /** + * \brief This function converts a vector of Msg struct into a byte array + * + * @param[in] msgs the vector of Msg struct + * @return a byte array + */ + std::vector<char> serialize_msgs( + const std::vector<Msg>& msgs); + + /** + * \brief This function converts a byte array into a vector of Msg struct + * + * @param[in] serialized the byte array + * @return a vector of Msg struct + */ + std::vector<Msg> deserialize_msgs( + const std::vector<char>& serialized); +} +#endif + +std::string Utils::MsgLogger::PriorityToString(const Priority& priority) +{ + if(priority == Priority::high) + return "high"; + else if (priority == Priority::medium) + return "medium"; + else + return "low"; +} + +Priority Utils::MsgLogger::StringToPriority(const std::string& priority_string) +{ + if(priority_string == "high") + return Priority::high; + else if (priority_string == "medium") + return Priority::medium; + else if (priority_string == "low") + return Priority::low; + else + amrex::Abort( + "Priority string '" + priority_string + "' not recognized"); + + //this silences a "non-void function does not return a value in all control paths" warning + return Priority::low; +} + +std::vector<char> Msg::serialize() const +{ + std::vector<char> serialized_msg; + + put_in(this->topic, serialized_msg); + put_in(this->text, serialized_msg); + const int int_priority = static_cast<int>(this->priority); + put_in(int_priority, serialized_msg); + + return serialized_msg; +} + +Msg Msg::deserialize (std::vector<char>::const_iterator& it) +{ + Msg msg; + + msg.topic = get_out<std::string> (it); + msg.text = get_out<std::string> (it); + msg.priority = static_cast<Priority> (get_out<int> (it)); + + return msg; +} + +Msg Msg::deserialize (std::vector<char>::const_iterator&& it) +{ + return Msg::deserialize(it); +} + +std::vector<char> MsgWithCounter::serialize() const +{ + std::vector<char> serialized_msg_with_counter; + + put_in_vec(msg.serialize(), serialized_msg_with_counter); + put_in(this->counter, serialized_msg_with_counter); + + return serialized_msg_with_counter; +} + +MsgWithCounter MsgWithCounter::deserialize (std::vector<char>::const_iterator& it) +{ + MsgWithCounter msg_with_counter; + + const auto vec = get_out_vec<char>(it); + auto iit = vec.begin(); + msg_with_counter.msg = Msg::deserialize(iit); + msg_with_counter.counter = get_out<std::int64_t> (it); + + return msg_with_counter; +} + +MsgWithCounter MsgWithCounter::deserialize (std::vector<char>::const_iterator&& it) +{ + return MsgWithCounter::deserialize(it); +} + +std::vector<char> MsgWithCounterAndRanks::serialize() const +{ + std::vector<char> serialized_msg_with_counter_and_ranks; + + put_in_vec(this->msg_with_counter.serialize(), serialized_msg_with_counter_and_ranks); + put_in(this->all_ranks, serialized_msg_with_counter_and_ranks); + put_in_vec(this->ranks, serialized_msg_with_counter_and_ranks); + + return serialized_msg_with_counter_and_ranks; +} + +MsgWithCounterAndRanks +MsgWithCounterAndRanks::deserialize (std::vector<char>::const_iterator& it) +{ + MsgWithCounterAndRanks msg_with_counter_and_ranks; + + const auto vec = get_out_vec<char>(it); + auto iit = vec.begin(); + msg_with_counter_and_ranks.msg_with_counter = MsgWithCounter::deserialize(iit); + msg_with_counter_and_ranks.all_ranks = get_out<bool>(it); + msg_with_counter_and_ranks.ranks = get_out_vec<int>(it); + + return msg_with_counter_and_ranks; +} + +MsgWithCounterAndRanks +MsgWithCounterAndRanks::deserialize (std::vector<char>::const_iterator&& it) +{ + return MsgWithCounterAndRanks::deserialize(it); +} + +Logger::Logger(){ + m_rank = amrex::ParallelDescriptor::MyProc(); + m_num_procs = amrex::ParallelDescriptor::NProcs(); + m_io_rank = amrex::ParallelDescriptor::IOProcessorNumber(); + m_am_i_io = (m_rank == m_io_rank); +} + +void Logger::record_msg(Msg msg) +{ + m_messages[msg]++; +} + +std::vector<Msg> Logger::get_msgs() const +{ + auto res = std::vector<Msg>{}; + + for (const auto& msg_w_counter : m_messages) + res.emplace_back(msg_w_counter.first); + + return res; +} + +std::vector<MsgWithCounter> Logger::get_msgs_with_counter() const +{ + auto res = std::vector<MsgWithCounter>{}; + + for (const auto& msg : m_messages) + res.emplace_back(MsgWithCounter{msg.first, msg.second}); + + return res; +} + +std::vector<MsgWithCounterAndRanks> +Logger::collective_gather_msgs_with_counter_and_ranks() const +{ + +#ifdef AMREX_USE_MPI + + // Trivial case of only one rank + if (m_num_procs == 1) + return one_rank_gather_msgs_with_counter_and_ranks(); + + // Find out who is the "gather rank" and how many messages it has + const auto my_msgs = get_msgs(); + const auto how_many_msgs = my_msgs.size(); + int gather_rank = 0; + std::int64_t gather_rank_how_many_msgs = 0; + std::tie(gather_rank, gather_rank_how_many_msgs) = + find_gather_rank_and_its_msgs(how_many_msgs); + + // If the "gather rank" has zero messages there are no messages at all + if(gather_rank_how_many_msgs == 0) + return std::vector<MsgWithCounterAndRanks>{}; + + // All the ranks receive the msgs of the "gather rank" as a byte array + const auto serialized_gather_rank_msgs = + ::get_serialized_gather_rank_msgs(my_msgs, gather_rank, m_rank); + + // Each rank assembles a message to send back to the "gather rank" + const bool is_gather_rank = (gather_rank == m_rank); + const auto package_for_gather_rank = + ::compute_package_for_gather_rank( + serialized_gather_rank_msgs, + gather_rank_how_many_msgs, + m_messages, is_gather_rank); + + // Send back all the data to the "gather rank" + auto all_data = std::vector<char>{}; + auto displacements = std::vector<int>{}; + std::tie(all_data, displacements) = + ::gather_all_data( + package_for_gather_rank, + gather_rank, m_rank); + + // Use the gathered data to generate (on the "gather rank") a vector of all the + // messages seen by all the ranks with the corresponding counters and + // emitting rank lists. + auto msgs_with_counter_and_ranks = + compute_msgs_with_counter_and_ranks( + m_messages, + all_data, + displacements, + gather_rank); + + // If the current rank is not the I/O rank, send msgs_with_counter_and_ranks + // to the I/O rank + swap_with_io_rank( + msgs_with_counter_and_ranks, + gather_rank); + + return msgs_with_counter_and_ranks; +#else + return one_rank_gather_msgs_with_counter_and_ranks(); +#endif +} + +std::vector<MsgWithCounterAndRanks> +Logger::one_rank_gather_msgs_with_counter_and_ranks() const +{ + std::vector<MsgWithCounterAndRanks> res; + for (const auto& el : m_messages) + { + res.emplace_back( + MsgWithCounterAndRanks{ + MsgWithCounter{el.first, el.second}, + true, + std::vector<int>{m_rank}}); + } + return res; +} + +#ifdef AMREX_USE_MPI + +std::pair<int,int> Logger::find_gather_rank_and_its_msgs(int how_many_msgs) const +{ + int max_items = 0; + int max_rank = 0; + + const auto num_msg = + amrex::ParallelDescriptor::Gather(how_many_msgs, m_io_rank); + + if (m_am_i_io){ + const auto it_max = std::max_element(num_msg.begin(), num_msg.end()); + max_items = *it_max; + + //In case of an "ex aequo" the I/O rank should be the gather rank + max_rank = (max_items == how_many_msgs) ? + m_io_rank : it_max - num_msg.begin(); + } + + auto package = std::array<int,2>{max_rank, max_items}; + amrex::ParallelDescriptor::Bcast(package.data(), 2, m_io_rank); + + return std::make_pair(package[0], package[1]); +} + +std::vector<MsgWithCounterAndRanks> +Logger::compute_msgs_with_counter_and_ranks( + const std::map<Msg,std::int64_t>& my_msg_map, + const std::vector<char>& all_data, + const std::vector<int>& displacements, + const int gather_rank) const +{ + if(m_rank != gather_rank) return std::vector<MsgWithCounterAndRanks>{}; + + std::vector<MsgWithCounterAndRanks> msgs_with_counter_and_ranks; + + // Put messages of the gather rank in msgs_with_counter_and_ranks + for (const auto& el : my_msg_map) + { + msgs_with_counter_and_ranks.emplace_back( + MsgWithCounterAndRanks{ + MsgWithCounter{el.first, el.second}, + false, + std::vector<int>{m_rank}}); + } + + // We need a temporary map + std::map<Msg, MsgWithCounterAndRanks> tmap; + +#ifdef AMREX_USE_OMP + #pragma omp parallel for +#endif + for(int rr = 0; rr < m_num_procs; ++rr){ //for each rank + if(rr == gather_rank) // (skip gather_rank) + continue; + + // get counters generated by rank rr + auto it = all_data.begin() + displacements[rr]; + const auto counters_rr = get_out_vec<std::int64_t>(it); + + //for each counter from rank rr + std::int64_t c = 0; + for (const auto& counter : counters_rr){ +#ifdef AMREX_USE_OMP + #pragma omp atomic +#endif + msgs_with_counter_and_ranks[c].msg_with_counter.counter += + counter; //update corresponding global counter + + //and add rank to rank list if it has emitted the message + if (counter > 0){ +#ifdef AMREX_USE_OMP + #pragma omp critical +#endif + { + msgs_with_counter_and_ranks[c].ranks.push_back(rr); + } + } + c++; + } + + // for each additional message coming from rank rr + const auto how_many_additional_msgs_with_counter = get_out<int>(it); + for(int i = 0; i < how_many_additional_msgs_with_counter; ++i){ + + //deserialize the message + const auto serialized_msg_with_counter = get_out_vec<char>(it); + auto msg_with_counter = + MsgWithCounter::deserialize(serialized_msg_with_counter.begin()); + + //and eventually add it to the temporary map +#ifdef AMREX_USE_OMP + #pragma omp critical +#endif + { + if (tmap.find(msg_with_counter.msg) == tmap.end()){ + const auto msg_with_counter_and_ranks = + MsgWithCounterAndRanks{ + msg_with_counter, + false, + std::vector<int>{rr} + }; + tmap[msg_with_counter.msg] = msg_with_counter_and_ranks; + } + else{ + tmap[msg_with_counter.msg].msg_with_counter.counter += + msg_with_counter.counter; + tmap[msg_with_counter.msg].ranks.push_back(rr); + } + } + } + } + + // Check if messages emitted by "gather rank" are actually emitted by all ranks + const auto ssize = static_cast<int>(msgs_with_counter_and_ranks.size()); + for (int i = 0; i < ssize; ++i){ + const auto how_many = + static_cast<int>(msgs_with_counter_and_ranks[i].ranks.size()); + if(how_many == m_num_procs){ + msgs_with_counter_and_ranks[i].all_ranks = true; + // trick to force free memory + std::vector<int>{}.swap(msgs_with_counter_and_ranks[i].ranks); + } + } + + // Add elements from the temporary map + for(const auto& el : tmap){ + msgs_with_counter_and_ranks.push_back(el.second); + } + + // Sort affected ranks lists + for(auto& el : msgs_with_counter_and_ranks){ + std::sort(el.ranks.begin(), el.ranks.end()); + } + + return msgs_with_counter_and_ranks; +} + +void Logger::swap_with_io_rank( + std::vector<MsgWithCounterAndRanks>& msgs_with_counter_and_ranks, + int gather_rank) const +{ + if (gather_rank != m_io_rank){ + if(m_rank == gather_rank){ + auto package = std::vector<char>{}; + for (const auto& el: msgs_with_counter_and_ranks) + put_in_vec<char>(el.serialize(), package); + + auto package_size = static_cast<int>(package.size()); + amrex::ParallelDescriptor::Send(&package_size, 1, m_io_rank, 0); + amrex::ParallelDescriptor::Send(package, m_io_rank, 1); + int list_size = static_cast<int>(msgs_with_counter_and_ranks.size()); + amrex::ParallelDescriptor::Send(&list_size, 1, m_io_rank, 2); + } + else if (m_rank == m_io_rank){ + int vec_size = 0; + amrex::ParallelDescriptor::Recv(&vec_size, 1, gather_rank, 0); + std::vector<char> package(vec_size); + amrex::ParallelDescriptor::Recv(package, gather_rank, 1); + int list_size = 0; + amrex::ParallelDescriptor::Recv(&list_size, 1, gather_rank, 2); + auto it = package.cbegin(); + for (int i = 0; i < list_size; ++i){ + const auto vec = get_out_vec<char>(it); + msgs_with_counter_and_ranks.emplace_back( + MsgWithCounterAndRanks::deserialize(vec.begin()) + ); + } + } + } +} + +namespace +{ +std::vector<char> +get_serialized_gather_rank_msgs( + const std::vector<Msg>& my_msgs, + const int gather_rank, + const int my_rank) +{ + const bool is_gather_rank = (my_rank == gather_rank); + + auto serialized_gather_rank_msgs = std::vector<char>{}; + int size_serialized_gather_rank_msgs = 0; + + if (is_gather_rank){ + serialized_gather_rank_msgs = ::serialize_msgs(my_msgs); + size_serialized_gather_rank_msgs = static_cast<int>( + serialized_gather_rank_msgs.size()); + } + + amrex::ParallelDescriptor::Bcast( + &size_serialized_gather_rank_msgs, 1, gather_rank); + + if (!is_gather_rank) + serialized_gather_rank_msgs.resize( + size_serialized_gather_rank_msgs); + + amrex::ParallelDescriptor::Bcast( + serialized_gather_rank_msgs.data(), + size_serialized_gather_rank_msgs, gather_rank); + + return serialized_gather_rank_msgs; +} + +std::vector<char> +compute_package_for_gather_rank( + const std::vector<char>& serialized_gather_rank_msgs, + const std::int64_t gather_rank_how_many_msgs, + const std::map<Msg, std::int64_t>& my_msg_map, + const bool is_gather_rank) +{ + if(!is_gather_rank){ + auto package = std::vector<char>{}; + + //generates a copy of the message map + auto msgs_to_send = std::map<Msg, std::int64_t>{my_msg_map}; + + // For each message of the "gather rank" store how many times + // the message has been emitted by the current ranks. + const auto gather_rank_msgs = + ::deserialize_msgs(serialized_gather_rank_msgs); + std::vector<std::int64_t> gather_rank_msg_counters(gather_rank_how_many_msgs); + std::int64_t counter = 0; + for (const auto& msg : gather_rank_msgs){ + const auto pp = msgs_to_send.find(msg); + if (pp != msgs_to_send.end()){ + gather_rank_msg_counters[counter] += pp->second; + // Remove messages already seen by "gather rank" from + // the messages to send back + msgs_to_send.erase(msg); + } + counter++; + } + put_in_vec(gather_rank_msg_counters, package); + + // Add the additional messages seen by the current rank to the package + put_in(static_cast<int>(msgs_to_send.size()), package); + for (const auto& el : msgs_to_send) + put_in_vec<char>( + MsgWithCounter{el.first, el.second}.serialize(), package); + + return package; + } + + return std::vector<char>{}; +} + +std::pair<std::vector<char>, std::vector<int>> +gather_all_data( + const std::vector<char>& package_for_gather_rank, + const int gather_rank, const int my_rank) +{ + auto package_lengths = std::vector<int>{}; + auto all_data = std::vector<char>{}; + auto displacements = std::vector<int>{}; + + if(gather_rank != my_rank){ + amrex::ParallelDescriptor::Gather( + static_cast<int>(package_for_gather_rank.size()), gather_rank); + amrex::ParallelDescriptor::Gatherv( + package_for_gather_rank.data(), + package_for_gather_rank.size(), + all_data.data(), + package_lengths, + displacements, + gather_rank); + } + else{ + const int zero_size = 0; + package_lengths = + amrex::ParallelDescriptor::Gather(zero_size, gather_rank); + + // Compute displacements + // Given (n1, n2, n3, n4, ..., n_n) we need (0, n1, n1+n2, n1+n2+n3, ...), + // but partial_sum gives us (n1,n1+n2, n1+n2+n3, n1+n2+n3+n4, ...). + // Rotating this last vector by one is just shifting: (n1+n2+n3+n4+...,n1, n1+n2, n1+n2+n3, ...). + // Then we just need to replace the first element with zero: (0,n1, n1+n2, n1+n2+n3, ...). + displacements.resize(package_lengths.size()); + std::partial_sum(package_lengths.begin(), package_lengths.end(), + displacements.begin()); + const auto total_sum = displacements.back(); + std::rotate(displacements.rbegin(), + displacements.rbegin()+1, + displacements.rend()); + displacements[0] = 0; + + all_data.resize(total_sum); + + amrex::ParallelDescriptor::Gatherv( + static_cast<char*>(nullptr), + 0, + all_data.data(), + package_lengths, + displacements, + gather_rank); + } + return std::make_pair(all_data, displacements); +} + +std::vector<char> serialize_msgs( + const std::vector<Msg>& msgs) +{ + auto serialized = std::vector<char>{}; + + const auto how_many = static_cast<int> (msgs.size()); + put_in (how_many, serialized); + + for (auto msg : msgs){ + put_in_vec(msg.serialize(), serialized); + } + return serialized; +} + +std::vector<Msg> deserialize_msgs( + const std::vector<char>& serialized) +{ + auto it = serialized.begin(); + + const auto how_many = get_out<int>(it); + auto msgs = std::vector<Msg>{}; + msgs.reserve(how_many); + + for (int i = 0; i < how_many; ++i){ + const auto vv = get_out_vec<char>(it); + msgs.emplace_back(Msg::deserialize(vv.begin())); + } + + return msgs; +} +} + +#endif + diff --git a/Source/Utils/MsgLogger/MsgLoggerSerialization.H b/Source/Utils/MsgLogger/MsgLoggerSerialization.H new file mode 100644 index 000000000..fba8bb0d1 --- /dev/null +++ b/Source/Utils/MsgLogger/MsgLoggerSerialization.H @@ -0,0 +1,189 @@ +/* Copyright 2021 Luca Fedeli + * + * This file is part of WarpX. + * + * License: BSD-3-Clause-LBNL + */ + +#ifndef WARPX_MSG_LOGGER_SERIALIZATION_H_ +#define WARPX_MSG_LOGGER_SERIALIZATION_H_ + +#include <algorithm> +#include <array> +#include <cstring> +#include <string> +#include <type_traits> +#include <vector> + +namespace Utils{ +namespace MsgLogger{ + + /** + * This function transforms a variable of type T into a vector of chars holding its + * byte representation and it appends this vector at the end of an + * existing vector of chars. T must be either a trivially copyable type or an std::string + * (see specialization) + * + * @tparam T the variable type + * @param[in] val a variable of type T to be serialized + * @param[in, out] vec a reference to the vector to which the byte representation of val is appended + */ + template <typename T> + void put_in(const T& val, std::vector<char>& vec) + { + static_assert(std::is_trivially_copyable<T>(), + "Cannot serialize non-trivally copyable types, except std::string."); + + const auto* ptr_val = reinterpret_cast<const char*>(&val); + vec.insert(vec.end(), ptr_val, ptr_val+sizeof(T)); + } + + /** + * This function transforms a string into a vector of chars holding its + * byte representation and it appends this vector at the end of an + * existing vector of chars (specialization of put_in<T>). + * + * @param[in] val a std::string to be serialized + * @param[in, out] vec a reference to the vector to which the byte representation of val is appended + */ + template <> + inline void put_in<std::string> (const std::string& val, std::vector<char>& vec) + { + const char* c_str = val.c_str(); + const auto length = static_cast<int>(val.size()); + + put_in(length, vec); + vec.insert(vec.end(), c_str, c_str+length); + } + + /** + * This function transforms an std::vector<T> into a vector of chars holding its + * byte representation and it appends this vector at the end of an + * existing vector of chars. T must be either a trivially copyable type or an std::string. + * A specialization exists in case val is a vector of chars. + * + * @tparam T the variable type + * @param[in] val a variable of type T to be serialized + * @param[in, out] vec a reference to the vector to which the byte representation of val is appended + */ + template <typename T> + inline void put_in_vec (const std::vector<T>& val, std::vector<char>& vec) + { + static_assert(std::is_trivially_copyable<T>() || std::is_same<T,std::string>(), + "Cannot serialize vectors of non-trivally copyable types" + ", except vectors of std::string."); + + put_in(static_cast<int>(val.size()), vec); + for (const auto& el : val) + put_in(el, vec); + } + + /** + * This function transforms an std::vector<char> into a vector of chars holding its + * byte representation and it appends this vector at the end of an + * existing vector of chars (specialization of put_in_vec<T>). + * + * @tparam T the variable type + * @param[in] val a variable of type T to be serialized + * @param[in, out] vec a reference to the vector to which the byte representation of val is appended + */ + template <> + inline void put_in_vec <char> (const std::vector<char>& val, std::vector<char>& vec) + { + put_in(static_cast<int>(val.size()), vec); + vec.insert(vec.end(), val.begin(), val.end()); + } + + /** + * This function extracts a variable of type T from a byte vector, at the position + * given by a std::vector<char> iterator. The iterator is then advanced according to + * the number of bytes read from the byte vector. T must be either a trivially copyable type + * or an std::string (see specialization below). + * + * @tparam T the variable type (must be trivially copyable) + * @param[in, out] it the iterator to a byte vector + * @return the variable extracted from the byte array + */ + template<typename T> + T get_out(std::vector<char>::const_iterator& it) + { + static_assert(std::is_trivially_copyable<T>(), + "Cannot extract non-trivally copyable types from char vectors," + " with the exception of std::string."); + + auto temp = std::array<char, sizeof(T)>{}; + std::copy(it, it + sizeof(T), temp.begin()); + it += sizeof(T); + T res; + std::memcpy(&res, temp.data(), sizeof(T)); + + return res; + } + + /** + * This function extracts an std::string from a byte vector, at the position + * given by a std::vector<char> iterator. The iterator is then advanced according to + * the number of bytes read from the byte vector. This is a specialization of + * get_out<T> + * + * @param[in, out] it the iterator to a byte vector + * @return the std::string extracted from the byte array + */ + template<> + inline std::string get_out<std::string> (std::vector<char>::const_iterator& it) + { + const auto length = get_out<int> (it); + const auto str = std::string{it, it+length}; + it += length; + + return str; + } + + /** + * This function extracts an std::vector<T> from a byte vector, at the position + * given by a std::vector<char> iterator. The iterator is then advanced according to + * the number of bytes read from the byte vector. T must be either a trivially copyable type + * or an std::string. + * + * @tparam T the variable type (must be trivially copyable) + * @param[in, out] it the iterator to a byte vector + * @return the variable extracted from the byte array + */ + template<typename T> + inline std::vector<T> get_out_vec (std::vector<char>::const_iterator& it) + { + static_assert(std::is_trivially_copyable<T>() || std::is_same<T,std::string>(), + "Cannot extract non-trivally copyable types from char vectors," + " with the exception of std::string."); + + const auto length = get_out<int> (it); + std::vector<T> res(length); + for (int i = 0; i < length; ++i) + res[i] = get_out<T>(it); + + return res; + } + + /** + * This function extracts an std::vector<char> from a byte vector, at the position + * given by a std::vector<char> iterator. The iterator is then advanced according to + * the number of bytes read from the byte vector. This is a specialization of get_out_vec<T>. + * + * @param[in, out] it the iterator to a byte vector + * @return the variable extracted from the byte array + */ + template<> + inline std::vector<char> get_out_vec<char> (std::vector<char>::const_iterator& it) + { + const auto length = get_out<int> (it); + std::vector<char> res(length); + std::copy(it, it+length, res.begin()); + it += length; + + return res; + } + +} +} + +#endif //WARPX_MSG_LOGGER_SERIALIZATION_H_ diff --git a/Source/Utils/MsgLogger/MsgLogger_fwd.H b/Source/Utils/MsgLogger/MsgLogger_fwd.H new file mode 100644 index 000000000..626348670 --- /dev/null +++ b/Source/Utils/MsgLogger/MsgLogger_fwd.H @@ -0,0 +1,24 @@ +/* Copyright 2021 Luca Fedeli + * + * This file is part of WarpX. + * + * License: BSD-3-Clause-LBNL + */ + +#ifndef WARPX_MSG_LOGGER_FWD_H +#define WARPX_MSG_LOGGER_FWD_H + +namespace Utils{ +namespace MsgLogger{ + + enum class Priority; + + struct Msg; + struct MsgWithCounter; + struct MsgWithCounterAndRanks; + + class Logger; +} +} + +#endif //WARPX_MSG_LOGGER_FWD_H |