use core::result::Result; use std::io::{Error, ErrorKind}; use std::net::SocketAddr; use std::pin::Pin; use std::task::{ready, Context, Poll}; use super::{AddrMaybeCached, SocketOpts, TcpTransport, TlsTransport, Transport}; use crate::config::TransportConfig; use anyhow::anyhow; use async_trait::async_trait; use bytes::Bytes; use futures_core::stream::Stream; use futures_sink::Sink; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio_native_tls::TlsStream; use tokio_tungstenite::tungstenite::protocol::WebSocketConfig; use tokio_tungstenite::{accept_async_with_config, client_async_with_config}; use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream}; use tokio_util::io::StreamReader; use url::Url; #[derive(Debug)] enum TransportStream { Insecure(TcpStream), Secure(TlsStream), } impl TransportStream { fn get_tcpstream(&self) -> &TcpStream { match self { TransportStream::Insecure(s) => s, TransportStream::Secure(s) => s.get_ref().get_ref().get_ref(), } } } impl AsyncRead for TransportStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { match self.get_mut() { TransportStream::Insecure(s) => Pin::new(s).poll_read(cx, buf), TransportStream::Secure(s) => Pin::new(s).poll_read(cx, buf), } } } impl AsyncWrite for TransportStream { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match self.get_mut() { TransportStream::Insecure(s) => Pin::new(s).poll_write(cx, buf), TransportStream::Secure(s) => Pin::new(s).poll_write(cx, buf), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { TransportStream::Insecure(s) => Pin::new(s).poll_flush(cx), TransportStream::Secure(s) => Pin::new(s).poll_flush(cx), } } fn poll_shutdown( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { match self.get_mut() { TransportStream::Insecure(s) => Pin::new(s).poll_shutdown(cx), TransportStream::Secure(s) => Pin::new(s).poll_shutdown(cx), } } } #[derive(Debug)] struct StreamWrapper { inner: WebSocketStream, } impl Stream for StreamWrapper { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match Pin::new(&mut self.get_mut().inner).poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => Poll::Ready(None), Poll::Ready(Some(Err(err))) => { Poll::Ready(Some(Err(Error::new(ErrorKind::Other, err)))) } Poll::Ready(Some(Ok(res))) => { if let Message::Binary(b) = res { Poll::Ready(Some(Ok(Bytes::from(b)))) } else { Poll::Ready(Some(Err(Error::new( ErrorKind::InvalidData, "unexpected frame", )))) } } } } fn size_hint(&self) -> (usize, Option) { self.inner.size_hint() } } #[derive(Debug)] pub struct WebsocketTunnel { inner: StreamReader, } impl AsyncRead for WebsocketTunnel { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { Pin::new(&mut self.get_mut().inner).poll_read(cx, buf) } } impl AsyncBufRead for WebsocketTunnel { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.get_mut().inner).poll_fill_buf(cx) } fn consume(self: Pin<&mut Self>, amt: usize) { Pin::new(&mut self.get_mut().inner).consume(amt) } } impl AsyncWrite for WebsocketTunnel { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let sw = self.get_mut().inner.get_mut(); ready!(Pin::new(&mut sw.inner) .poll_ready(cx) .map_err(|err| Error::new(ErrorKind::Other, err)))?; match Pin::new(&mut sw.inner).start_send(Message::Binary(buf.to_vec())) { Ok(()) => Poll::Ready(Ok(buf.len())), Err(e) => Poll::Ready(Err(Error::new(ErrorKind::Other, e))), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.get_mut().inner.get_mut().inner) .poll_flush(cx) .map_err(|err| Error::new(ErrorKind::Other, err)) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.get_mut().inner.get_mut().inner) .poll_close(cx) .map_err(|err| Error::new(ErrorKind::Other, err)) } } #[derive(Debug)] enum SubTransport { Secure(TlsTransport), Insecure(TcpTransport), } #[derive(Debug)] pub struct WebsocketTransport { sub: SubTransport, conf: WebSocketConfig, } #[async_trait] impl Transport for WebsocketTransport { type Acceptor = TcpListener; type RawStream = TcpStream; type Stream = WebsocketTunnel; fn new(config: &TransportConfig) -> anyhow::Result { let wsconfig = config .websocket .as_ref() .ok_or_else(|| anyhow!("Missing websocket config"))?; let conf = WebSocketConfig { write_buffer_size: 0, ..WebSocketConfig::default() }; let sub = match wsconfig.tls { true => SubTransport::Secure(TlsTransport::new(config)?), false => SubTransport::Insecure(TcpTransport::new(config)?), }; Ok(WebsocketTransport { sub, conf }) } fn hint(conn: &Self::Stream, opt: SocketOpts) { opt.apply(conn.inner.get_ref().inner.get_ref().get_tcpstream()) } async fn bind( &self, addr: A, ) -> anyhow::Result { TcpListener::bind(addr).await.map_err(Into::into) } async fn accept(&self, a: &Self::Acceptor) -> anyhow::Result<(Self::RawStream, SocketAddr)> { let (s, addr) = match &self.sub { SubTransport::Insecure(t) => t.accept(a).await?, SubTransport::Secure(t) => t.accept(a).await?, }; Ok((s, addr)) } async fn handshake(&self, conn: Self::RawStream) -> anyhow::Result { let tsream = match &self.sub { SubTransport::Insecure(t) => TransportStream::Insecure(t.handshake(conn).await?), SubTransport::Secure(t) => TransportStream::Secure(t.handshake(conn).await?), }; let wsstream = accept_async_with_config(tsream, Some(self.conf)).await?; let tun = WebsocketTunnel { inner: StreamReader::new(StreamWrapper { inner: wsstream }), }; Ok(tun) } async fn connect(&self, addr: &AddrMaybeCached) -> anyhow::Result { let u = format!("ws://{}", &addr.addr.as_str()); let url = Url::parse(&u).unwrap(); let tstream = match &self.sub { SubTransport::Insecure(t) => TransportStream::Insecure(t.connect(addr).await?), SubTransport::Secure(t) => TransportStream::Secure(t.connect(addr).await?), }; let (wsstream, _) = client_async_with_config(url, tstream, Some(self.conf)) .await .expect("failed to connect"); let tun = WebsocketTunnel { inner: StreamReader::new(StreamWrapper { inner: wsstream }), }; Ok(tun) } }