aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Yujia Qiao <rapiz3142@gmail.com> 2022-09-15 19:40:15 +0800
committerGravatar GitHub <noreply@github.com> 2022-09-15 11:40:15 +0000
commitea01c42da70442d71ed9b98210ef39e2779859e7 (patch)
tree415b54e93fd6e1a91752eb3c420402bb27c092e7
parent187f4f033509cb616582d0645ff5cf1ea77bac76 (diff)
downloadrathole-ea01c42da70442d71ed9b98210ef39e2779859e7.tar.gz
rathole-ea01c42da70442d71ed9b98210ef39e2779859e7.tar.zst
rathole-ea01c42da70442d71ed9b98210ef39e2779859e7.zip
refactor: ConfigChange (#191)
-rw-r--r--src/client.rs52
-rw-r--r--src/config_watcher.rs172
-rw-r--r--src/lib.rs11
-rw-r--r--src/server.rs48
4 files changed, 146 insertions, 137 deletions
diff --git a/src/client.rs b/src/client.rs
index 1149bec..45d3d85 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1,5 +1,5 @@
use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
-use crate::config_watcher::ServiceChange;
+use crate::config_watcher::{ClientServiceChange, ConfigChange};
use crate::helper::udp_connect;
use crate::protocol::Hello::{self, *};
use crate::protocol::{
@@ -31,7 +31,7 @@ use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE
pub async fn run_client(
config: Config,
shutdown_rx: broadcast::Receiver<bool>,
- service_rx: mpsc::Receiver<ServiceChange>,
+ update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
let config = config.client.ok_or_else(|| {
anyhow!(
@@ -42,13 +42,13 @@ pub async fn run_client(
match config.transport.transport_type {
TransportType::Tcp => {
let mut client = Client::<TcpTransport>::from(config).await?;
- client.run(shutdown_rx, service_rx).await
+ client.run(shutdown_rx, update_rx).await
}
TransportType::Tls => {
#[cfg(feature = "tls")]
{
let mut client = Client::<TlsTransport>::from(config).await?;
- client.run(shutdown_rx, service_rx).await
+ client.run(shutdown_rx, update_rx).await
}
#[cfg(not(feature = "tls"))]
crate::helper::feature_not_compile("tls")
@@ -57,7 +57,7 @@ pub async fn run_client(
#[cfg(feature = "noise")]
{
let mut client = Client::<NoiseTransport>::from(config).await?;
- client.run(shutdown_rx, service_rx).await
+ client.run(shutdown_rx, update_rx).await
}
#[cfg(not(feature = "noise"))]
crate::helper::feature_not_compile("noise")
@@ -91,7 +91,7 @@ impl<T: 'static + Transport> Client<T> {
async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
- mut service_rx: mpsc::Receiver<ServiceChange>,
+ mut update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
for (name, config) in &self.config.services {
// Create a control channel for each service defined
@@ -116,24 +116,9 @@ impl<T: 'static + Transport> Client<T> {
}
break;
},
- e = service_rx.recv() => {
+ e = update_rx.recv() => {
if let Some(e) = e {
- match e {
- ServiceChange::ClientAdd(s)=> {
- let name = s.name.clone();
- let handle = ControlChannelHandle::new(
- s,
- self.config.remote_addr.clone(),
- self.transport.clone(),
- self.config.heartbeat_timeout
- );
- let _ = self.service_handles.insert(name, handle);
- },
- ServiceChange::ClientDelete(s)=> {
- let _ = self.service_handles.remove(&s);
- },
- _ => ()
- }
+ self.handle_hot_reload(e).await;
}
}
}
@@ -146,6 +131,27 @@ impl<T: 'static + Transport> Client<T> {
Ok(())
}
+
+ async fn handle_hot_reload(&mut self, e: ConfigChange) {
+ match e {
+ ConfigChange::ClientChange(client_change) => match client_change {
+ ClientServiceChange::Add(cfg) => {
+ let name = cfg.name.clone();
+ let handle = ControlChannelHandle::new(
+ cfg,
+ self.config.remote_addr.clone(),
+ self.transport.clone(),
+ self.config.heartbeat_timeout,
+ );
+ let _ = self.service_handles.insert(name, handle);
+ }
+ ClientServiceChange::Delete(s) => {
+ let _ = self.service_handles.remove(&s);
+ }
+ },
+ ignored => warn!("Ignored {:?} since running as a client", ignored),
+ }
+ }
}
struct RunDataChannelArgs<T: Transport> {
diff --git a/src/config_watcher.rs b/src/config_watcher.rs
index 25423cf..993fdcc 100644
--- a/src/config_watcher.rs
+++ b/src/config_watcher.rs
@@ -14,36 +14,30 @@ use tracing::{error, info, instrument};
#[cfg(feature = "notify")]
use notify::{EventKind, RecursiveMode, Watcher};
-#[derive(Debug, PartialEq, Eq)]
+#[derive(Debug, PartialEq, Eq, Clone)]
pub enum ConfigChange {
General(Box<Config>), // Trigger a full restart
- ServiceChange(ServiceChange),
+ ServerChange(ServerServiceChange),
+ ClientChange(ClientServiceChange),
}
-#[derive(Debug, PartialEq, Eq)]
-pub enum ServiceChange {
- ClientAdd(ClientServiceConfig),
- ClientDelete(String),
- ServerAdd(ServerServiceConfig),
- ServerDelete(String),
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub enum ClientServiceChange {
+ Add(ClientServiceConfig),
+ Delete(String),
}
-impl From<ClientServiceConfig> for ServiceChange {
- fn from(c: ClientServiceConfig) -> Self {
- ServiceChange::ClientAdd(c)
- }
-}
-
-impl From<ServerServiceConfig> for ServiceChange {
- fn from(c: ServerServiceConfig) -> Self {
- ServiceChange::ServerAdd(c)
- }
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub enum ServerServiceChange {
+ Add(ServerServiceConfig),
+ Delete(String),
}
trait InstanceConfig: Clone {
- type ServiceConfig: Into<ServiceChange> + PartialEq + Clone;
+ type ServiceConfig: PartialEq + Eq + Clone;
fn equal_without_service(&self, rhs: &Self) -> bool;
- fn to_service_change_delete(s: String) -> ServiceChange;
+ fn service_delete_change(s: String) -> ConfigChange;
+ fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange;
fn get_services(&self) -> &HashMap<String, Self::ServiceConfig>;
}
@@ -62,8 +56,11 @@ impl InstanceConfig for ServerConfig {
left == right
}
- fn to_service_change_delete(s: String) -> ServiceChange {
- ServiceChange::ServerDelete(s)
+ fn service_delete_change(s: String) -> ConfigChange {
+ ConfigChange::ServerChange(ServerServiceChange::Delete(s))
+ }
+ fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange {
+ ConfigChange::ServerChange(ServerServiceChange::Add(cfg))
}
fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
&self.services
@@ -85,8 +82,11 @@ impl InstanceConfig for ClientConfig {
left == right
}
- fn to_service_change_delete(s: String) -> ServiceChange {
- ServiceChange::ClientDelete(s)
+ fn service_delete_change(s: String) -> ConfigChange {
+ ConfigChange::ClientChange(ClientServiceChange::Delete(s))
+ }
+ fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange {
+ ConfigChange::ClientChange(ClientServiceChange::Add(cfg))
}
fn get_services(&self) -> &HashMap<String, Self::ServiceConfig> {
&self.services
@@ -180,8 +180,9 @@ async fn config_watcher(
}
};
- for event in calculate_events(&old, &new) {
- event_tx.send(event)?;
+ let events = calculate_events(&old, &new).into_iter().flatten();
+ for event in events {
+ event_tx.send(event)?;
}
old = new;
@@ -198,42 +199,40 @@ async fn config_watcher(
Ok(())
}
-fn calculate_events(old: &Config, new: &Config) -> Vec<ConfigChange> {
+fn calculate_events(old: &Config, new: &Config) -> Option<Vec<ConfigChange>> {
if old == new {
- return vec![];
+ return None;
+ }
+
+ if (old.server.is_some() != new.server.is_some())
+ || (old.client.is_some() != new.client.is_some())
+ {
+ return Some(vec![ConfigChange::General(Box::new(new.clone()))]);
}
let mut ret = vec![];
if old.server != new.server {
- if old.server.is_some() != new.server.is_some() {
- return vec![ConfigChange::General(Box::new(new.clone()))];
- } else {
- match calculate_instance_config_events(
- old.server.as_ref().unwrap(),
- new.server.as_ref().unwrap(),
- ) {
- Some(mut v) => ret.append(&mut v),
- None => return vec![ConfigChange::General(Box::new(new.clone()))],
- }
+ match calculate_instance_config_events(
+ old.server.as_ref().unwrap(),
+ new.server.as_ref().unwrap(),
+ ) {
+ Some(mut v) => ret.append(&mut v),
+ None => return Some(vec![ConfigChange::General(Box::new(new.clone()))]),
}
}
if old.client != new.client {
- if old.client.is_some() != new.client.is_some() {
- return vec![ConfigChange::General(Box::new(new.clone()))];
- } else {
- match calculate_instance_config_events(
- old.client.as_ref().unwrap(),
- new.client.as_ref().unwrap(),
- ) {
- Some(mut v) => ret.append(&mut v),
- None => return vec![ConfigChange::General(Box::new(new.clone()))],
- }
+ match calculate_instance_config_events(
+ old.client.as_ref().unwrap(),
+ new.client.as_ref().unwrap(),
+ ) {
+ Some(mut v) => ret.append(&mut v),
+ None => return Some(vec![ConfigChange::General(Box::new(new.clone()))]),
}
}
- ret
+ Some(ret)
}
// None indicates a General change needed
@@ -248,31 +247,17 @@ fn calculate_instance_config_events<T: InstanceConfig>(
let old = old.get_services();
let new = new.get_services();
- let mut v = vec![];
- v.append(&mut calculate_service_delete_events::<T>(old, new));
- v.append(&mut calculate_service_add_events(old, new));
-
- Some(v.into_iter().map(ConfigChange::ServiceChange).collect())
-}
-
-fn calculate_service_delete_events<T: InstanceConfig>(
- old: &HashMap<String, T::ServiceConfig>,
- new: &HashMap<String, T::ServiceConfig>,
-) -> Vec<ServiceChange> {
- old.keys()
+ let deletions = old
+ .keys()
.filter(|&name| new.get(name).is_none())
- .map(|x| T::to_service_change_delete(x.to_owned()))
- .collect()
-}
+ .map(|x| T::service_delete_change(x.to_owned()));
-fn calculate_service_add_events<T: PartialEq + Clone + Into<ServiceChange>>(
- old: &HashMap<String, T>,
- new: &HashMap<String, T>,
-) -> Vec<ServiceChange> {
- new.iter()
+ let addition = new
+ .iter()
.filter(|(name, c)| old.get(*name) != Some(*c))
- .map(|(_, c)| c.clone().into())
- .collect()
+ .map(|(_, c)| T::service_add_change(c.clone()));
+
+ Some(deletions.chain(addition).collect())
}
#[cfg(test)]
@@ -378,23 +363,23 @@ mod test {
let mut expected = [
vec![ConfigChange::General(Box::new(tests[0].new.clone()))],
vec![ConfigChange::General(Box::new(tests[1].new.clone()))],
- vec![ConfigChange::ServiceChange(ServiceChange::ServerAdd(
+ vec![ConfigChange::ServerChange(ServerServiceChange::Add(
Default::default(),
))],
- vec![ConfigChange::ServiceChange(ServiceChange::ServerDelete(
+ vec![ConfigChange::ServerChange(ServerServiceChange::Delete(
String::from("foo"),
))],
vec![
- ConfigChange::ServiceChange(ServiceChange::ServerDelete(String::from("foo1"))),
- ConfigChange::ServiceChange(ServiceChange::ServerAdd(
+ ConfigChange::ServerChange(ServerServiceChange::Delete(String::from("foo1"))),
+ ConfigChange::ServerChange(ServerServiceChange::Add(
tests[4].new.server.as_ref().unwrap().services["bar1"].clone(),
)),
- ConfigChange::ServiceChange(ServiceChange::ClientDelete(String::from("foo1"))),
- ConfigChange::ServiceChange(ServiceChange::ClientDelete(String::from("foo2"))),
- ConfigChange::ServiceChange(ServiceChange::ClientAdd(
+ ConfigChange::ClientChange(ClientServiceChange::Delete(String::from("foo1"))),
+ ConfigChange::ClientChange(ClientServiceChange::Delete(String::from("foo2"))),
+ ConfigChange::ClientChange(ClientServiceChange::Add(
tests[4].new.client.as_ref().unwrap().services["bar1"].clone(),
)),
- ConfigChange::ServiceChange(ServiceChange::ClientAdd(
+ ConfigChange::ClientChange(ClientServiceChange::Add(
tests[4].new.client.as_ref().unwrap().services["bar2"].clone(),
)),
],
@@ -403,16 +388,18 @@ mod test {
assert_eq!(tests.len(), expected.len());
for i in 0..tests.len() {
- let mut actual = calculate_events(&tests[i].old, &tests[i].new);
+ let mut actual = calculate_events(&tests[i].old, &tests[i].new).unwrap();
let get_key = |x: &ConfigChange| -> String {
match x {
ConfigChange::General(_) => String::from("g"),
- ConfigChange::ServiceChange(sc) => match sc {
- ServiceChange::ClientAdd(c) => "c_add_".to_owned() + &c.name,
- ServiceChange::ClientDelete(s) => "c_del_".to_owned() + s,
- ServiceChange::ServerAdd(c) => "s_add_".to_owned() + &c.name,
- ServiceChange::ServerDelete(s) => "s_del_".to_owned() + s,
+ ConfigChange::ServerChange(sc) => match sc {
+ ServerServiceChange::Add(c) => "s_add_".to_owned() + &c.name,
+ ServerServiceChange::Delete(s) => "s_del_".to_owned() + s,
+ },
+ ConfigChange::ClientChange(sc) => match sc {
+ ClientServiceChange::Add(c) => "c_add_".to_owned() + &c.name,
+ ClientServiceChange::Delete(s) => "c_del_".to_owned() + s,
},
}
};
@@ -422,5 +409,20 @@ mod test {
assert_eq!(actual, expected[i]);
}
+
+ // No changes
+ assert_eq!(
+ calculate_events(
+ &Config {
+ server: Default::default(),
+ client: None,
+ },
+ &Config {
+ server: Default::default(),
+ client: None,
+ },
+ ),
+ None
+ );
}
}
diff --git a/src/lib.rs b/src/lib.rs
index c31da23..7fb2fa6 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -10,7 +10,6 @@ mod transport;
pub use cli::Cli;
use cli::KeypairType;
pub use config::Config;
-use config_watcher::ServiceChange;
pub use constants::UDP_BUFFER_SIZE;
use anyhow::Result;
@@ -76,7 +75,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
let (shutdown_tx, _) = broadcast::channel(1);
// (The join handle of the last instance, The service update channel sender)
- let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChange>)> = None;
+ let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ConfigChange>)> = None;
while let Some(e) = cfg_watcher.event_rx.recv().await {
match e {
@@ -101,10 +100,10 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
service_update_tx,
));
}
- ConfigChange::ServiceChange(service_event) => {
- info!("Service change detcted. {:?}", service_event);
+ ev => {
+ info!("Service change detected. {:?}", ev);
if let Some((_, service_update_tx)) = &last_instance {
- let _ = service_update_tx.send(service_event).await;
+ let _ = service_update_tx.send(ev).await;
}
}
}
@@ -119,7 +118,7 @@ async fn run_instance(
config: Config,
args: Cli,
shutdown_rx: broadcast::Receiver<bool>,
- service_update: mpsc::Receiver<ServiceChange>,
+ service_update: mpsc::Receiver<ConfigChange>,
) {
let ret: Result<()> = match determine_run_mode(&config, &args) {
RunMode::Undetermine => panic!("Cannot determine running as a server or a client"),
diff --git a/src/server.rs b/src/server.rs
index 4c08640..6ad91ee 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,5 +1,5 @@
use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
-use crate::config_watcher::ServiceChange;
+use crate::config_watcher::{ConfigChange, ServerServiceChange};
use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
use crate::helper::retry_notify_with_deadline;
use crate::multi_map::MultiMap;
@@ -40,7 +40,7 @@ const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake
pub async fn run_server(
config: Config,
shutdown_rx: broadcast::Receiver<bool>,
- service_rx: mpsc::Receiver<ServiceChange>,
+ update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
let config = match config.server {
Some(config) => config,
@@ -52,13 +52,13 @@ pub async fn run_server(
match config.transport.transport_type {
TransportType::Tcp => {
let mut server = Server::<TcpTransport>::from(config).await?;
- server.run(shutdown_rx, service_rx).await?;
+ server.run(shutdown_rx, update_rx).await?;
}
TransportType::Tls => {
#[cfg(feature = "tls")]
{
let mut server = Server::<TlsTransport>::from(config).await?;
- server.run(shutdown_rx, service_rx).await?;
+ server.run(shutdown_rx, update_rx).await?;
}
#[cfg(not(feature = "tls"))]
crate::helper::feature_not_compile("tls")
@@ -67,7 +67,7 @@ pub async fn run_server(
#[cfg(feature = "noise")]
{
let mut server = Server::<NoiseTransport>::from(config).await?;
- server.run(shutdown_rx, service_rx).await?;
+ server.run(shutdown_rx, update_rx).await?;
}
#[cfg(not(feature = "noise"))]
crate::helper::feature_not_compile("noise")
@@ -124,7 +124,7 @@ impl<T: 'static + Transport> Server<T> {
pub async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
- mut service_rx: mpsc::Receiver<ServiceChange>,
+ mut update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
// Listen at `server.bind_addr`
let l = self
@@ -198,7 +198,7 @@ impl<T: 'static + Transport> Server<T> {
info!("Shuting down gracefully...");
break;
},
- e = service_rx.recv() => {
+ e = update_rx.recv() => {
if let Some(e) = e {
self.handle_hot_reload(e).await;
}
@@ -211,24 +211,26 @@ impl<T: 'static + Transport> Server<T> {
Ok(())
}
- async fn handle_hot_reload(&mut self, e: ServiceChange) {
+ async fn handle_hot_reload(&mut self, e: ConfigChange) {
match e {
- ServiceChange::ServerAdd(s) => {
- let hash = protocol::digest(s.name.as_bytes());
- let mut wg = self.services.write().await;
- let _ = wg.insert(hash, s);
-
- let mut wg = self.control_channels.write().await;
- let _ = wg.remove1(&hash);
- }
- ServiceChange::ServerDelete(s) => {
- let hash = protocol::digest(s.as_bytes());
- let _ = self.services.write().await.remove(&hash);
+ ConfigChange::ServerChange(server_change) => match server_change {
+ ServerServiceChange::Add(cfg) => {
+ let hash = protocol::digest(cfg.name.as_bytes());
+ let mut wg = self.services.write().await;
+ let _ = wg.insert(hash, cfg);
+
+ let mut wg = self.control_channels.write().await;
+ let _ = wg.remove1(&hash);
+ }
+ ServerServiceChange::Delete(s) => {
+ let hash = protocol::digest(s.as_bytes());
+ let _ = self.services.write().await.remove(&hash);
- let mut wg = self.control_channels.write().await;
- let _ = wg.remove1(&hash);
- }
- _ => (),
+ let mut wg = self.control_channels.write().await;
+ let _ = wg.remove1(&hash);
+ }
+ },
+ ignored => warn!("Ignored {:?} since running as a server", ignored),
}
}
}