瀏覽代碼

Started working on supporting remote messages in btrun.

Matthew Carr 1 年之前
父節點
當前提交
dca9e637b1
共有 5 個文件被更改,包括 134 次插入75 次删除
  1. 1 0
      Cargo.lock
  2. 3 3
      crates/btlib/src/crypto.rs
  3. 5 0
      crates/btlib/src/log.rs
  4. 1 0
      crates/btrun/Cargo.toml
  5. 124 72
      crates/btrun/src/lib.rs

+ 1 - 0
Cargo.lock

@@ -458,6 +458,7 @@ dependencies = [
  "btmsg",
  "btserde",
  "bytes",
+ "env_logger",
  "futures",
  "serde",
  "strum",

+ 3 - 3
crates/btlib/src/crypto.rs

@@ -2199,15 +2199,15 @@ pub trait Verifier {
 
 impl<V: ?Sized + Verifier> Verifier for &V {
     fn init_verify<'a>(&'a self) -> Result<Box<dyn 'a + VerifyOp>> {
-        self.deref().init_verify()
+        (*self).init_verify()
     }
 
     fn verify(&self, parts: &mut dyn Iterator<Item = &[u8]>, signature: &[u8]) -> Result<()> {
-        self.deref().verify(parts, signature)
+        (*self).verify(parts, signature)
     }
 
     fn kind(&self) -> Sign {
-        self.deref().kind()
+        (*self).kind()
     }
 }
 

+ 5 - 0
crates/btlib/src/log.rs

@@ -28,3 +28,8 @@ impl BuilderExt for env_logger::Builder {
         })
     }
 }
+
+/// Initializes [env_logger] using the default environment and `btformat`.
+pub fn init() {
+    env_logger::Builder::from_default_env().btformat().init();
+}

+ 1 - 0
crates/btrun/Cargo.toml

@@ -19,3 +19,4 @@ strum = { version = "^0.24.0", features = ["derive"] }
 
 [dev-dependencies]
 btlib-tests = { path = "../btlib-tests" }
+env_logger = { version = "0.9.0" }

+ 124 - 72
crates/btrun/src/lib.rs

@@ -7,7 +7,7 @@ use btmsg::{MsgCallback, Receiver, Replier, Transmitter};
 use btserde::field_helpers::smart_ptr;
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use tokio::{
-    sync::{mpsc, oneshot, RwLock},
+    sync::{mpsc, oneshot, Mutex, RwLock},
     task::JoinHandle,
 };
 use uuid::Uuid;
@@ -87,12 +87,18 @@ impl<Rx: Receiver> Runtime<Rx> {
     }
 
     /// Activates a new actor using the given activator function and returns a handle to it.
-    pub async fn activate<Msg, F, Fut, G>(&self, activator: F, deserializer: G) -> ActorName
+    pub async fn activate<Msg, F, Fut, G, H>(
+        &self,
+        activator: F,
+        deserializer: G,
+        serializer: H,
+    ) -> ActorName
     where
         Msg: 'static + CallMsg,
         Fut: 'static + Send + Future<Output = ()>,
         F: FnOnce(mpsc::Receiver<Envelope<Msg>>, Uuid) -> Fut,
         G: 'static + Send + Sync + Fn(&[u8]) -> Result<Msg>,
