diff options
author | 2023-06-29 15:02:38 +0100 | |
---|---|---|
committer | 2023-06-29 16:20:09 +0100 | |
commit | ca8d6e981a63a7ef911bb19ccf779be9743f0454 (patch) | |
tree | 794ee17162b52ced4d81c2b96ab7e93692c8ae66 | |
parent | d1662fb0dafebb536968ba445a94f7ba80d68e55 (diff) | |
download | quiche-cid-arc.tar.gz quiche-cid-arc.tar.zst quiche-cid-arc.zip |
simplify ConnectionIdcid-arc
Currently a ConnectionId can either be a slice or a Vec. Because of this
declarations might need to also declare an explicit lifetime, which make
using the struct slightly annoying.
Instead just use a Vec all the time, and to avoid wasting memory when
cloning (since the CIDs are read-only anyway), wrap the internal Vec in
an Arc.
-rw-r--r-- | apps/src/bin/quiche-server.rs | 8 | ||||
-rw-r--r-- | apps/src/common.rs | 4 | ||||
-rw-r--r-- | quiche/examples/http3-server.rs | 8 | ||||
-rw-r--r-- | quiche/examples/server.rs | 8 | ||||
-rw-r--r-- | quiche/src/cid.rs | 23 | ||||
-rw-r--r-- | quiche/src/lib.rs | 33 | ||||
-rw-r--r-- | quiche/src/packet.rs | 84 |
7 files changed, 68 insertions, 100 deletions
diff --git a/apps/src/bin/quiche-server.rs b/apps/src/bin/quiche-server.rs index f5d56910..3d261c1c 100644 --- a/apps/src/bin/quiche-server.rs +++ b/apps/src/bin/quiche-server.rs @@ -617,7 +617,7 @@ fn main() { ); for id in c.conn.source_ids() { - let id_owned = id.clone().into_owned(); + let id_owned = id.clone(); clients_ids.remove(&id_owned); } } @@ -658,9 +658,9 @@ fn mint_token(hdr: &quiche::Header, src: &net::SocketAddr) -> Vec<u8> { /// /// Note that this function is only an example and doesn't do any cryptographic /// authenticate of the token. *It should not be used in production system*. -fn validate_token<'a>( - src: &net::SocketAddr, token: &'a [u8], -) -> Option<quiche::ConnectionId<'a>> { +fn validate_token( + src: &net::SocketAddr, token: &[u8], +) -> Option<quiche::ConnectionId> { if token.len() < 6 { return None; } diff --git a/apps/src/common.rs b/apps/src/common.rs index e58e4a01..6f4fc83c 100644 --- a/apps/src/common.rs +++ b/apps/src/common.rs @@ -100,7 +100,7 @@ pub struct Client { pub max_send_burst: usize, } -pub type ClientIdMap = HashMap<ConnectionId<'static>, ClientId>; +pub type ClientIdMap = HashMap<ConnectionId, ClientId>; pub type ClientMap = HashMap<ClientId, Client>; /// Makes a buffered writer for a resource with a target URL. @@ -255,7 +255,7 @@ pub fn hdrs_to_strings(hdrs: &[quiche::h3::Header]) -> Vec<(String, String)> { /// Generate a new pair of Source Connection ID and reset token. pub fn generate_cid_and_reset_token<T: SecureRandom>( rng: &T, -) -> (quiche::ConnectionId<'static>, u128) { +) -> (quiche::ConnectionId, u128) { let mut scid = [0; quiche::MAX_CONN_ID_LEN]; rng.fill(&mut scid).unwrap(); let scid = scid.to_vec().into(); diff --git a/quiche/examples/http3-server.rs b/quiche/examples/http3-server.rs index 32650cdd..bbb58739 100644 --- a/quiche/examples/http3-server.rs +++ b/quiche/examples/http3-server.rs @@ -53,7 +53,7 @@ struct Client { partial_responses: HashMap<u64, PartialResponse>, } -type ClientMap = HashMap<quiche::ConnectionId<'static>, Client>; +type ClientMap = HashMap<quiche::ConnectionId, Client>; fn main() { let mut buf = [0; 65535]; @@ -476,9 +476,9 @@ fn mint_token(hdr: &quiche::Header, src: &net::SocketAddr) -> Vec<u8> { /// /// Note that this function is only an example and doesn't do any cryptographic /// authenticate of the token. *It should not be used in production system*. -fn validate_token<'a>( - src: &net::SocketAddr, token: &'a [u8], -) -> Option<quiche::ConnectionId<'a>> { +fn validate_token( + src: &net::SocketAddr, token: &[u8], +) -> Option<quiche::ConnectionId> { if token.len() < 6 { return None; } diff --git a/quiche/examples/server.rs b/quiche/examples/server.rs index 496b51ca..bc322e48 100644 --- a/quiche/examples/server.rs +++ b/quiche/examples/server.rs @@ -47,7 +47,7 @@ struct Client { partial_responses: HashMap<u64, PartialResponse>, } -type ClientMap = HashMap<quiche::ConnectionId<'static>, Client>; +type ClientMap = HashMap<quiche::ConnectionId, Client>; fn main() { let mut buf = [0; 65535]; @@ -418,9 +418,9 @@ fn mint_token(hdr: &quiche::Header, src: &net::SocketAddr) -> Vec<u8> { /// /// Note that this function is only an example and doesn't do any cryptographic /// authenticate of the token. *It should not be used in production system*. -fn validate_token<'a>( - src: &net::SocketAddr, token: &'a [u8], -) -> Option<quiche::ConnectionId<'a>> { +fn validate_token( + src: &net::SocketAddr, token: &[u8], +) -> Option<quiche::ConnectionId> { if token.len() < 6 { return None; } diff --git a/quiche/src/cid.rs b/quiche/src/cid.rs index fdedfe3a..a8edeb40 100644 --- a/quiche/src/cid.rs +++ b/quiche/src/cid.rs @@ -36,7 +36,7 @@ use std::collections::VecDeque; #[derive(Debug, Default)] pub struct ConnectionIdEntry { /// The Connection ID. - pub cid: ConnectionId<'static>, + pub cid: ConnectionId, /// Its associated sequence number. pub seq: u64, @@ -193,11 +193,11 @@ impl BoundedNonEmptyConnectionIdVecDeque { /// An iterator over QUIC Connection IDs. pub struct ConnectionIdIter { - cids: VecDeque<ConnectionId<'static>>, + cids: VecDeque<ConnectionId>, } impl Iterator for ConnectionIdIter { - type Item = ConnectionId<'static>; + type Item = ConnectionId; #[inline] fn next(&mut self) -> Option<Self::Item> { @@ -228,7 +228,7 @@ pub struct ConnectionIdentifiers { /// Retired Source Connection IDs that should be notified to the /// application. - retired_scids: VecDeque<ConnectionId<'static>>, + retired_scids: VecDeque<ConnectionId>, /// Largest "Retire Prior To" we received from the peer. largest_peer_retire_prior_to: u64, @@ -271,8 +271,7 @@ impl ConnectionIdentifiers { // Record the zero-length SCID status. let zero_length_scid = initial_scid.is_empty(); - let initial_scid = - ConnectionId::from_ref(initial_scid.as_ref()).into_owned(); + let initial_scid = ConnectionId::from_ref(initial_scid.as_ref()); // We need to track up to (2 * source_conn_id_limit - 1) source // Connection IDs when the host wants to force their renewal. @@ -366,8 +365,8 @@ impl ConnectionIdentifiers { /// [`InvalidState`]: enum.Error.html#InvalidState /// [`IdLimit`]: enum.Error.html#IdLimit pub fn new_scid( - &mut self, cid: ConnectionId<'static>, reset_token: Option<u128>, - advertise: bool, path_id: Option<usize>, retire_if_needed: bool, + &mut self, cid: ConnectionId, reset_token: Option<u128>, advertise: bool, + path_id: Option<usize>, retire_if_needed: bool, ) -> Result<u64> { if self.zero_length_scid { return Err(Error::InvalidState); @@ -415,7 +414,7 @@ impl ConnectionIdentifiers { /// Sets the initial destination identifier. pub fn set_initial_dcid( - &mut self, cid: ConnectionId<'static>, reset_token: Option<u128>, + &mut self, cid: ConnectionId, reset_token: Option<u128>, path_id: Option<usize>, ) { // Record the zero-length DCID status. @@ -438,7 +437,7 @@ impl ConnectionIdentifiers { /// sequence number of retired DCIDs that were linked to their respective /// Path ID. pub fn new_dcid( - &mut self, cid: ConnectionId<'static>, seq: u64, reset_token: u128, + &mut self, cid: ConnectionId, seq: u64, reset_token: u128, retire_prior_to: u64, ) -> Result<Vec<(u64, usize)>> { if self.zero_length_dcid { @@ -488,7 +487,7 @@ impl ConnectionIdentifiers { } let new_entry = ConnectionIdEntry { - cid: cid.clone(), + cid, seq, reset_token: Some(reset_token), path_id: None, @@ -803,7 +802,7 @@ impl ConnectionIdentifiers { self.retired_scids.len() } - pub fn pop_retired_scid(&mut self) -> Option<ConnectionId<'static>> { + pub fn pop_retired_scid(&mut self) -> Option<ConnectionId> { self.retired_scids.pop_front() } } diff --git a/quiche/src/lib.rs b/quiche/src/lib.rs index 1c8f1a79..43adaa81 100644 --- a/quiche/src/lib.rs +++ b/quiche/src/lib.rs @@ -1300,11 +1300,11 @@ pub struct Connection { /// Peer's original destination connection ID. Used by the client to /// validate the server's transport parameter. - odcid: Option<ConnectionId<'static>>, + odcid: Option<ConnectionId>, /// Peer's retry source connection ID. Used by the client during stateless /// retry to validate the server's transport parameter. - rscid: Option<ConnectionId<'static>>, + rscid: Option<ConnectionId>, /// Received address verification token. token: Option<Vec<u8>>, @@ -1521,7 +1521,7 @@ pub fn negotiate_version( /// # fn mint_token(hdr: &quiche::Header, src: &std::net::SocketAddr) -> Vec<u8> { /// # vec![] /// # } -/// # fn validate_token<'a>(src: &std::net::SocketAddr, token: &'a [u8]) -> Option<quiche::ConnectionId<'a>> { +/// # fn validate_token<'a>(src: &std::net::SocketAddr, token: &'a [u8]) -> Option<quiche::ConnectionId> { /// # None /// # } /// let (len, peer) = socket.recv_from(&mut buf).unwrap(); @@ -2350,7 +2350,7 @@ impl Connection { self.did_retry = true; // Remember peer's new connection ID. - self.odcid = Some(self.destination_id().into_owned()); + self.odcid = Some(self.destination_id()); self.set_initial_dcid( hdr.scid.clone(), @@ -2358,7 +2358,7 @@ impl Connection { self.paths.get_active_path_id()?, )?; - self.rscid = Some(self.destination_id().into_owned()); + self.rscid = Some(self.destination_id()); // Derive Initial secrets using the new connection ID. let (aead_open, aead_seal) = crypto::derive_initial_key_material( @@ -2648,7 +2648,7 @@ impl Connection { if !self.is_server && !self.got_peer_conn_id { if self.odcid.is_none() { - self.odcid = Some(self.destination_id().into_owned()); + self.odcid = Some(self.destination_id()); } // Replace the randomly generated destination connection ID with @@ -5829,7 +5829,7 @@ impl Connection { /// more retired connection IDs. /// /// [`ConnectionId`]: struct.ConnectionId.html - pub fn retired_scid_next(&mut self) -> Option<ConnectionId<'static>> { + pub fn retired_scid_next(&mut self) -> Option<ConnectionId> { self.ids.pop_retired_scid() } @@ -7111,8 +7111,7 @@ impl Connection { } fn set_initial_dcid( - &mut self, cid: ConnectionId<'static>, reset_token: Option<u128>, - path_id: usize, + &mut self, cid: ConnectionId, reset_token: Option<u128>, path_id: usize, ) -> Result<()> { self.ids.set_initial_dcid(cid, reset_token, Some(path_id)); self.paths.get_mut(path_id)?.active_dcid_seq = Some(0); @@ -7552,7 +7551,7 @@ impl std::fmt::Debug for Stats { #[derive(Clone, Debug, PartialEq)] struct TransportParams { - pub original_destination_connection_id: Option<ConnectionId<'static>>, + pub original_destination_connection_id: Option<ConnectionId>, pub max_idle_timeout: u64, pub stateless_reset_token: Option<u128>, pub max_udp_payload_size: u64, @@ -7567,8 +7566,8 @@ struct TransportParams { pub disable_active_migration: bool, // pub preferred_address: ..., pub active_conn_id_limit: u64, - pub initial_source_connection_id: Option<ConnectionId<'static>>, - pub retry_source_connection_id: Option<ConnectionId<'static>>, + pub initial_source_connection_id: Option<ConnectionId>, + pub retry_source_connection_id: Option<ConnectionId>, pub max_datagram_frame_size: Option<u64>, } @@ -8420,12 +8419,10 @@ pub mod testing { Ok(frames) } - pub fn create_cid_and_reset_token( - cid_len: usize, - ) -> (ConnectionId<'static>, u128) { + pub fn create_cid_and_reset_token(cid_len: usize) -> (ConnectionId, u128) { let mut cid = vec![0; cid_len]; rand::rand_bytes(&mut cid[..]); - let cid = ConnectionId::from_ref(&cid).into_owned(); + let cid = ConnectionId::from_ref(&cid); let mut reset_token = [0; 16]; rand::rand_bytes(&mut reset_token); @@ -14256,7 +14253,7 @@ mod tests { assert_eq!(pipe.server.path_event_next(), None); assert_eq!(pipe.client.source_cids_left(), 1); - let scid = pipe.client.source_id().into_owned(); + let scid = pipe.client.source_id(); let (scid_1, reset_token_1) = testing::create_cid_and_reset_token(16); assert_eq!( @@ -14351,7 +14348,7 @@ mod tests { let mut pipe = testing::Pipe::with_config(&mut config).unwrap(); assert_eq!(pipe.handshake(), Ok(())); - let scid = pipe.client.source_id().into_owned(); + let scid = pipe.client.source_id(); let (scid_1, reset_token_1) = testing::create_cid_and_reset_token(16); assert_eq!( diff --git a/quiche/src/packet.rs b/quiche/src/packet.rs index 9c7b9c43..1da96b5d 100644 --- a/quiche/src/packet.rs +++ b/quiche/src/packet.rs @@ -28,6 +28,7 @@ use std::fmt::Display; use std::ops::Index; use std::ops::IndexMut; use std::ops::RangeInclusive; +use std::sync::Arc; use std::time; use ring::aead; @@ -174,103 +175,76 @@ impl Type { } /// A QUIC connection ID. -pub struct ConnectionId<'a>(ConnectionIdInner<'a>); +pub struct ConnectionId(Arc<Vec<u8>>); -enum ConnectionIdInner<'a> { - Vec(Vec<u8>), - Ref(&'a [u8]), -} - -impl<'a> ConnectionId<'a> { +impl ConnectionId { /// Creates a new connection ID from the given vector. #[inline] - pub const fn from_vec(cid: Vec<u8>) -> Self { - Self(ConnectionIdInner::Vec(cid)) + pub fn from_vec(cid: Vec<u8>) -> Self { + Self(Arc::new(cid)) } /// Creates a new connection ID from the given slice. #[inline] - pub const fn from_ref(cid: &'a [u8]) -> Self { - Self(ConnectionIdInner::Ref(cid)) - } - - /// Returns a new owning connection ID from the given existing one. - #[inline] - pub fn into_owned(self) -> ConnectionId<'static> { - ConnectionId::from_vec(self.into()) + pub fn from_ref(cid: &[u8]) -> Self { + Self::from_vec(cid.to_vec()) } } -impl<'a> Default for ConnectionId<'a> { +impl Default for ConnectionId { #[inline] fn default() -> Self { Self::from_vec(Vec::new()) } } -impl<'a> From<Vec<u8>> for ConnectionId<'a> { +impl From<Vec<u8>> for ConnectionId { #[inline] fn from(v: Vec<u8>) -> Self { Self::from_vec(v) } } -impl<'a> From<ConnectionId<'a>> for Vec<u8> { - #[inline] - fn from(id: ConnectionId<'a>) -> Self { - match id.0 { - ConnectionIdInner::Vec(cid) => cid, - ConnectionIdInner::Ref(cid) => cid.to_vec(), - } - } -} - -impl<'a> PartialEq for ConnectionId<'a> { +impl PartialEq for ConnectionId { #[inline] fn eq(&self, other: &Self) -> bool { self.as_ref() == other.as_ref() } } -impl<'a> Eq for ConnectionId<'a> {} +impl Eq for ConnectionId {} -impl<'a> AsRef<[u8]> for ConnectionId<'a> { +impl AsRef<[u8]> for ConnectionId { #[inline] fn as_ref(&self) -> &[u8] { - match &self.0 { - ConnectionIdInner::Vec(v) => v.as_ref(), - ConnectionIdInner::Ref(v) => v, - } + self.0.as_ref() } } -impl<'a> std::hash::Hash for ConnectionId<'a> { +impl std::hash::Hash for ConnectionId { #[inline] fn hash<H: std::hash::Hasher>(&self, state: &mut H) { self.as_ref().hash(state); } } -impl<'a> std::ops::Deref for ConnectionId<'a> { +impl std::ops::Deref for ConnectionId { type Target = [u8]; #[inline] fn deref(&self) -> &[u8] { - match &self.0 { - ConnectionIdInner::Vec(v) => v.as_ref(), - ConnectionIdInner::Ref(v) => v, - } + self.0.as_ref() } } -impl<'a> Clone for ConnectionId<'a> { +impl Clone for ConnectionId { #[inline] fn clone(&self) -> Self { - Self::from_vec(self.as_ref().to_vec()) + Self(Arc::clone(&self.0)) } } -impl<'a> std::fmt::Debug for ConnectionId<'a> { +impl std::fmt::Debug for ConnectionId { #[inline] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { for c in self.as_ref() { @@ -283,7 +257,7 @@ impl<'a> std::fmt::Debug for ConnectionId<'a> { /// A QUIC packet's header. #[derive(Clone, PartialEq, Eq)] -pub struct Header<'a> { +pub struct Header { /// The type of the packet. pub ty: Type, @@ -291,10 +265,10 @@ pub struct Header<'a> { pub version: u32, /// The destination connection ID of the packet. - pub dcid: ConnectionId<'a>, + pub dcid: ConnectionId, /// The source connection ID of the packet. - pub scid: ConnectionId<'a>, + pub scid: ConnectionId, /// The packet number. It's only meaningful after the header protection is /// removed. @@ -317,7 +291,7 @@ pub struct Header<'a> { pub(crate) key_phase: bool, } -impl<'a> Header<'a> { +impl Header { /// Parses a QUIC packet header from the given buffer. /// /// The `dcid_len` parameter is the length of the destination connection ID, @@ -336,16 +310,14 @@ impl<'a> Header<'a> { /// # Ok::<(), quiche::Error>(()) /// ``` #[inline] - pub fn from_slice<'b>( - buf: &'b mut [u8], dcid_len: usize, - ) -> Result<Header<'a>> { + pub fn from_slice(buf: &mut [u8], dcid_len: usize) -> Result<Header> { let mut b = octets::OctetsMut::with_slice(buf); Header::from_bytes(&mut b, dcid_len) } - pub(crate) fn from_bytes<'b>( - b: &'b mut octets::OctetsMut, dcid_len: usize, - ) -> Result<Header<'a>> { + pub(crate) fn from_bytes( + b: &mut octets::OctetsMut, dcid_len: usize, + ) -> Result<Header> { let first = b.get_u8()?; if !Header::is_long(first) { @@ -522,7 +494,7 @@ impl<'a> Header<'a> { } } -impl<'a> std::fmt::Debug for Header<'a> { +impl std::fmt::Debug for Header { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{:?}", self.ty)?; |