aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/client.rs19
-rw-r--r--src/helper.rs8
-rw-r--r--src/lib.rs9
-rw-r--r--src/server.rs16
-rw-r--r--src/transport/mod.rs23
-rw-r--r--src/transport/native_tls.rs (renamed from src/transport/tls.rs)13
-rw-r--r--src/transport/rustls.rs156
-rw-r--r--src/transport/websocket.rs14
8 files changed, 221 insertions, 37 deletions
diff --git a/src/client.rs b/src/client.rs
index 95a5a74..2564869 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -8,8 +8,9 @@ use crate::protocol::{
};
use crate::transport::{AddrMaybeCached, SocketOpts, TcpTransport, Transport};
use anyhow::{anyhow, bail, Context, Result};
+use backoff::backoff::Backoff;
+use backoff::future::retry_notify;
use backoff::ExponentialBackoff;
-use backoff::{backoff::Backoff, future::retry_notify};
use bytes::{Bytes, BytesMut};
use std::collections::HashMap;
use std::net::SocketAddr;
@@ -22,9 +23,9 @@ use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};
#[cfg(feature = "noise")]
use crate::transport::NoiseTransport;
-#[cfg(feature = "tls")]
+#[cfg(any(feature = "native-tls", feature = "rustls"))]
use crate::transport::TlsTransport;
-#[cfg(feature = "websocket")]
+#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
use crate::transport::WebsocketTransport;
use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
@@ -47,13 +48,13 @@ pub async fn run_client(
client.run(shutdown_rx, update_rx).await
}
TransportType::Tls => {
- #[cfg(feature = "tls")]
+ #[cfg(any(feature = "native-tls", feature = "rustls"))]
{
let mut client = Client::<TlsTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
}
- #[cfg(not(feature = "tls"))]
- crate::helper::feature_not_compile("tls")
+ #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
+ crate::helper::feature_neither_compile("native-tls", "rustls")
}
TransportType::Noise => {
#[cfg(feature = "noise")]
@@ -65,13 +66,13 @@ pub async fn run_client(
crate::helper::feature_not_compile("noise")
}
TransportType::Websocket => {
- #[cfg(feature = "websocket")]
+ #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
{
let mut client = Client::<WebsocketTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
}
- #[cfg(not(feature = "websocket"))]
- crate::helper::feature_not_compile("websocket")
+ #[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
+ crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
}
}
}
diff --git a/src/helper.rs b/src/helper.rs
index 7e1e5d3..a292969 100644
--- a/src/helper.rs
+++ b/src/helper.rs
@@ -43,6 +43,14 @@ pub fn feature_not_compile(feature: &str) -> ! {
)
}
+#[allow(dead_code)]
+pub fn feature_neither_compile(feature1: &str, feature2: &str) -> ! {
+ panic!(
+ "Neither of the feature '{}' or '{}' is compiled in this binary. Please re-compile rathole",
+ feature1, feature2
+ )
+}
+
pub async fn to_socket_addr<A: ToSocketAddrs>(addr: A) -> Result<SocketAddr> {
lookup_host(addr)
.await?
diff --git a/src/lib.rs b/src/lib.rs
index 7fb2fa6..65beb7f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -83,7 +83,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
if let Some((i, _)) = last_instance {
info!("General configuration change detected. Restarting...");
shutdown_tx.send(true)?;
- i.await?;
+ i.await??;
}
debug!("{:?}", config);
@@ -119,8 +119,8 @@ async fn run_instance(
args: Cli,
shutdown_rx: broadcast::Receiver<bool>,
service_update: mpsc::Receiver<ConfigChange>,
-) {
- let ret: Result<()> = match determine_run_mode(&config, &args) {
+) -> Result<()> {
+ match determine_run_mode(&config, &args) {
RunMode::Undetermine => panic!("Cannot determine running as a server or a client"),
RunMode::Client => {
#[cfg(not(feature = "client"))]
@@ -134,8 +134,7 @@ async fn run_instance(
#[cfg(feature = "server")]
run_server(config, shutdown_rx, service_update).await
}
- };
- ret.unwrap();
+ }
}
#[derive(PartialEq, Eq, Debug)]
diff --git a/src/server.rs b/src/server.rs
index 83ae976..a4c4948 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -25,9 +25,9 @@ use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span}
#[cfg(feature = "noise")]
use crate::transport::NoiseTransport;
-#[cfg(feature = "tls")]
+#[cfg(any(feature = "native-tls", feature = "rustls"))]
use crate::transport::TlsTransport;
-#[cfg(feature = "websocket")]
+#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
use crate::transport::WebsocketTransport;
type ServiceDigest = protocol::Digest; // SHA256 of a service name
@@ -57,13 +57,13 @@ pub async fn run_server(
server.run(shutdown_rx, update_rx).await?;
}
TransportType::Tls => {
- #[cfg(feature = "tls")]
+ #[cfg(any(feature = "native-tls", feature = "rustls"))]
{
let mut server = Server::<TlsTransport>::from(config).await?;
server.run(shutdown_rx, update_rx).await?;
}
- #[cfg(not(feature = "tls"))]
- crate::helper::feature_not_compile("tls")
+ #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
+ crate::helper::feature_neither_compile("native-tls", "rustls")
}
TransportType::Noise => {
#[cfg(feature = "noise")]
@@ -75,13 +75,13 @@ pub async fn run_server(
crate::helper::feature_not_compile("noise")
}
TransportType::Websocket => {
- #[cfg(feature = "websocket")]
+ #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
{
let mut server = Server::<WebsocketTransport>::from(config).await?;
server.run(shutdown_rx, update_rx).await?;
}
- #[cfg(not(feature = "websocket"))]
- crate::helper::feature_not_compile("websocket")
+ #[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
+ crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
}
}
diff --git a/src/transport/mod.rs b/src/transport/mod.rs
index 38682a6..26d357f 100644
--- a/src/transport/mod.rs
+++ b/src/transport/mod.rs
@@ -69,19 +69,30 @@ pub trait Transport: Debug + Send + Sync {
mod tcp;
pub use tcp::TcpTransport;
-#[cfg(feature = "tls")]
-mod tls;
-#[cfg(feature = "tls")]
-pub use tls::TlsTransport;
+
+#[cfg(all(feature = "native-tls", feature = "rustls"))]
+compile_error!("Only one of `native-tls` and `rustls` can be enabled");
+
+#[cfg(feature = "native-tls")]
+mod native_tls;
+#[cfg(feature = "native-tls")]
+use native_tls as tls;
+#[cfg(feature = "rustls")]
+mod rustls;
+#[cfg(feature = "rustls")]
+use rustls as tls;
+
+#[cfg(any(feature = "native-tls", feature = "rustls"))]
+pub(crate) use tls::TlsTransport;
#[cfg(feature = "noise")]
mod noise;
#[cfg(feature = "noise")]
pub use noise::NoiseTransport;
-#[cfg(feature = "websocket")]
+#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
mod websocket;
-#[cfg(feature = "websocket")]
+#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
pub use websocket::WebsocketTransport;
#[derive(Debug, Clone, Copy)]
diff --git a/src/transport/tls.rs b/src/transport/native_tls.rs
index 918af04..40afd50 100644
--- a/src/transport/tls.rs
+++ b/src/transport/native_tls.rs
@@ -1,14 +1,14 @@
-use std::net::SocketAddr;
-
-use super::{AddrMaybeCached, SocketOpts, TcpTransport, Transport};
use crate::config::{TlsConfig, TransportConfig};
use crate::helper::host_port_pair;
+use crate::transport::{AddrMaybeCached, SocketOpts, TcpTransport, Transport};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use std::fs;
+use std::net::SocketAddr;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_native_tls::native_tls::{self, Certificate, Identity};
-use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
+pub(crate) use tokio_native_tls::TlsStream;
+use tokio_native_tls::{TlsAcceptor, TlsConnector};
#[derive(Debug)]
pub struct TlsTransport {
@@ -109,3 +109,8 @@ impl Transport for TlsTransport {
.await?)
}
}
+
+#[cfg(feature = "websocket-native-tls")]
+pub(crate) fn get_tcpstream(s: &TlsStream<TcpStream>) -> &TcpStream {
+ s.get_ref().get_ref().get_ref()
+}
diff --git a/src/transport/rustls.rs b/src/transport/rustls.rs
new file mode 100644
index 0000000..3ca4704
--- /dev/null
+++ b/src/transport/rustls.rs
@@ -0,0 +1,156 @@
+use crate::config::{TlsConfig, TransportConfig};
+use crate::helper::host_port_pair;
+use crate::transport::{AddrMaybeCached, SocketOpts, TcpTransport, Transport};
+use std::fmt::Debug;
+use std::fs;
+use std::net::SocketAddr;
+use std::sync::Arc;
+use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
+use tokio_rustls::rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName};
+
+use anyhow::{anyhow, Context, Result};
+use async_trait::async_trait;
+use p12::PFX;
+use tokio_rustls::rustls::{ClientConfig, RootCertStore, ServerConfig};
+pub(crate) use tokio_rustls::TlsStream;
+use tokio_rustls::{TlsAcceptor, TlsConnector};
+
+pub struct TlsTransport {
+ tcp: TcpTransport,
+ config: TlsConfig,
+ connector: Option<TlsConnector>,
+ tls_acceptor: Option<TlsAcceptor>,
+}
+
+// workaround for TlsConnector and TlsAcceptor not implementing Debug
+impl Debug for TlsTransport {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("TlsTransport")
+ .field("tcp", &self.tcp)
+ .field("config", &self.config)
+ .finish()
+ }
+}
+
+fn load_server_config(config: &TlsConfig) -> Result<Option<ServerConfig>> {
+ if let Some(pkcs12_path) = config.pkcs12.as_ref() {
+ let buf = fs::read(pkcs12_path)?;
+ let pfx = PFX::parse(buf.as_slice())?;
+ let pass = config.pkcs12_password.as_ref().unwrap();
+
+ let certs = pfx.cert_bags(pass)?;
+ let keys = pfx.key_bags(pass)?;
+
+ let chain: Vec<CertificateDer> = certs.into_iter().map(CertificateDer::from).collect();
+ let key = PrivatePkcs8KeyDer::from(keys.into_iter().next().unwrap());
+
+ Ok(Some(
+ ServerConfig::builder()
+ .with_no_client_auth()
+ .with_single_cert(chain, key.into())?,
+ ))
+ } else {
+ Ok(None)
+ }
+}
+
+fn load_client_config(config: &TlsConfig) -> Result<Option<ClientConfig>> {
+ let cert = if let Some(path) = config.trusted_root.as_ref() {
+ rustls_pemfile::certs(&mut std::io::BufReader::new(fs::File::open(path).unwrap()))
+ .map(|cert| cert.unwrap())
+ .next()
+ .with_context(|| "Failed to read certificate")?
+ } else {
+ // read from native
+ match rustls_native_certs::load_native_certs() {
+ Ok(certs) => certs.into_iter().next().unwrap(),
+ Err(e) => {
+ eprintln!("Failed to load native certs: {}", e);
+ return Ok(None);
+ }
+ }
+ };
+
+ let mut root_certs = RootCertStore::empty();
+ root_certs.add(cert).unwrap();
+
+ Ok(Some(
+ ClientConfig::builder()
+ .with_root_certificates(root_certs)
+ .with_no_client_auth(),
+ ))
+}
+
+#[async_trait]
+impl Transport for TlsTransport {
+ type Acceptor = TcpListener;
+ type RawStream = TcpStream;
+ type Stream = TlsStream<TcpStream>;
+
+ fn new(config: &TransportConfig) -> Result<Self> {
+ let tcp = TcpTransport::new(config)?;
+ let config = config
+ .tls
+ .as_ref()
+ .ok_or_else(|| anyhow!("Missing tls config"))?;
+
+ let connector = load_client_config(config)
+ .unwrap()
+ .map(|c| Arc::new(c).into());
+ let tls_acceptor = load_server_config(config)
+ .unwrap()
+ .map(|c| Arc::new(c).into());
+
+ Ok(TlsTransport {
+ tcp,
+ config: config.clone(),
+ connector,
+ tls_acceptor,
+ })
+ }
+
+ fn hint(conn: &Self::Stream, opt: SocketOpts) {
+ opt.apply(conn.get_ref().0);
+ }
+
+ async fn bind<A: ToSocketAddrs + Send + Sync>(&self, addr: A) -> Result<Self::Acceptor> {
+ let l = TcpListener::bind(addr)
+ .await
+ .with_context(|| "Failed to create tcp listener")?;
+ Ok(l)
+ }
+
+ async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)> {
+ self.tcp
+ .accept(a)
+ .await
+ .with_context(|| "Failed to accept TCP connection")
+ }
+
+ async fn handshake(&self, conn: Self::RawStream) -> Result<Self::Stream> {
+ let conn = self.tls_acceptor.as_ref().unwrap().accept(conn).await?;
+ Ok(tokio_rustls::TlsStream::Server(conn))
+ }
+
+ async fn connect(&self, addr: &AddrMaybeCached) -> Result<Self::Stream> {
+ let conn = self.tcp.connect(addr).await?;
+
+ let connector = self.connector.as_ref().unwrap();
+
+ let host_name = self
+ .config
+ .hostname
+ .as_deref()
+ .unwrap_or(host_port_pair(&addr.addr)?.0);
+
+ Ok(tokio_rustls::TlsStream::Client(
+ connector
+ .connect(ServerName::try_from(host_name)?.to_owned(), conn)
+ .await?,
+ ))
+ }
+}
+
+pub(crate) fn get_tcpstream(s: &TlsStream<TcpStream>) -> &TcpStream {
+ &s.get_ref().0
+}
diff --git a/src/transport/websocket.rs b/src/transport/websocket.rs
index ec6177d..228eff7 100644
--- a/src/transport/websocket.rs
+++ b/src/transport/websocket.rs
@@ -13,10 +13,14 @@ use futures_core::stream::Stream;
use futures_sink::Sink;
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
-use tokio_native_tls::TlsStream;
-use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
-use tokio_tungstenite::{accept_async_with_config, client_async_with_config};
-use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream};
+
+#[cfg(any(feature = "native-tls", feature = "rustls"))]
+use super::tls::get_tcpstream;
+#[cfg(any(feature = "native-tls", feature = "rustls"))]
+use super::tls::TlsStream;
+
+use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
+use tokio_tungstenite::{accept_async_with_config, client_async_with_config, WebSocketStream};
use tokio_util::io::StreamReader;
use url::Url;
@@ -30,7 +34,7 @@ impl TransportStream {
fn get_tcpstream(&self) -> &TcpStream {
match self {
TransportStream::Insecure(s) => s,
- TransportStream::Secure(s) => s.get_ref().get_ref().get_ref(),
+ TransportStream::Secure(s) => get_tcpstream(s),
}
}
}