Bladeren bron

Modified the btmsg API to allow a single UDP socket
to be used for the receiver and all the senders.

Matthew Carr 2 jaren geleden
bovenliggende
commit
3125364053
2 gewijzigde bestanden met toevoegingen van 138 en 109 verwijderingen
  1. 1 1
      crates/btlib/src/block_path.rs
  2. 137 108
      crates/btmsg/src/lib.rs

+ 1 - 1
crates/btlib/src/block_path.rs

@@ -8,7 +8,7 @@ mod private {
     use super::*;
 
     /// An identifier for a block in a tree.
-    #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Default)]
+    #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Default, Hash)]
     pub struct BlockPath {
         root: Principal,
         components: Vec<String>,

+ 137 - 108
crates/btmsg/src/lib.rs

@@ -3,17 +3,19 @@ use btlib::{
     bterr,
     crypto::{rand_array, ConcreteCreds, CredsPriv, CredsPub},
     error::BoxInIoErr,
-    Principal, Result,
+    BlockPath, Result,
 };
 use btserde::{read_from, write_to};
 use bytes::{BufMut, BytesMut};
 use core::{
     future::Future,
+    marker::Send,
     pin::Pin,
     task::{Context, Poll},
 };
 use futures::{
-    sink::{Close, Send, Sink},
+    future::{ready, Ready},
+    sink::{Close, Send as SendFut, Sink},
     stream::Stream,
     SinkExt, StreamExt,
 };
@@ -23,13 +25,13 @@ use quinn::{ClientConfig, Endpoint, SendStream, ServerConfig};
 use rustls::{
     Certificate, ConfigBuilder, ConfigSide, PrivateKey, WantsCipherSuites, WantsVerifier,
 };
-use serde::{Deserialize, Serialize};
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use std::{
     collections::hash_map::DefaultHasher,
     hash::{Hash, Hasher},
     io,
     marker::PhantomData,
-    net::{IpAddr, Ipv6Addr, Shutdown, SocketAddr},
+    net::{IpAddr, Shutdown, SocketAddr},
     path::PathBuf,
     sync::Arc,
 };
@@ -51,22 +53,10 @@ mod private {
 
     use super::*;
 
-    /// Returns a [Receiver] which can be used to receive messages addressed to the given path.
-    /// The `fs_path` argument specifies the filesystem directory under which the receiver's socket
-    /// will be stored.
-    pub fn local_receiver<T: for<'de> Deserialize<'de> + core::marker::Send + 'static>(
-        addr: BlockAddr,
-        creds: &ConcreteCreds,
-    ) -> Result<impl Receiver<T>> {
-        QuicReceiver::new(addr, creds)
-    }
-
-    /// Returns a [Sender] which can be used to send messages to the given Blocktree path.
-    /// The `fs_path` argument specifies the filesystem directory in which to locate the
-    /// socket of the recipient.
-    pub async fn local_sender(addr: BlockAddr) -> Result<impl Sender> {
-        let result = QuicSender::new(addr).await;
-        result
+    /// Returns a [Router] which can be used to make a [Receiver] for the given path and
+    ///  [Sender] instances for any path.
+    pub fn router(addr: Arc<BlockAddr>, creds: &ConcreteCreds) -> Result<impl Router> {
+        QuicRouter::new(addr, creds)
     }
 
     lazy_static! {
@@ -81,7 +71,7 @@ mod private {
     /// Appends the given Blocktree path to the path of the given directory.
     #[allow(dead_code)]
     fn socket_path(fs_path: &mut PathBuf, addr: &BlockAddr) {
-        fs_path.push(addr.num.value().to_string());
+        fs_path.push(addr.path.to_string());
     }
 
     fn common_config<Side: ConfigSide>(
@@ -139,30 +129,31 @@ mod private {
         }
     }
 
-    #[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug)]
+    #[derive(PartialEq, Eq, Hash, Clone, Debug)]
     pub struct BlockAddr {
-        /// The root principal of the blocktree this block is part of.
-        pub root: Principal,
-        /// The cluster ID this block is served by.
-        pub cluster: u64,
-        /// The number of this block.
-        pub num: BlockNum,
+        ip_addr: IpAddr,
+        path: Arc<BlockPath>,
     }
 
     impl BlockAddr {
-        pub fn new(root: Principal, cluster: u64, num: BlockNum) -> Self {
-            Self { root, cluster, num }
+        pub fn new(ip_addr: IpAddr, path: Arc<BlockPath>) -> Self {
+            Self { ip_addr, path }
         }
 
-        pub fn port(&self) -> u16 {
+        fn port(&self) -> u16 {
             let mut hasher = DefaultHasher::new();
-            self.hash(&mut hasher);
+            self.path.hash(&mut hasher);
             let hash = hasher.finish();
             // We compute a port in the dynamic range [49152, 65535] as defined by RFC 6335.
             const NUM_RES_PORTS: u16 = 49153;
             const PORTS_AVAIL: u64 = (u16::MAX - NUM_RES_PORTS) as u64;
             NUM_RES_PORTS + (hash % PORTS_AVAIL) as u16
         }
+
+        pub fn socket_addr(&self) -> SocketAddr {
+            let port = self.port();
+            SocketAddr::new(self.ip_addr, port)
+        }
     }
 
     /// Generates and returns a new message ID.
@@ -176,21 +167,17 @@ mod private {
 
     #[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
     pub struct Msg<T> {
-        pub to: BlockAddr,
-        pub from: BlockAddr,
         pub id: u128,
         pub body: T,
     }
 
     impl<T> Msg<T> {
-        pub fn new(to: BlockAddr, from: BlockAddr, id: u128, body: T) -> Self {
-            Self { to, from, id, body }
+        pub fn new(id: u128, body: T) -> Self {
+            Self { id, body }
         }
 
-        pub fn with_rand_id(to: BlockAddr, from: BlockAddr, body: T) -> Result<Self> {
+        pub fn with_rand_id(body: T) -> Result<Self> {
             Ok(Self {
-                to,
-                from,
                 id: rand_msg_id()?,
                 body,
             })
@@ -201,28 +188,24 @@ mod private {
     /// Once the "Permit impl Trait in type aliases" https://github.com/rust-lang/rust/issues/63063
     /// feature lands the future types in this trait should be rewritten to use it.
     pub trait Sender {
-        type SendFut<'a, T>: 'a + Future<Output = Result<()>> + core::marker::Send
+        type SendFut<'a, T>: 'a + Future<Output = Result<()>> + Send
         where
             Self: 'a,
-            T: 'a + Serialize + core::marker::Send;
+            T: 'a + Serialize + Send;
 
-        fn send<'a, T: 'a + Serialize + core::marker::Send>(
-            &'a mut self,
-            msg: Msg<T>,
-        ) -> Self::SendFut<'a, T>;
+        fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: Msg<T>) -> Self::SendFut<'a, T>;
 
-        type FinishFut: Future<Output = Result<()>> + core::marker::Send;
+        type FinishFut: Future<Output = Result<()>> + Send;
 
         fn finish(self) -> Self::FinishFut;
 
         fn addr(&self) -> &BlockAddr;
 
-        fn send_msg<'a, T: 'a + Serialize + core::marker::Send>(
+        fn send_with_rand_id<'a, T: 'a + Serialize + Send>(
             &'a mut self,
-            from: BlockAddr,
             body: T,
         ) -> Self::SendFut<'a, T> {
-            let msg = Msg::with_rand_id(self.addr().clone(), from, body).unwrap();
+            let msg = Msg::with_rand_id(body).unwrap();
             self.send(msg)
         }
     }
@@ -232,6 +215,25 @@ mod private {
         fn addr(&self) -> &BlockAddr;
     }
 
+    pub trait Router {
+        type Sender: Sender + Send;
+        type SenderFut<'a>: 'a + Future<Output = Result<Self::Sender>> + Send
+        where
+            Self: 'a;
+
+        fn sender(&self, addr: Arc<BlockAddr>) -> Self::SenderFut<'_>;
+
+        type Receiver<T: 'static + DeserializeOwned + Send>: Receiver<T> + Send + Unpin;
+        type ReceiverFut<'a, T>: 'a + Future<Output = Result<Self::Receiver<T>>> + Send
+        where
+            T: 'static + DeserializeOwned + Send + Unpin,
+            Self: 'a;
+
+        fn receiver<T: 'static + DeserializeOwned + Send + Unpin>(
+            &self,
+        ) -> Self::ReceiverFut<'_, T>;
+    }
+
     /// Encodes and decodes messages using [btserde].
     struct MsgEncoder;
 
@@ -267,7 +269,7 @@ mod private {
         }
     }
 
-    impl<T: for<'de> Deserialize<'de>> Decoder for MsgDecoder<T> {
+    impl<T: DeserializeOwned> Decoder for MsgDecoder<T> {
         type Item = Msg<T>;
         type Error = btlib::Error;
 
@@ -362,7 +364,7 @@ mod private {
         socket: FramedRead<DatagramAdapter, MsgDecoder<T>>,
     }
 
-    impl<T: for<'de> Deserialize<'de>> UnixReceiver<T> {
+    impl<T: DeserializeOwned> UnixReceiver<T> {
         #[allow(dead_code)]
         fn new(mut fs_path: PathBuf, addr: BlockAddr) -> Result<Self> {
             socket_path(&mut fs_path, &addr);
@@ -372,7 +374,7 @@ mod private {
         }
     }
 
-    impl<T: for<'de> Deserialize<'de>> Stream for UnixReceiver<T> {
+    impl<T: DeserializeOwned> Stream for UnixReceiver<T> {
         type Item = Result<Msg<T>>;
 
         fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
@@ -380,7 +382,7 @@ mod private {
         }
     }
 
-    impl<T: for<'de> Deserialize<'de>> Receiver<T> for UnixReceiver<T> {
+    impl<T: DeserializeOwned> Receiver<T> for UnixReceiver<T> {
         fn addr(&self) -> &BlockAddr {
             &self.addr
         }
@@ -435,17 +437,14 @@ mod private {
         }
 
         type SendFut<'a, T>
-            = Send<'a, FramedWrite<DatagramAdapter, MsgEncoder>, Msg<T>>
-                where T: 'a + Serialize + core::marker::Send;
+            = SendFut<'a, FramedWrite<DatagramAdapter, MsgEncoder>, Msg<T>>
+                where T: 'a + Serialize + Send;
 
-        fn send<'a, T: 'a + Serialize + core::marker::Send>(
-            &'a mut self,
-            msg: Msg<T>,
-        ) -> Self::SendFut<'a, T> {
+        fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: Msg<T>) -> Self::SendFut<'a, T> {
             self.socket.send(msg)
         }
 
-        type FinishFut = Pin<Box<dyn Future<Output = Result<()>> + core::marker::Send>>;
+        type FinishFut = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
 
         fn finish(mut self) -> Self::FinishFut {
             Box::pin(async move {
@@ -469,21 +468,55 @@ mod private {
         };
     }
 
+    struct QuicRouter {
+        recv_addr: Arc<BlockAddr>,
+        endpoint: Endpoint,
+    }
+
+    impl QuicRouter {
+        fn new(recv_addr: Arc<BlockAddr>, creds: &ConcreteCreds) -> Result<Self> {
+            let socket_addr = recv_addr.socket_addr();
+            let endpoint = Endpoint::server(server_config(creds)?, socket_addr)?;
+            Ok(Self {
+                endpoint,
+                recv_addr,
+            })
+        }
+    }
+
+    impl Router for QuicRouter {
+        type Receiver<T: 'static + DeserializeOwned + Send> = QuicReceiver<T>;
+        type ReceiverFut<'a, T: 'static + DeserializeOwned + Send + Unpin> =
+            Ready<Result<QuicReceiver<T>>>;
+
+        fn receiver<T: 'static + DeserializeOwned + Send + Unpin>(
+            &self,
+        ) -> Self::ReceiverFut<'_, T> {
+            ready(QuicReceiver::new(
+                self.endpoint.clone(),
+                self.recv_addr.clone(),
+            ))
+        }
+
+        type Sender = QuicSender;
+        type SenderFut<'a> = Pin<Box<dyn 'a + Future<Output = Result<QuicSender>> + Send>>;
+
+        fn sender(&self, addr: Arc<BlockAddr>) -> Self::SenderFut<'_> {
+            Box::pin(async { QuicSender::from_endpoint(self.endpoint.clone(), addr).await })
+        }
+    }
+
     struct QuicReceiver<T> {
-        addr: BlockAddr,
+        recv_addr: Arc<BlockAddr>,
         stop_tx: broadcast::Sender<()>,
         stream: ReceiverStream<Result<Msg<T>>>,
     }
 
-    impl<T: for<'de> Deserialize<'de> + core::marker::Send + 'static> QuicReceiver<T> {
+    impl<T: DeserializeOwned + Send + 'static> QuicReceiver<T> {
         /// The size of the buffer to store received messages in.
         const MSG_BUF_SZ: usize = 64;
 
-        fn new(addr: BlockAddr, creds: &ConcreteCreds) -> Result<Self> {
-            let config = server_config(creds)?;
-            let port = addr.port();
-            let socket_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
-            let endpoint = Endpoint::server(config, socket_addr)?;
+        fn new(endpoint: Endpoint, recv_addr: Arc<BlockAddr>) -> Result<Self> {
             let (stop_tx, mut stop_rx) = broadcast::channel(1);
             let (msg_tx, msg_rx) = mpsc::channel(Self::MSG_BUF_SZ);
             tokio::spawn(async move {
@@ -525,7 +558,7 @@ mod private {
                 }
             });
             Ok(Self {
-                addr,
+                recv_addr,
                 stop_tx,
                 stream: ReceiverStream::new(msg_rx),
             })
@@ -540,7 +573,7 @@ mod private {
         }
     }
 
-    impl<T: for<'de> Deserialize<'de> + core::marker::Send + 'static> Stream for QuicReceiver<T> {
+    impl<T: DeserializeOwned + Send + 'static> Stream for QuicReceiver<T> {
         type Item = Result<Msg<T>>;
 
         fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
@@ -548,9 +581,9 @@ mod private {
         }
     }
 
-    impl<T: for<'de> Deserialize<'de> + core::marker::Send + 'static> Receiver<T> for QuicReceiver<T> {
+    impl<T: DeserializeOwned + Send + 'static> Receiver<T> for QuicReceiver<T> {
         fn addr(&self) -> &BlockAddr {
-            &self.addr
+            &self.recv_addr
         }
     }
 
@@ -578,18 +611,14 @@ mod private {
     }
 
     struct QuicSender {
-        addr: BlockAddr,
+        addr: Arc<BlockAddr>,
         sink: FramedWrite<SendStream, MsgEncoder>,
     }
 
     impl QuicSender {
-        async fn new(addr: BlockAddr) -> Result<Self> {
-            let mut endpoint =
-                Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))?;
-            endpoint.set_default_client_config(client_config()?);
-            let port = addr.port();
-            let socket_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), port);
-            let connecting = endpoint.connect(socket_addr, "localhost")?;
+        async fn from_endpoint(endpoint: Endpoint, addr: Arc<BlockAddr>) -> Result<Self> {
+            let socket_addr = addr.socket_addr();
+            let connecting = endpoint.connect_with(client_config()?, socket_addr, "localhost")?;
             let connection = connecting.await?;
             let send_stream = connection.open_uni().await?;
             let sink = FramedWrite::new(send_stream, MsgEncoder::new());
@@ -602,17 +631,14 @@ mod private {
             &self.addr
         }
 
-        type SendFut<'a, T> = Send<'a, FramedWrite<SendStream, MsgEncoder>, Msg<T>>
-            where T: 'a + Serialize + core::marker::Send;
+        type SendFut<'a, T> = SendFut<'a, FramedWrite<SendStream, MsgEncoder>, Msg<T>>
+            where T: 'a + Serialize + Send;
 
-        fn send<'a, T: 'a + Serialize + core::marker::Send>(
-            &'a mut self,
-            msg: Msg<T>,
-        ) -> Self::SendFut<'a, T> {
+        fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: Msg<T>) -> Self::SendFut<'a, T> {
             self.sink.send(msg)
         }
 
-        type FinishFut = Pin<Box<dyn Future<Output = Result<()>> + core::marker::Send>>;
+        type FinishFut = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
 
         fn finish(mut self) -> Self::FinishFut {
             Box::pin(async move {
@@ -628,8 +654,8 @@ mod private {
     /// lifetimes. Once this issue is resolved this can be removed:
     /// https://github.com/rust-lang/rust/issues/102211
     pub fn assert_send<'a, T>(
-        fut: impl 'a + Future<Output = T> + core::marker::Send,
-    ) -> impl 'a + Future<Output = T> + core::marker::Send {
+        fut: impl 'a + Future<Output = T> + Send,
+    ) -> impl 'a + Future<Output = T> + Send {
         fut
     }
 }
@@ -638,9 +664,10 @@ mod private {
 mod tests {
     use super::*;
 
-    use btlib::{crypto::Creds, Epoch, Principaled};
+    use btlib::{crypto::Creds, Epoch, Principal, Principaled};
     use ctor::ctor;
     use std::{
+        net::Ipv6Addr,
         sync::atomic::{AtomicU64, Ordering},
         time::Duration,
     };
@@ -668,8 +695,11 @@ mod tests {
         static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal();
     }
 
-    fn block_addr(generation: u64, inode: u64) -> BlockAddr {
-        BlockAddr::new(ROOT_PRINCIPAL.clone(), generation, BlockNum::Inode(inode))
+    fn block_addr<'a, I: Iterator<Item = S>, S: ToString>(components: I) -> BlockAddr {
+        let components = components.map(|e| e.to_string()).collect();
+        let path = BlockPath::new(ROOT_CREDS.principal(), components);
+        let path = Arc::new(path);
+        BlockAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), path)
     }
 
     #[derive(Serialize, Deserialize)]
@@ -706,20 +736,21 @@ mod tests {
             Self { instance_num }
         }
 
-        async fn endpoint(&self, inode: u64) -> (BlockAddr, impl Sender, impl Receiver<BodyOwned>) {
-            let addr = block_addr(self.instance_num, inode);
-            let receiver = local_receiver(addr.clone(), &NODE_CREDS).unwrap();
-            let sender = local_sender(addr.clone()).await.unwrap();
-            (addr, sender, receiver)
+        async fn endpoint(&self, inode: u64) -> (impl Sender, impl Receiver<BodyOwned>) {
+            let addr = Arc::new(block_addr([self.instance_num, inode].iter()));
+            let router = router(addr.clone(), &NODE_CREDS).unwrap();
+            let receiver = router.receiver::<BodyOwned>().await.unwrap();
+            let sender = router.sender(addr).await.unwrap();
+            (sender, receiver)
         }
     }
 
     #[tokio::test]
     async fn message_received_is_message_sent() {
         let case = TestCase::new();
-        let (addr, mut sender, mut receiver) = case.endpoint(1).await;
+        let (mut sender, mut receiver) = case.endpoint(1).await;
 
-        sender.send_msg(addr.clone(), BodyRef::Ping).await.unwrap();
+        sender.send_with_rand_id(BodyRef::Ping).await.unwrap();
         let actual = receiver.next().await.unwrap().unwrap();
 
         let matched = if let BodyOwned::Ping = actual.body {
@@ -728,15 +759,13 @@ mod tests {
             false
         };
         assert!(matched);
-        assert_eq!(&addr, &actual.to);
-        assert_eq!(&addr, &actual.from);
     }
 
     #[tokio::test]
     async fn ping_pong() {
         let case = TestCase::new();
-        let (addr_one, mut sender_one, mut receiver_one) = case.endpoint(1).await;
-        let (addr_two, mut sender_two, mut receiver_two) = case.endpoint(2).await;
+        let (mut sender_one, mut receiver_one) = case.endpoint(1).await;
+        let (mut sender_two, mut receiver_two) = case.endpoint(2).await;
 
         tokio::spawn(async move {
             let msg = receiver_one.next().await.unwrap().unwrap();
@@ -745,12 +774,12 @@ mod tests {
             } else {
                 BodyRef::Fail(MsgError::Unknown)
             };
-            let fut = assert_send::<'_, Result<()>>(sender_two.send_msg(addr_one, reply_body));
+            let fut = assert_send::<'_, Result<()>>(sender_two.send_with_rand_id(reply_body));
             fut.await.unwrap();
             sender_two.finish().await.unwrap();
         });
 
-        sender_one.send_msg(addr_two, BodyRef::Ping).await.unwrap();
+        sender_one.send_with_rand_id(BodyRef::Ping).await.unwrap();
         let reply = receiver_two.next().await.unwrap().unwrap();
         let matched = if let BodyOwned::Success = reply.body {
             true
@@ -763,8 +792,8 @@ mod tests {
     #[tokio::test]
     async fn read_write() {
         let case = TestCase::new();
-        let (addr_one, mut sender_one, mut receiver_one) = case.endpoint(1).await;
-        let (addr_two, mut sender_two, mut receiver_two) = case.endpoint(2).await;
+        let (mut sender_one, mut receiver_one) = case.endpoint(1).await;
+        let (mut sender_two, mut receiver_two) = case.endpoint(2).await;
 
         let handle = tokio::spawn(async move {
             let data: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
@@ -780,14 +809,14 @@ mod tests {
             } else {
                 BodyRef::Fail(MsgError::Unknown)
             };
-            let msg = Msg::new(msg.from, addr_one, msg.id, reply_body);
+            let msg = Msg::new(msg.id, reply_body);
             let fut = assert_send::<'_, Result<()>>(sender_two.send(msg));
             fut.await.unwrap();
             sender_two.finish().await.unwrap();
         });
 
         sender_one
-            .send_msg(addr_two, BodyRef::Read { offset: 2, size: 2 })
+            .send_with_rand_id(BodyRef::Read { offset: 2, size: 2 })
             .await
             .unwrap();
         handle.await.unwrap();