diff options
author | 2024-02-18 17:17:17 +0800 | |
---|---|---|
committer | 2024-02-18 17:17:17 +0800 | |
commit | 4ac53a5a39e74d5eb12bee22d0fd4783acaae670 (patch) | |
tree | ca442e698e6870b7f6c5053b54fff82e2a95626e /src | |
parent | 7251759bdaf4b7d170575bdd6d2062bbd9f338bb (diff) | |
download | rathole-4ac53a5a39e74d5eb12bee22d0fd4783acaae670.tar.gz rathole-4ac53a5a39e74d5eb12bee22d0fd4783acaae670.tar.zst rathole-4ac53a5a39e74d5eb12bee22d0fd4783acaae670.zip |
feat: optional rustls support (#330)
* initial implementation of rustls support
* Refactor create_self_signed_cert.sh script
* resolve lint errors
* Fix handling of Option in tls.rs
* Update cargo-hack check command and feature dependencies
* fix missing point
* Add conditional check to skip test if client or server is not enabled
* clean up things
* fix for windows CI
* try fixing Windows CI
* Update src/main.rs
* Update src/transport/websocket.rs
* add missing messages
* split the tls mod
Co-authored-by: Ning Sun <n@sunng.info>
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 19 | ||||
-rw-r--r-- | src/helper.rs | 8 | ||||
-rw-r--r-- | src/lib.rs | 9 | ||||
-rw-r--r-- | src/server.rs | 16 | ||||
-rw-r--r-- | src/transport/mod.rs | 23 | ||||
-rw-r--r-- | src/transport/native_tls.rs (renamed from src/transport/tls.rs) | 13 | ||||
-rw-r--r-- | src/transport/rustls.rs | 156 | ||||
-rw-r--r-- | src/transport/websocket.rs | 14 |
8 files changed, 221 insertions, 37 deletions
diff --git a/src/client.rs b/src/client.rs index 95a5a74..2564869 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,8 +8,9 @@ use crate::protocol::{ }; use crate::transport::{AddrMaybeCached, SocketOpts, TcpTransport, Transport}; use anyhow::{anyhow, bail, Context, Result}; +use backoff::backoff::Backoff; +use backoff::future::retry_notify; use backoff::ExponentialBackoff; -use backoff::{backoff::Backoff, future::retry_notify}; use bytes::{Bytes, BytesMut}; use std::collections::HashMap; use std::net::SocketAddr; @@ -22,9 +23,9 @@ use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span}; #[cfg(feature = "noise")] use crate::transport::NoiseTransport; -#[cfg(feature = "tls")] +#[cfg(any(feature = "native-tls", feature = "rustls"))] use crate::transport::TlsTransport; -#[cfg(feature = "websocket")] +#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))] use crate::transport::WebsocketTransport; use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT}; @@ -47,13 +48,13 @@ pub async fn run_client( client.run(shutdown_rx, update_rx).await } TransportType::Tls => { - #[cfg(feature = "tls")] + #[cfg(any(feature = "native-tls", feature = "rustls"))] { let mut client = Client::<TlsTransport>::from(config).await?; client.run(shutdown_rx, update_rx).await } - #[cfg(not(feature = "tls"))] - crate::helper::feature_not_compile("tls") + #[cfg(not(any(feature = "native-tls", feature = "rustls")))] + crate::helper::feature_neither_compile("native-tls", "rustls") } TransportType::Noise => { #[cfg(feature = "noise")] @@ -65,13 +66,13 @@ pub async fn run_client( crate::helper::feature_not_compile("noise") } TransportType::Websocket => { - #[cfg(feature = "websocket")] + #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))] { let mut client = Client::<WebsocketTransport>::from(config).await?; client.run(shutdown_rx, update_rx).await } - #[cfg(not(feature = "websocket"))] - crate::helper::feature_not_compile("websocket") + #[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))] + crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls") } } } diff --git a/src/helper.rs b/src/helper.rs index 7e1e5d3..a292969 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -43,6 +43,14 @@ pub fn feature_not_compile(feature: &str) -> ! { ) } +#[allow(dead_code)] +pub fn feature_neither_compile(feature1: &str, feature2: &str) -> ! { + panic!( + "Neither of the feature '{}' or '{}' is compiled in this binary. Please re-compile rathole", + feature1, feature2 + ) +} + pub async fn to_socket_addr<A: ToSocketAddrs>(addr: A) -> Result<SocketAddr> { lookup_host(addr) .await? @@ -83,7 +83,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<() if let Some((i, _)) = last_instance { info!("General configuration change detected. Restarting..."); shutdown_tx.send(true)?; - i.await?; + i.await??; } debug!("{:?}", config); @@ -119,8 +119,8 @@ async fn run_instance( args: Cli, shutdown_rx: broadcast::Receiver<bool>, service_update: mpsc::Receiver<ConfigChange>, -) { - let ret: Result<()> = match determine_run_mode(&config, &args) { +) -> Result<()> { + match determine_run_mode(&config, &args) { RunMode::Undetermine => panic!("Cannot determine running as a server or a client"), RunMode::Client => { #[cfg(not(feature = "client"))] @@ -134,8 +134,7 @@ async fn run_instance( #[cfg(feature = "server")] run_server(config, shutdown_rx, service_update).await } - }; - ret.unwrap(); + } } #[derive(PartialEq, Eq, Debug)] diff --git a/src/server.rs b/src/server.rs index 83ae976..a4c4948 100644 --- a/src/server.rs +++ b/src/server.rs @@ -25,9 +25,9 @@ use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span} #[cfg(feature = "noise")] use crate::transport::NoiseTransport; -#[cfg(feature = "tls")] +#[cfg(any(feature = "native-tls", feature = "rustls"))] use crate::transport::TlsTransport; -#[cfg(feature = "websocket")] +#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))] use crate::transport::WebsocketTransport; type ServiceDigest = protocol::Digest; // SHA256 of a service name @@ -57,13 +57,13 @@ pub async fn run_server( server.run(shutdown_rx, update_rx).await?; } TransportType::Tls => { - #[cfg(feature = "tls")] + #[cfg(any(feature = "native-tls", feature = "rustls"))] { let mut server = Server::<TlsTransport>::from(config).await?; server.run(shutdown_rx, update_rx).await?; } - #[cfg(not(feature = "tls"))] - crate::helper::feature_not_compile("tls") + #[cfg(not(any(feature = "native-tls", feature = "rustls")))] + crate::helper::feature_neither_compile("native-tls", "rustls") } TransportType::Noise => { #[cfg(feature = "noise")] @@ -75,13 +75,13 @@ pub async fn run_server( crate::helper::feature_not_compile("noise") } TransportType::Websocket => { - #[cfg(feature = "websocket")] + #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))] { let mut server = Server::<WebsocketTransport>::from(config).await?; server.run(shutdown_rx, update_rx).await?; } - #[cfg(not(feature = "websocket"))] - crate::helper::feature_not_compile("websocket") + #[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))] + crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls") } } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 38682a6..26d357f 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -69,19 +69,30 @@ pub trait Transport: Debug + Send + Sync { mod tcp; pub use tcp::TcpTransport; -#[cfg(feature = "tls")] -mod tls; -#[cfg(feature = "tls")] -pub use tls::TlsTransport; + +#[cfg(all(feature = "native-tls", feature = "rustls"))] +compile_error!("Only one of `native-tls` and `rustls` can be enabled"); + +#[cfg(feature = "native-tls")] +mod native_tls; +#[cfg(feature = "native-tls")] +use native_tls as tls; +#[cfg(feature = "rustls")] +mod rustls; +#[cfg(feature = "rustls")] +use rustls as tls; + +#[cfg(any(feature = "native-tls", feature = "rustls"))] +pub(crate) use tls::TlsTransport; #[cfg(feature = "noise")] mod noise; #[cfg(feature = "noise")] pub use noise::NoiseTransport; -#[cfg(feature = "websocket")] +#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))] mod websocket; -#[cfg(feature = "websocket")] +#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))] pub use websocket::WebsocketTransport; #[derive(Debug, Clone, Copy)] diff --git a/src/transport/tls.rs b/src/transport/native_tls.rs index 918af04..40afd50 100644 --- a/src/transport/tls.rs +++ b/src/transport/native_tls.rs @@ -1,14 +1,14 @@ -use std::net::SocketAddr; - -use super::{AddrMaybeCached, SocketOpts, TcpTransport, Transport}; use crate::config::{TlsConfig, TransportConfig}; use crate::helper::host_port_pair; +use crate::transport::{AddrMaybeCached, SocketOpts, TcpTransport, Transport}; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use std::fs; +use std::net::SocketAddr; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio_native_tls::native_tls::{self, Certificate, Identity}; -use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream}; +pub(crate) use tokio_native_tls::TlsStream; +use tokio_native_tls::{TlsAcceptor, TlsConnector}; #[derive(Debug)] pub struct TlsTransport { @@ -109,3 +109,8 @@ impl Transport for TlsTransport { .await?) } } + +#[cfg(feature = "websocket-native-tls")] +pub(crate) fn get_tcpstream(s: &TlsStream<TcpStream>) -> &TcpStream { + s.get_ref().get_ref().get_ref() +} diff --git a/src/transport/rustls.rs b/src/transport/rustls.rs new file mode 100644 index 0000000..3ca4704 --- /dev/null +++ b/src/transport/rustls.rs @@ -0,0 +1,156 @@ +use crate::config::{TlsConfig, TransportConfig}; +use crate::helper::host_port_pair; +use crate::transport::{AddrMaybeCached, SocketOpts, TcpTransport, Transport}; +use std::fmt::Debug; +use std::fs; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName}; + +use anyhow::{anyhow, Context, Result}; +use async_trait::async_trait; +use p12::PFX; +use tokio_rustls::rustls::{ClientConfig, RootCertStore, ServerConfig}; +pub(crate) use tokio_rustls::TlsStream; +use tokio_rustls::{TlsAcceptor, TlsConnector}; + +pub struct TlsTransport { + tcp: TcpTransport, + config: TlsConfig, + connector: Option<TlsConnector>, + tls_acceptor: Option<TlsAcceptor>, +} + +// workaround for TlsConnector and TlsAcceptor not implementing Debug +impl Debug for TlsTransport { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TlsTransport") + .field("tcp", &self.tcp) + .field("config", &self.config) + .finish() + } +} + +fn load_server_config(config: &TlsConfig) -> Result<Option<ServerConfig>> { + if let Some(pkcs12_path) = config.pkcs12.as_ref() { + let buf = fs::read(pkcs12_path)?; + let pfx = PFX::parse(buf.as_slice())?; + let pass = config.pkcs12_password.as_ref().unwrap(); + + let certs = pfx.cert_bags(pass)?; + let keys = pfx.key_bags(pass)?; + + let chain: Vec<CertificateDer> = certs.into_iter().map(CertificateDer::from).collect(); + let key = PrivatePkcs8KeyDer::from(keys.into_iter().next().unwrap()); + + Ok(Some( + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(chain, key.into())?, + )) + } else { + Ok(None) + } +} + +fn load_client_config(config: &TlsConfig) -> Result<Option<ClientConfig>> { + let cert = if let Some(path) = config.trusted_root.as_ref() { + rustls_pemfile::certs(&mut std::io::BufReader::new(fs::File::open(path).unwrap())) + .map(|cert| cert.unwrap()) + .next() + .with_context(|| "Failed to read certificate")? + } else { + // read from native + match rustls_native_certs::load_native_certs() { + Ok(certs) => certs.into_iter().next().unwrap(), + Err(e) => { + eprintln!("Failed to load native certs: {}", e); + return Ok(None); + } + } + }; + + let mut root_certs = RootCertStore::empty(); + root_certs.add(cert).unwrap(); + + Ok(Some( + ClientConfig::builder() + .with_root_certificates(root_certs) + .with_no_client_auth(), + )) +} + +#[async_trait] +impl Transport for TlsTransport { + type Acceptor = TcpListener; + type RawStream = TcpStream; + type Stream = TlsStream<TcpStream>; + + fn new(config: &TransportConfig) -> Result<Self> { + let tcp = TcpTransport::new(config)?; + let config = config + .tls + .as_ref() + .ok_or_else(|| anyhow!("Missing tls config"))?; + + let connector = load_client_config(config) + .unwrap() + .map(|c| Arc::new(c).into()); + let tls_acceptor = load_server_config(config) + .unwrap() + .map(|c| Arc::new(c).into()); + + Ok(TlsTransport { + tcp, + config: config.clone(), + connector, + tls_acceptor, + }) + } + + fn hint(conn: &Self::Stream, opt: SocketOpts) { + opt.apply(conn.get_ref().0); + } + + async fn bind<A: ToSocketAddrs + Send + Sync>(&self, addr: A) -> Result<Self::Acceptor> { + let l = TcpListener::bind(addr) + .await + .with_context(|| "Failed to create tcp listener")?; + Ok(l) + } + + async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)> { + self.tcp + .accept(a) + .await + .with_context(|| "Failed to accept TCP connection") + } + + async fn handshake(&self, conn: Self::RawStream) -> Result<Self::Stream> { + let conn = self.tls_acceptor.as_ref().unwrap().accept(conn).await?; + Ok(tokio_rustls::TlsStream::Server(conn)) + } + + async fn connect(&self, addr: &AddrMaybeCached) -> Result<Self::Stream> { + let conn = self.tcp.connect(addr).await?; + + let connector = self.connector.as_ref().unwrap(); + + let host_name = self + .config + .hostname + .as_deref() + .unwrap_or(host_port_pair(&addr.addr)?.0); + + Ok(tokio_rustls::TlsStream::Client( + connector + .connect(ServerName::try_from(host_name)?.to_owned(), conn) + .await?, + )) + } +} + +pub(crate) fn get_tcpstream(s: &TlsStream<TcpStream>) -> &TcpStream { + &s.get_ref().0 +} diff --git a/src/transport/websocket.rs b/src/transport/websocket.rs index ec6177d..228eff7 100644 --- a/src/transport/websocket.rs +++ b/src/transport/websocket.rs @@ -13,10 +13,14 @@ use futures_core::stream::Stream; use futures_sink::Sink; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; -use tokio_native_tls::TlsStream; -use tokio_tungstenite::tungstenite::protocol::WebSocketConfig; -use tokio_tungstenite::{accept_async_with_config, client_async_with_config}; -use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream}; + +#[cfg(any(feature = "native-tls", feature = "rustls"))] +use super::tls::get_tcpstream; +#[cfg(any(feature = "native-tls", feature = "rustls"))] +use super::tls::TlsStream; + +use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig}; +use tokio_tungstenite::{accept_async_with_config, client_async_with_config, WebSocketStream}; use tokio_util::io::StreamReader; use url::Url; @@ -30,7 +34,7 @@ impl TransportStream { fn get_tcpstream(&self) -> &TcpStream { match self { TransportStream::Insecure(s) => s, - TransportStream::Secure(s) => s.get_ref().get_ref().get_ref(), + TransportStream::Secure(s) => get_tcpstream(s), } } } |