aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Yujia Qiao <rapiz3142@gmail.com> 2021-12-26 22:59:12 +0800
committerGravatar Yujia Qiao <contact@rapiz.me> 2021-12-27 15:50:13 +0800
commitc8cb60708d51f17d893defc587dc2165c109db2c (patch)
tree801bf9749e3b2a269086271962ef3d567c14f357
parentc8e679fa6539fa90b33f172544a0805fd41dc6c3 (diff)
downloadrathole-c8cb60708d51f17d893defc587dc2165c109db2c.tar.gz
rathole-c8cb60708d51f17d893defc587dc2165c109db2c.tar.zst
rathole-c8cb60708d51f17d893defc587dc2165c109db2c.zip
test: refactor and add tests for hot-reload
-rw-r--r--src/client.rs10
-rw-r--r--src/config.rs10
-rw-r--r--src/config_watcher.rs332
-rw-r--r--src/lib.rs16
-rw-r--r--src/server.rs28
5 files changed, 277 insertions, 119 deletions
diff --git a/src/client.rs b/src/client.rs
index 7591ab1..9f1bdcb 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1,5 +1,5 @@
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
-use crate::config_watcher::ServiceChangeEvent;
+use crate::config_watcher::ServiceChange;
use crate::helper::udp_connect;
use crate::protocol::Hello::{self, *};
use crate::protocol::{
@@ -30,7 +30,7 @@ use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
pub async fn run_client(
config: &Config,
shutdown_rx: broadcast::Receiver<bool>,
- service_rx: mpsc::Receiver<ServiceChangeEvent>,
+ service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
let config = match &config.client {
Some(v) => v,
@@ -93,7 +93,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
- mut service_rx: mpsc::Receiver<ServiceChangeEvent>,
+ mut service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
for (name, config) in &self.config.services {
// Create a control channel for each service defined
@@ -120,7 +120,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
e = service_rx.recv() => {
if let Some(e) = e {
match e {
- ServiceChangeEvent::ClientAdd(s)=> {
+ ServiceChange::ClientAdd(s)=> {
let name = s.name.clone();
let handle = ControlChannelHandle::new(
s,
@@ -129,7 +129,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
);
let _ = self.service_handles.insert(name, handle);
},
- ServiceChangeEvent::ClientDelete(s)=> {
+ ServiceChange::ClientDelete(s)=> {
let _ = self.service_handles.remove(&s);
},
_ => ()
diff --git a/src/config.rs b/src/config.rs
index af09176..9fead6c 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -38,11 +38,17 @@ pub enum ServiceType {
Udp,
}
+impl Default for ServiceType {
+ fn default() -> Self {
+ ServiceType::Tcp
+ }
+}
+
fn default_service_type() -> ServiceType {
- ServiceType::Tcp
+ Default::default()
}
-#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
+#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
pub struct ServerServiceConfig {
#[serde(rename = "type", default = "default_service_type")]
pub service_type: ServiceType,
diff --git a/src/config_watcher.rs b/src/config_watcher.rs
index 79e84f4..c7c2ff5 100644
--- a/src/config_watcher.rs
+++ b/src/config_watcher.rs
@@ -1,5 +1,5 @@
use crate::{
- config::{ClientServiceConfig, ServerServiceConfig},
+ config::{ClientConfig, ClientServiceConfig, ServerConfig, ServerServiceConfig},
Config,
};
use anyhow::{Context, Result};
@@ -13,22 +13,87 @@ use tracing::{error, info, instrument};
#[cfg(feature = "notify")]
use notify::{event::ModifyKind, EventKind, RecursiveMode, Watcher};
-#[derive(Debug)]
-pub enum ConfigChangeEvent {
+#[derive(Debug, PartialEq)]
+pub enum ConfigChange {
General(Box<Config>), // Trigger a full restart
- ServiceChange(ServiceChangeEvent),
+ ServiceChange(ServiceChange),
}
-#[derive(Debug)]
-pub enum ServiceChangeEvent {
+#[derive(Debug, PartialEq)]
+pub enum ServiceChange {
ClientAdd(ClientServiceConfig),
ClientDelete(String),
ServerAdd(ServerServiceConfig),
ServerDelete(String),
}
+impl From<ClientServiceConfig> for ServiceChange {
+ fn from(c: ClientServiceConfig) -> Self {
+ ServiceChange::ClientAdd(c)
+ }
+}
+
+impl From<ServerServiceConfig> for ServiceChange {
+ fn from(c: ServerServiceConfig) -> Self {
+ ServiceChange::ServerAdd(c)
+ }
+}
+
+trait InstanceConfig: Clone {
+ type ServiceConfig: Into<ServiceChange> + PartialEq + Clone;
+ fn equal_without_service(&self, rhs: &Self) -> bool;
+ fn to_service_change_delete(s: String) -> ServiceChange;
+ fn get_services(&self) -> &HashMap<String, Self::ServiceConfig>;
+}
+
+impl InstanceConfig for ServerConfig {
+ type ServiceConfig = ServerServiceConfig;
+ fn equal_without_service(&self, rhs: &Self) -> bool {
+ let left = ServerConfig {
+ services: Default::default(),
+ ..self.clone()
+ };
+
+ let right = ServerConfig {
+ services: Default::default(),
+ ..rhs.clone()
+ };
+
+ left == right
+ }
+ fn to_service_change_delete(s: String) -> ServiceChange {
+ ServiceChange::ServerDelete(s)
+ }
+ fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
+ &self.services
+ }
+}
+
+impl InstanceConfig for ClientConfig {
+ type ServiceConfig = ClientServiceConfig;
+ fn equal_without_service(&self, rhs: &Self) -> bool {
+ let left = ClientConfig {
+ services: Default::default(),
+ ..self.clone()
+ };
+
+ let right = ClientConfig {
+ services: Default::default(),
+ ..rhs.clone()
+ };
+
+ left == right
+ }
+ fn to_service_change_delete(s: String) -> ServiceChange {
+ ServiceChange::ClientDelete(s)
+ }
+ fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
+ &self.services
+ }
+}
+
pub struct ConfigWatcherHandle {
- pub event_rx: mpsc::Receiver<ConfigChangeEvent>,
+ pub event_rx: mpsc::Receiver<ConfigChange>,
}
impl ConfigWatcherHandle {
@@ -39,7 +104,7 @@ impl ConfigWatcherHandle {
// Initial start
event_tx
- .send(ConfigChangeEvent::General(Box::new(origin_cfg.clone())))
+ .send(ConfigChange::General(Box::new(origin_cfg.clone())))
.await
.unwrap();
@@ -59,30 +124,33 @@ impl ConfigWatcherHandle {
async fn config_watcher(
_path: PathBuf,
mut shutdown_rx: broadcast::Receiver<bool>,
- _cfg_event_tx: mpsc::Sender<ConfigChangeEvent>,
+ _event_tx: mpsc::Sender<ConfigChange>,
_old: Config,
) -> Result<()> {
- // Do nothing except wating for ctrl-c
+ // Do nothing except waiting for ctrl-c
let _ = shutdown_rx.recv().await;
Ok(())
}
#[cfg(feature = "notify")]
-#[instrument(skip(shutdown_rx, cfg_event_tx, old))]
+#[instrument(skip(shutdown_rx, event_tx, old))]
async fn config_watcher(
path: PathBuf,
mut shutdown_rx: broadcast::Receiver<bool>,
- cfg_event_tx: mpsc::Sender<ConfigChangeEvent>,
+ event_tx: mpsc::Sender<ConfigChange>,
mut old: Config,
) -> Result<()> {
let (fevent_tx, mut fevent_rx) = mpsc::channel(16);
- let mut watcher = notify::recommended_watcher(move |res| match res {
- Ok(event) => {
- let _ = fevent_tx.blocking_send(event);
- }
- Err(e) => error!("watch error: {:?}", e),
- })?;
+ let mut watcher =
+ notify::recommended_watcher(move |res: Result<notify::Event, _>| match res {
+ Ok(e) => {
+ if let EventKind::Modify(ModifyKind::Data(_)) = e.kind {
+ let _ = fevent_tx.blocking_send(true);
+ }
+ }
+ Err(e) => error!("watch error: {:?}", e),
+ })?;
watcher.watch(&path, RecursiveMode::NonRecursive)?;
info!("Start watching the config");
@@ -91,12 +159,7 @@ async fn config_watcher(
tokio::select! {
e = fevent_rx.recv() => {
match e {
- Some(e) => {
- if let EventKind::Modify(kind) = e.kind {
- match kind {
- ModifyKind::Data(_) => (),
- _ => continue
- }
+ Some(_) => {
info!("Rescan the configuration");
let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") {
Ok(v) => v,
@@ -107,12 +170,11 @@ async fn config_watcher(
}
};
- for event in calculate_event(&old, &new) {
- cfg_event_tx.send(event).await?;
+ for event in calculate_events(&old, &new) {
+ event_tx.send(event).await?;
}
old = new;
- }
},
None => break
}
@@ -126,74 +188,170 @@ async fn config_watcher(
Ok(())
}
-fn calculate_event(old: &Config, new: &Config) -> Vec<ConfigChangeEvent> {
- let mut ret = Vec::new();
-
- if old != new {
- if old.server.is_some() && new.server.is_some() {
- let mut e: Vec<ConfigChangeEvent> = calculate_service_delete_event(
- &old.server.as_ref().unwrap().services,
- &new.server.as_ref().unwrap().services,
- )
- .into_iter()
- .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerDelete(x)))
- .collect();
- ret.append(&mut e);
-
- let mut e: Vec<ConfigChangeEvent> = calculate_service_add_event(
- &old.server.as_ref().unwrap().services,
- &new.server.as_ref().unwrap().services,
- )
- .into_iter()
- .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ServerAdd(x)))
- .collect();
-
- ret.append(&mut e);
- } else if old.client.is_some() && new.client.is_some() {
- let mut e: Vec<ConfigChangeEvent> = calculate_service_delete_event(
- &old.client.as_ref().unwrap().services,
- &new.client.as_ref().unwrap().services,
- )
- .into_iter()
- .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientDelete(x)))
- .collect();
- ret.append(&mut e);
-
- let mut e: Vec<ConfigChangeEvent> = calculate_service_add_event(
- &old.client.as_ref().unwrap().services,
- &new.client.as_ref().unwrap().services,
- )
- .into_iter()
- .map(|x| ConfigChangeEvent::ServiceChange(ServiceChangeEvent::ClientAdd(x)))
- .collect();
-
- ret.append(&mut e);
+fn calculate_events(old: &Config, new: &Config) -> Vec<ConfigChange> {
+ if old == new {
+ vec![]
+ } else if old.server != new.server {
+ if old.server.is_some() != new.server.is_some() {
+ vec![ConfigChange::General(Box::new(new.clone()))]
+ } else {
+ match calculate_instance_config_events(
+ old.server.as_ref().unwrap(),
+ new.server.as_ref().unwrap(),
+ ) {
+ Some(v) => v,
+ None => vec![ConfigChange::General(Box::new(new.clone()))],
+ }
+ }
+ } else if old.client != new.client {
+ if old.client.is_some() != new.client.is_some() {
+ vec![ConfigChange::General(Box::new(new.clone()))]
} else {
- ret.push(ConfigChangeEvent::General(Box::new(new.clone())));
+ match calculate_instance_config_events(
+ old.client.as_ref().unwrap(),
+ new.client.as_ref().unwrap(),
+ ) {
+ Some(v) => v,
+ None => vec![ConfigChange::General(Box::new(new.clone()))],
+ }
}
+ } else {
+ vec![]
+ }
+}
+
+// None indicates a General change needed
+fn calculate_instance_config_events<T: InstanceConfig>(
+ old: &T,
+ new: &T,
+) -> Option<Vec<ConfigChange>> {
+ if !old.equal_without_service(new) {
+ return None;
}
- ret
+ let old = old.get_services();
+ let new = new.get_services();
+
+ let mut v = vec![];
+ v.append(&mut calculate_service_delete_events::<T>(old, new));
+ v.append(&mut calculate_service_add_events(old, new));
+
+ Some(v.into_iter().map(ConfigChange::ServiceChange).collect())
}
-fn calculate_service_delete_event<T: PartialEq>(
- old_services: &HashMap<String, T>,
- new_services: &HashMap<String, T>,
-) -> Vec<String> {
- old_services
- .keys()
- .filter(|&name| old_services.get(name) != new_services.get(name))
- .map(|x| x.to_owned())
+fn calculate_service_delete_events<T: InstanceConfig>(
+ old: &HashMap<String, T::ServiceConfig>,
+ new: &HashMap<String, T::ServiceConfig>,
+) -> Vec<ServiceChange> {
+ old.keys()
+ .filter(|&name| new.get(name).is_none())
+ .map(|x| T::to_service_change_delete(x.to_owned()))
.collect()
}
-fn calculate_service_add_event<T: PartialEq + Clone>(
- old_services: &HashMap<String, T>,
- new_services: &HashMap<String, T>,
-) -> Vec<T> {
- new_services
- .iter()
- .filter(|(name, _)| old_services.get(*name) != new_services.get(*name))
- .map(|(_, c)| c.clone())
+fn calculate_service_add_events<T: PartialEq + Clone + Into<ServiceChange>>(
+ old: &HashMap<String, T>,
+ new: &HashMap<String, T>,
+) -> Vec<ServiceChange> {
+ new.iter()
+ .filter(|(name, c)| old.get(*name) != Some(*c))
+ .map(|(_, c)| c.clone().into())
.collect()
}
+
+#[cfg(test)]
+mod test {
+ use crate::config::ServerConfig;
+
+ use super::*;
+
+ // macro to create map or set literal
+ macro_rules! collection {
+ // map-like
+ ($($k:expr => $v:expr),* $(,)?) => {{
+ use std::iter::{Iterator, IntoIterator};
+ Iterator::collect(IntoIterator::into_iter([$(($k, $v),)*]))
+ }};
+ }
+
+ #[test]
+ fn test_calculate_events() {
+ struct Test {
+ old: Config,
+ new: Config,
+ }
+
+ let tests = [
+ Test {
+ old: Config {
+ server: Some(Default::default()),
+ client: None,
+ },
+ new: Config {
+ server: Some(Default::default()),
+ client: Some(Default::default()),
+ },
+ },
+ Test {
+ old: Config {
+ server: Some(ServerConfig {
+ bind_addr: String::from("127.0.0.1:2334"),
+ ..Default::default()
+ }),
+ client: None,
+ },
+ new: Config {
+ server: Some(ServerConfig {
+ bind_addr: String::from("127.0.0.1:2333"),
+ services: collection!(String::from("foo") => Default::default()),
+ ..Default::default()
+ }),
+ client: None,
+ },
+ },
+ Test {
+ old: Config {
+ server: Some(Default::default()),
+ client: None,
+ },
+ new: Config {
+ server: Some(ServerConfig {
+ services: collection!(String::from("foo") => Default::default()),
+ ..Default::default()
+ }),
+ client: None,
+ },
+ },
+ Test {
+ old: Config {
+ server: Some(ServerConfig {
+ services: collection!(String::from("foo") => Default::default()),
+ ..Default::default()
+ }),
+ client: None,
+ },
+ new: Config {
+ server: Some(Default::default()),
+ client: None,
+ },
+ },
+ ];
+ let expected = [
+ vec![ConfigChange::General(Box::new(tests[0].new.clone()))],
+ vec![ConfigChange::General(Box::new(tests[1].new.clone()))],
+ vec![ConfigChange::ServiceChange(ServiceChange::ServerAdd(
+ Default::default(),
+ ))],
+ vec![ConfigChange::ServiceChange(ServiceChange::ServerDelete(
+ String::from("foo"),
+ ))],
+ ];
+
+ assert_eq!(tests.len(), expected.len());
+
+ for i in 0..tests.len() {
+ let actual = calculate_events(&tests[i].old, &tests[i].new);
+ assert_eq!(actual, expected[i]);
+ }
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 2dacfa9..2097c16 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -10,7 +10,7 @@ mod transport;
pub use cli::Cli;
use cli::KeypairType;
pub use config::Config;
-use config_watcher::ServiceChangeEvent;
+use config_watcher::ServiceChange;
pub use constants::UDP_BUFFER_SIZE;
use anyhow::Result;
@@ -27,7 +27,7 @@ mod server;
#[cfg(feature = "server")]
use server::run_server;
-use crate::config_watcher::{ConfigChangeEvent, ConfigWatcherHandle};
+use crate::config_watcher::{ConfigChange, ConfigWatcherHandle};
const DEFAULT_CURVE: KeypairType = KeypairType::X25519;
@@ -76,12 +76,11 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
let (shutdown_tx, _) = broadcast::channel(1);
// (The join handle of the last instance, The service update channel sender)
- let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChangeEvent>)> =
- None;
+ let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChange>)> = None;
while let Some(e) = cfg_watcher.event_rx.recv().await {
match e {
- ConfigChangeEvent::General(config) => {
+ ConfigChange::General(config) => {
if let Some((i, _)) = last_instance {
info!("General configuration change detected. Restarting...");
shutdown_tx.send(true)?;
@@ -102,7 +101,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
service_update_tx,
));
}
- ConfigChangeEvent::ServiceChange(service_event) => {
+ ConfigChange::ServiceChange(service_event) => {
info!("Service change detcted. {:?}", service_event);
if let Some((_, service_update_tx)) = &last_instance {
let _ = service_update_tx.send(service_event).await;
@@ -110,6 +109,9 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
}
}
}
+
+ let _ = shutdown_tx.send(true);
+
Ok(())
}
@@ -117,7 +119,7 @@ async fn run_instance(
config: Config,
args: Cli,
shutdown_rx: broadcast::Receiver<bool>,
- service_update: mpsc::Receiver<ServiceChangeEvent>,
+ service_update: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
match determine_run_mode(&config, &args) {
RunMode::Undetermine => panic!("Cannot determine running as a server or a client"),
diff --git a/src/server.rs b/src/server.rs
index f8bb2b2..07b3e6a 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,5 +1,5 @@
use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
-use crate::config_watcher::ServiceChangeEvent;
+use crate::config_watcher::ServiceChange;
use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
use crate::multi_map::MultiMap;
use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
@@ -39,7 +39,7 @@ const CHAN_SIZE: usize = 2048; // The capacity of various chans
pub async fn run_server(
config: &Config,
shutdown_rx: broadcast::Receiver<bool>,
- service_rx: mpsc::Receiver<ServiceChangeEvent>,
+ service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
let config = match &config.server {
Some(config) => config,
@@ -122,7 +122,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
pub async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
- mut service_rx: mpsc::Receiver<ServiceChangeEvent>,
+ mut service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
// Listen at `server.bind_addr`
let l = self
@@ -193,9 +193,9 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
Ok(())
}
- async fn handle_hot_reload(&mut self, e: ServiceChangeEvent) {
+ async fn handle_hot_reload(&mut self, e: ServiceChange) {
match e {
- ServiceChangeEvent::ServerAdd(s) => {
+ ServiceChange::ServerAdd(s) => {
let hash = protocol::digest(s.name.as_bytes());
let mut wg = self.services.write().await;
let _ = wg.insert(hash, s);
@@ -203,7 +203,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
let mut wg = self.control_channels.write().await;
let _ = wg.remove1(&hash);
}
- ServiceChangeEvent::ServerDelete(s) => {
+ ServiceChange::ServerDelete(s) => {
let hash = protocol::digest(s.as_bytes());
let _ = self.services.write().await.remove(&hash);
@@ -340,11 +340,8 @@ async fn do_data_channel_handshake<T: 'static + Transport>(
}
pub struct ControlChannelHandle<T: Transport> {
- // Shutdown the control channel.
- // Not used for now, but can be used for hot reloading
- #[allow(dead_code)]
- shutdown_tx: broadcast::Sender<bool>,
- //data_ch_req_tx: mpsc::Sender<bool>,
+ // Shutdown the control channel by dropping it
+ _shutdown_tx: broadcast::Sender<bool>,
data_ch_tx: mpsc::Sender<T::Stream>,
}
@@ -359,7 +356,7 @@ where
// Save the name string for logging
let name = service.name.clone();
- // Create a shutdown channel. The sender is not used for now, but for future use
+ // Create a shutdown channel
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
// Store data channels
@@ -417,15 +414,10 @@ where
});
ControlChannelHandle {
- shutdown_tx,
+ _shutdown_tx: shutdown_tx,
data_ch_tx,
}
}
-
- #[allow(dead_code)]
- fn shutdown(self) {
- let _ = self.shutdown_tx.send(true);
- }
}
// Control channel, using T as the transport layer. P is TcpStream or UdpTraffic