pub const HASH_WIDTH_IN_BYTES: usize = 32; use anyhow::{bail, Context, Result}; use bytes::{Bytes, BytesMut}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use std::net::SocketAddr; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::trace; type ProtocolVersion = u8; const _PROTO_V0: u8 = 0u8; const PROTO_V1: u8 = 1u8; pub const CURRENT_PROTO_VERSION: ProtocolVersion = PROTO_V1; pub type Digest = [u8; HASH_WIDTH_IN_BYTES]; #[derive(Deserialize, Serialize, Debug)] pub enum Hello { ControlChannelHello(ProtocolVersion, Digest), // sha256sum(service name) or a nonce DataChannelHello(ProtocolVersion, Digest), // token provided by CreateDataChannel } #[derive(Deserialize, Serialize, Debug)] pub struct Auth(pub Digest); #[derive(Deserialize, Serialize, Debug)] pub enum Ack { Ok, ServiceNotExist, AuthFailed, } impl std::fmt::Display for Ack { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "{}", match self { Ack::Ok => "Ok", Ack::ServiceNotExist => "Service not exist", Ack::AuthFailed => "Incorrect token", } ) } } #[derive(Deserialize, Serialize, Debug)] pub enum ControlChannelCmd { CreateDataChannel, HeartBeat, } #[derive(Deserialize, Serialize, Debug)] pub enum DataChannelCmd { StartForwardTcp, StartForwardUdp, } type UdpPacketLen = u16; // `u16` should be enough for any practical UDP traffic on the Internet #[derive(Deserialize, Serialize, Debug)] struct UdpHeader { from: SocketAddr, len: UdpPacketLen, } #[derive(Debug)] pub struct UdpTraffic { pub from: SocketAddr, pub data: Bytes, } impl UdpTraffic { pub async fn write(&self, writer: &mut T) -> Result<()> { let hdr = UdpHeader { from: self.from, len: self.data.len() as UdpPacketLen, }; let v = bincode::serialize(&hdr).unwrap(); trace!("Write {:?} of length {}", hdr, v.len()); writer.write_u8(v.len() as u8).await?; writer.write_all(&v).await?; writer.write_all(&self.data).await?; Ok(()) } #[allow(dead_code)] pub async fn write_slice( writer: &mut T, from: SocketAddr, data: &[u8], ) -> Result<()> { let hdr = UdpHeader { from, len: data.len() as UdpPacketLen, }; let v = bincode::serialize(&hdr).unwrap(); trace!("Write {:?} of length {}", hdr, v.len()); writer.write_u8(v.len() as u8).await?; writer.write_all(&v).await?; writer.write_all(data).await?; Ok(()) } pub async fn read(reader: &mut T, hdr_len: u8) -> Result { let mut buf = vec![0; hdr_len as usize]; reader .read_exact(&mut buf) .await .with_context(|| "Failed to read udp header")?; let hdr: UdpHeader = bincode::deserialize(&buf).with_context(|| "Failed to deserialize UdpHeader")?; trace!("hdr {:?}", hdr); let mut data = BytesMut::new(); data.resize(hdr.len as usize, 0); reader.read_exact(&mut data).await?; Ok(UdpTraffic { from: hdr.from, data: data.freeze(), }) } } pub fn digest(data: &[u8]) -> Digest { use sha2::{Digest, Sha256}; let d = Sha256::new().chain_update(data).finalize(); d.into() } struct PacketLength { hello: usize, ack: usize, auth: usize, c_cmd: usize, d_cmd: usize, } impl PacketLength { pub fn new() -> PacketLength { let username = "default"; let d = digest(username.as_bytes()); let hello = bincode::serialized_size(&Hello::ControlChannelHello(CURRENT_PROTO_VERSION, d)) .unwrap() as usize; let c_cmd = bincode::serialized_size(&ControlChannelCmd::CreateDataChannel).unwrap() as usize; let d_cmd = bincode::serialized_size(&DataChannelCmd::StartForwardTcp).unwrap() as usize; let ack = Ack::Ok; let ack = bincode::serialized_size(&ack).unwrap() as usize; let auth = bincode::serialized_size(&Auth(d)).unwrap() as usize; PacketLength { hello, ack, auth, c_cmd, d_cmd, } } } lazy_static! { static ref PACKET_LEN: PacketLength = PacketLength::new(); } pub async fn read_hello(conn: &mut T) -> Result { let mut buf = vec![0u8; PACKET_LEN.hello]; conn.read_exact(&mut buf) .await .with_context(|| "Failed to read hello")?; let hello = bincode::deserialize(&buf).with_context(|| "Failed to deserialize hello")?; match hello { Hello::ControlChannelHello(v, _) => { if v != CURRENT_PROTO_VERSION { bail!( "Protocol version mismatched. Expected {}, got {}. Please update `rathole`.", CURRENT_PROTO_VERSION, v ); } } Hello::DataChannelHello(v, _) => { if v != CURRENT_PROTO_VERSION { bail!( "Protocol version mismatched. Expected {}, got {}. Please update `rathole`.", CURRENT_PROTO_VERSION, v ); } } } Ok(hello) } pub async fn read_auth(conn: &mut T) -> Result { let mut buf = vec![0u8; PACKET_LEN.auth]; conn.read_exact(&mut buf) .await .with_context(|| "Failed to read auth")?; bincode::deserialize(&buf).with_context(|| "Failed to deserialize auth") } pub async fn read_ack(conn: &mut T) -> Result { let mut bytes = vec![0u8; PACKET_LEN.ack]; conn.read_exact(&mut bytes) .await .with_context(|| "Failed to read ack")?; bincode::deserialize(&bytes).with_context(|| "Failed to deserialize ack") } pub async fn read_control_cmd( conn: &mut T, ) -> Result { let mut bytes = vec![0u8; PACKET_LEN.c_cmd]; conn.read_exact(&mut bytes) .await .with_context(|| "Failed to read cmd")?; bincode::deserialize(&bytes).with_context(|| "Failed to deserialize control cmd") } pub async fn read_data_cmd( conn: &mut T, ) -> Result { let mut bytes = vec![0u8; PACKET_LEN.d_cmd]; conn.read_exact(&mut bytes) .await .with_context(|| "Failed to read cmd")?; bincode::deserialize(&bytes).with_context(|| "Failed to deserialize data cmd") }