+        H: 'static + Send + Sync + Fn(&Msg::Reply, &mut Vec<u8>) -> Result<()>,
     {
         let mut guard = self.handles.write().await;
         let act_id = {
@@ -104,29 +110,40 @@ impl<Rx: Receiver> Runtime<Rx> {
         };
         let (tx, rx) = mpsc::channel::<Envelope<Msg>>(MAILBOX_LIMIT);
         // The deliverer closure is responsible for deserializing messages received over the wire
-        // and delivering them to the actor's mailbox. It's also responsible for sending replies to
-        // call messages.
+        // and delivering them to the actor's mailbox and sending replies to call messages.
         let deliverer = {
+            let buffer = Arc::new(Mutex::new(Vec::<u8>::new()));
+            let serializer = Arc::new(serializer);
             let tx = tx.clone();
             move |envelope: WireEnvelope| {
-                let (msg, replier) = envelope.into_parts();
-                let result = deserializer(msg);
-                let tx_clone = tx.clone();
+                let (wire_msg, replier) = envelope.into_parts();
+                let result = deserializer(wire_msg.payload);
+                let buffer = buffer.clone();
+                let serializer = serializer.clone();
+                let tx = tx.clone();
                 let fut: FutureResult = Box::pin(async move {
                     let msg = result?;
                     if let Some(mut replier) = replier {
                         let (envelope, rx) = Envelope::call(msg);
-                        tx_clone.send(envelope).await.map_err(|_| {
+                        tx.send(envelope).await.map_err(|_| {
                             bterr!("failed to deliver message. Recipient may have halted.")
                         })?;
-                        // TODO: `reply` does not have the right type.
-                        // It needs to be WireEnvelope::Reply.
                         match rx.await {
-                            Ok(reply) => replier.reply(reply).await,
+                            Ok(reply) => {
+                                let mut guard = buffer.lock().await;
+                                guard.clear();
+                                serializer(&reply, &mut guard)?;
+                                let wire_reply = WireMsg {
+                                    to: wire_msg.from,
+                                    from: act_id,
+                                    payload: &guard,
+                                };
+                                replier.reply(wire_reply).await
+                            }
                             Err(err) => replier.reply_err(err.to_string(), None).await,
                         }
                     } else {
-                        tx_clone.send(Envelope::Send { msg }).await.map_err(|_| {
+                        tx.send(Envelope::Send { msg }).await.map_err(|_| {
                             bterr!("failed to deliver message. Recipient may have halted.")
                         })
                     }
@@ -157,6 +174,39 @@ impl<Rx: Receiver> Runtime<Rx> {
     }
 }
 
+/// This struct implements the server callback for network messages.
+#[derive(Clone)]
+struct RuntimeCallback {
+    handles: Arc<RwLock<HashMap<Uuid, ActorHandle>>>,
+}
+
+impl RuntimeCallback {
+    fn new(handles: Arc<RwLock<HashMap<Uuid, ActorHandle>>>) -> Self {
+        Self { handles }
+    }
+}
+
+impl MsgCallback for RuntimeCallback {
+    type Arg<'de> = WireMsg<'de>;
+    type CallFut<'de> = impl 'de + Future<Output = Result<()>>;
+    fn call<'de>(&'de self, arg: btmsg::MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
+        async move {
+            let (_, body, replier) = arg.into_parts();
+            let guard = self.handles.read().await;
+            if let Some(handle) = guard.get(&body.to) {
+                let envelope = if let Some(replier) = replier {
+                    WireEnvelope::Call { msg: body, replier }
+                } else {
+                    WireEnvelope::Send { msg: body }
+                };
+                (handle.deliverer)(envelope).await
+            } else {
+                Err(bterr!("invalid actor ID: {}", body.to))
+            }
+        }
+    }
+}
+
 /// A unique identifier for a particular service.
 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
 pub struct ServiceId(#[serde(with = "smart_ptr")] Arc<String>);
@@ -224,17 +274,41 @@ impl<'a, T: CallMsg> btmsg::CallMsg<'a> for Adapter<T> {
 
 impl<'a, T: SendMsg> btmsg::SendMsg<'a> for Adapter<T> {}
 
+impl<T: CallMsg> From<T> for Adapter<T> {
+    fn from(value: T) -> Self {
+        Self(value)
+    }
+}
+
 /// The maximum number of messages which can be kept in an actor's mailbox.
 const MAILBOX_LIMIT: usize = 32;
 
-/// The type of messages sent remotely.
+/// The type of messages sent over the wire between runtime.
+#[derive(Serialize, Deserialize)]
+struct WireMsg<'a> {
+    to: Uuid,
+    from: Uuid,
+    payload: &'a [u8],
+}
+
+impl<'a> btmsg::CallMsg<'a> for WireMsg<'a> {
+    type Reply<'r> = WireReply<'r>;
+}
+
+#[derive(Serialize, Deserialize)]
+enum WireReply<'a> {
+    Ok(&'a [u8]),
+    Err(&'a str),
+}
+
+/// A wrapper around [WireMsg] which indicates whether a call or send was executed.
 enum WireEnvelope<'de> {
-    Send { msg: &'de [u8] },
-    Call { msg: &'de [u8], replier: Replier },
+    Send { msg: WireMsg<'de> },
+    Call { msg: WireMsg<'de>, replier: Replier },
 }
 
 impl<'de> WireEnvelope<'de> {
-    fn into_parts(self) -> (&'de [u8], Option<Replier>) {
+    fn into_parts(self) -> (WireMsg<'de>, Option<Replier>) {
         match self {
             Self::Send { msg } => (msg, None),
             Self::Call { msg, replier } => (msg, Some(replier)),
@@ -317,68 +391,18 @@ impl ActorHandle {
     }
 }
 
-#[derive(Serialize, Deserialize)]
-enum WireReply<'a> {
-    Ok(&'a [u8]),
-    Err(&'a str),
-}
-
-#[derive(Serialize, Deserialize)]
-struct WireMsg<'a> {
-    act_id: Uuid,
-    payload: &'a [u8],
-}
-
-impl<'a> btmsg::CallMsg<'a> for WireMsg<'a> {
-    type Reply<'r> = WireReply<'r>;
-}
-
-/// This struct implements the server callback for network messages.
-#[derive(Clone)]
-struct RuntimeCallback {
-    handles: Arc<RwLock<HashMap<Uuid, ActorHandle>>>,
-}
-
-impl RuntimeCallback {
-    fn new(handles: Arc<RwLock<HashMap<Uuid, ActorHandle>>>) -> Self {
-        Self { handles }
-    }
-}
-
-impl MsgCallback for RuntimeCallback {
-    type Arg<'de> = WireMsg<'de>;
-    type CallFut<'de> = impl 'de + Future<Output = Result<()>>;
-    fn call<'de>(&'de self, mut arg: btmsg::MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
-        async move {
-            let replier = arg.take_replier();
-            let body = arg.body();
-            let guard = self.handles.read().await;
-            if let Some(handle) = guard.get(&body.act_id) {
-                let envelope = if let Some(replier) = replier {
-                    WireEnvelope::Call {
-                        msg: body.payload,
-                        replier,
-                    }
-                } else {
-                    WireEnvelope::Send { msg: body.payload }
-                };
-                (handle.deliverer)(envelope).await
-            } else {
-                Err(bterr!("invalid actor ID: {}", body.act_id))
-            }
-        }
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use std::net::IpAddr;
 
     use super::*;
 
-    use btlib::crypto::CredStore;
+    use btlib::{
+        crypto::{CredStore, CredsPriv},
+    };
     use btlib_tests::TEST_STORE;
-    use btserde::from_slice;
+    use btmsg::BlockAddr;
+    use btserde::{from_slice, write_to};
 
     #[derive(Serialize, Deserialize)]
     struct EchoMsg(String);
@@ -401,14 +425,42 @@ mod tests {
         from_slice(slice).map_err(|err| err.into())
     }
 
+    fn echo_serializer(msg: &EchoMsg, buf: &mut Vec<u8>) -> Result<()> {
+        write_to(msg, buf).map_err(|err| err.into())
+    }
+
     #[tokio::test]
     async fn local_call() {
         const EXPECTED: &str = "hello";
         let ip_addr = IpAddr::from([127, 0, 0, 1]);
         let creds = TEST_STORE.node_creds().unwrap();
         let runtime = new_runtime(ip_addr, creds).unwrap();
-        let name = runtime.activate(echo, echo_deserializer).await;
+        let name = runtime
+            .activate(echo, echo_deserializer, echo_serializer)
+            .await;
         let reply = runtime.call(&name, EchoMsg(EXPECTED.into())).await.unwrap();
         assert_eq!(EXPECTED, reply.0)
     }
+
+    //#[tokio::test]
+    #[allow(dead_code)]
+    async fn remote_call() {
+        btlib::log::init();
+
+        const EXPECTED: &str = "hello";
+        let ip_addr = IpAddr::from([127, 0, 0, 1]);
+        let creds = TEST_STORE.node_creds().unwrap();
+        let runtime = new_runtime(ip_addr, creds.clone()).unwrap();
+        runtime
+            .activate(echo, echo_deserializer, echo_serializer)
+            .await;
+        let bind_path = Arc::new(creds.bind_path().unwrap());
+        let block_addr = Arc::new(BlockAddr::new(ip_addr, bind_path));
+        let transmitter = btmsg::transmitter(block_addr, creds).await.unwrap();
+        let reply = transmitter
+            .call_through(Adapter(EchoMsg(EXPECTED.into())))
+            .await
+            .unwrap();
+        assert_eq!(EXPECTED, reply.0);
+    }
 }