Эх сурвалжийг харах

Resolved the test failure for remote actor calls.

Matthew Carr 2 жил өмнө
parent
commit
e520c7c2b0

+ 1 - 0
Cargo.lock

@@ -458,6 +458,7 @@ dependencies = [
  "btmsg",
  "btmsg",
  "btserde",
  "btserde",
  "bytes",
  "bytes",
+ "ctor",
  "env_logger",
  "env_logger",
  "futures",
  "futures",
  "serde",
  "serde",

+ 1 - 0
crates/btrun/Cargo.toml

@@ -20,3 +20,4 @@ strum = { version = "^0.24.0", features = ["derive"] }
 [dev-dependencies]
 [dev-dependencies]
 btlib-tests = { path = "../btlib-tests" }
 btlib-tests = { path = "../btlib-tests" }
 env_logger = { version = "0.9.0" }
 env_logger = { version = "0.9.0" }
+ctor = { version = "0.1.22" }

+ 134 - 87
crates/btrun/src/lib.rs

@@ -1,10 +1,19 @@
 #![feature(impl_trait_in_assoc_type)]
 #![feature(impl_trait_in_assoc_type)]
 
 
-use std::{any::Any, collections::HashMap, future::Future, net::IpAddr, pin::Pin, sync::Arc};
+use std::{
+    any::Any,
+    collections::HashMap,
+    future::{ready, Future, Ready},
+    marker::PhantomData,
+    net::IpAddr,
+    ops::DerefMut,
+    pin::Pin,
+    sync::Arc,
+};
 
 
-use btlib::{bterr, crypto::Creds, BlockPath, Result};
-use btmsg::{MsgCallback, Receiver, Replier, Transmitter};
-use btserde::field_helpers::smart_ptr;
+use btlib::{bterr, crypto::Creds, error::StringError, BlockPath, Result};
+use btmsg::{DeserCallback, MsgCallback, Receiver, Replier, Transmitter};
+use btserde::{field_helpers::smart_ptr, from_slice, to_vec, write_to};
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use tokio::{
 use tokio::{
     sync::{mpsc, oneshot, Mutex, RwLock},
     sync::{mpsc, oneshot, Mutex, RwLock},
@@ -43,42 +52,74 @@ pub struct Runtime<Rx: Receiver> {
     peers: RwLock<HashMap<Arc<BlockPath>, Rx::Transmitter>>,
     peers: RwLock<HashMap<Arc<BlockPath>, Rx::Transmitter>>,
 }
 }
 
 
-macro_rules! deliver {
-    ($self:expr, $to:expr, $msg:expr, $method:ident) => {
-        if $to.path == $self.path {
-            let guard = $self.handles.read().await;
-            if let Some(handle) = guard.get(&$to.act_id) {
-                handle.$method($msg).await
+impl<Rx: Receiver> Runtime<Rx> {
+    pub fn path(&self) -> &Arc<BlockPath> {
+        &self.path
+    }
+
+    /// Sends a message to the actor identified by the given [ActorName].
+    pub async fn send<T: 'static + SendMsg>(
+        &self,
+        to: &ActorName,
+        from: Uuid,
+        msg: T,
+    ) -> Result<()> {
+        if to.path == self.path {
+            let guard = self.handles.read().await;
+            if let Some(handle) = guard.get(&to.act_id) {
+                handle.send(msg).await
             } else {
             } else {
                 Err(bterr!("invalid actor name"))
                 Err(bterr!("invalid actor name"))
             }
             }
         } else {
         } else {
-            let guard = $self.peers.read().await;
-            if let Some(peer) = guard.get(&$to.path) {
-                peer.$method(Adapter($msg)).await
+            let guard = self.peers.read().await;
+            if let Some(peer) = guard.get(&to.path) {
+                let buf = to_vec(&msg)?;
+                let wire_msg = WireMsg {
+                    to: to.act_id,
+                    from,
+                    payload: &buf,
+                };
+                peer.send(wire_msg).await
             } else {
             } else {
                 // TODO: Use the filesystem to discover the address of the recipient and connect to
                 // TODO: Use the filesystem to discover the address of the recipient and connect to
                 // it.
                 // it.
                 todo!()
                 todo!()
             }
             }
         }
         }
-    };
-}
-
-impl<Rx: Receiver> Runtime<Rx> {
-    pub fn path(&self) -> &Arc<BlockPath> {
-        &self.path
-    }
-
-    /// Sends a message to the actor identified by the given [ActorName].
-    pub async fn send<T: 'static + SendMsg>(&self, to: &ActorName, msg: T) -> Result<()> {
-        deliver!(self, to, msg, send)
     }
     }
 
 
     /// Sends a message to the actor identified by the given [ActorName] and returns a future which
     /// Sends a message to the actor identified by the given [ActorName] and returns a future which
-    /// is ready when the reply has been received.
-    pub async fn call<T: 'static + CallMsg>(&self, to: &ActorName, msg: T) -> Result<T::Reply> {
-        deliver!(self, to, msg, call_through)
+    /// is ready when a reply has been received.
+    pub async fn call<T: 'static + CallMsg>(
+        &self,
+        to: &ActorName,
+        from: Uuid,
+        msg: T,
+    ) -> Result<T::Reply> {
+        if to.path == self.path {
+            let guard = self.handles.read().await;
+            if let Some(handle) = guard.get(&to.act_id) {
+                handle.call_through(msg).await
+            } else {
+                Err(bterr!("invalid actor name"))
+            }
+        } else {
+            let guard = self.peers.read().await;
+            if let Some(peer) = guard.get(&to.path) {
+                let buf = to_vec(&msg)?;
+                let wire_msg = WireMsg {
+                    to: to.act_id,
+                    from,
+                    payload: &buf,
+                };
+                peer.call(wire_msg, ReplyCallback::<T>::new()).await?
+            } else {
+                // TODO: Use the filesystem to discover the address of the recipient and connect to
+                // it.
+                todo!()
+            }
+        }
     }
     }
 
 
     /// Resolves the given [ServiceName] to an [ActorName] which is part of it.
     /// Resolves the given [ServiceName] to an [ActorName] which is part of it.
@@ -87,18 +128,11 @@ impl<Rx: Receiver> Runtime<Rx> {
     }
     }
 
 
     /// Activates a new actor using the given activator function and returns a handle to it.
     /// Activates a new actor using the given activator function and returns a handle to it.
-    pub async fn activate<Msg, F, Fut, G, H>(
-        &self,
-        activator: F,
-        deserializer: G,
-        serializer: H,
-    ) -> ActorName
+    pub async fn activate<Msg, F, Fut>(&self, activator: F) -> ActorName
     where
     where
         Msg: 'static + CallMsg,
         Msg: 'static + CallMsg,
         Fut: 'static + Send + Future<Output = ()>,
         Fut: 'static + Send + Future<Output = ()>,
         F: FnOnce(mpsc::Receiver<Envelope<Msg>>, Uuid) -> Fut,
         F: FnOnce(mpsc::Receiver<Envelope<Msg>>, Uuid) -> Fut,
-        G: 'static + Send + Sync + Fn(&[u8]) -> Result<Msg>,
-        H: 'static + Send + Sync + Fn(&Msg::Reply, &mut Vec<u8>) -> Result<()>,
     {
     {
         let mut guard = self.handles.write().await;
         let mut guard = self.handles.write().await;
         let act_id = {
         let act_id = {
@@ -113,13 +147,11 @@ impl<Rx: Receiver> Runtime<Rx> {
         // and delivering them to the actor's mailbox and sending replies to call messages.
         // and delivering them to the actor's mailbox and sending replies to call messages.
         let deliverer = {
         let deliverer = {
             let buffer = Arc::new(Mutex::new(Vec::<u8>::new()));
             let buffer = Arc::new(Mutex::new(Vec::<u8>::new()));
-            let serializer = Arc::new(serializer);
             let tx = tx.clone();
             let tx = tx.clone();
             move |envelope: WireEnvelope| {
             move |envelope: WireEnvelope| {
                 let (wire_msg, replier) = envelope.into_parts();
                 let (wire_msg, replier) = envelope.into_parts();
-                let result = deserializer(wire_msg.payload);
+                let result = from_slice(wire_msg.payload);
                 let buffer = buffer.clone();
                 let buffer = buffer.clone();
-                let serializer = serializer.clone();
                 let tx = tx.clone();
                 let tx = tx.clone();
                 let fut: FutureResult = Box::pin(async move {
                 let fut: FutureResult = Box::pin(async move {
                     let msg = result?;
                     let msg = result?;
@@ -132,12 +164,8 @@ impl<Rx: Receiver> Runtime<Rx> {
                             Ok(reply) => {
                             Ok(reply) => {
                                 let mut guard = buffer.lock().await;
                                 let mut guard = buffer.lock().await;
                                 guard.clear();
                                 guard.clear();
-                                serializer(&reply, &mut guard)?;
-                                let wire_reply = WireMsg {
-                                    to: wire_msg.from,
-                                    from: act_id,
-                                    payload: &guard,
-                                };
+                                write_to(&reply, guard.deref_mut())?;
+                                let wire_reply = WireReply::Ok(&guard);
                                 replier.reply(wire_reply).await
                                 replier.reply(wire_reply).await
                             }
                             }
                             Err(err) => replier.reply_err(err.to_string(), None).await,
                             Err(err) => replier.reply_err(err.to_string(), None).await,
@@ -174,6 +202,32 @@ impl<Rx: Receiver> Runtime<Rx> {
     }
     }
 }
 }
 
 
+/// Deserializes replies sent over the wire.
+struct ReplyCallback<T> {
+    _phantom: PhantomData<T>,
+}
+
+impl<T: CallMsg> ReplyCallback<T> {
+    fn new() -> Self {
+        Self {
+            _phantom: PhantomData,
+        }
+    }
+}
+
+impl<T: CallMsg> DeserCallback for ReplyCallback<T> {
+    type Arg<'de> = WireReply<'de> where T: 'de;
+    type Return = Result<T::Reply>;
+    type CallFut<'de> = Ready<Self::Return> where T: 'de, T::Reply: 'de;
+    fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
+        let result = match arg {
+            WireReply::Ok(slice) => from_slice(slice).map_err(|err| err.into()),
+            WireReply::Err(msg) => Err(StringError::new(msg.to_string()).into()),
+        };
+        ready(result)
+    }
+}
+
 /// This struct implements the server callback for network messages.
 /// This struct implements the server callback for network messages.
 #[derive(Clone)]
 #[derive(Clone)]
 struct RuntimeCallback {
 struct RuntimeCallback {
@@ -262,28 +316,10 @@ pub trait CallMsg: Serialize + DeserializeOwned + Send + Sync {
 /// Trait for messages which expect exactly zero replies.
 /// Trait for messages which expect exactly zero replies.
 pub trait SendMsg: CallMsg {}
 pub trait SendMsg: CallMsg {}
 
 
-/// An adapter which allows a [CallMsg] to be used sent remotely using a [Transmitter].
-#[derive(Serialize, Deserialize)]
-#[repr(transparent)]
-#[serde(transparent)]
-struct Adapter<T>(T);
-
-impl<'a, T: CallMsg> btmsg::CallMsg<'a> for Adapter<T> {
-    type Reply<'r> = T::Reply;
-}
-
-impl<'a, T: SendMsg> btmsg::SendMsg<'a> for Adapter<T> {}
-
-impl<T: CallMsg> From<T> for Adapter<T> {
-    fn from(value: T) -> Self {
-        Self(value)
-    }
-}
-
 /// The maximum number of messages which can be kept in an actor's mailbox.
 /// The maximum number of messages which can be kept in an actor's mailbox.
 const MAILBOX_LIMIT: usize = 32;
 const MAILBOX_LIMIT: usize = 32;
 
 
-/// The type of messages sent over the wire between runtime.
+/// The type of messages sent over the wire between runtimes.
 #[derive(Serialize, Deserialize)]
 #[derive(Serialize, Deserialize)]
 struct WireMsg<'a> {
 struct WireMsg<'a> {
     to: Uuid,
     to: Uuid,
@@ -295,6 +331,8 @@ impl<'a> btmsg::CallMsg<'a> for WireMsg<'a> {
     type Reply<'r> = WireReply<'r>;
     type Reply<'r> = WireReply<'r>;
 }
 }
 
 
+impl<'a> btmsg::SendMsg<'a> for WireMsg<'a> {}
+
 #[derive(Serialize, Deserialize)]
 #[derive(Serialize, Deserialize)]
 enum WireReply<'a> {
 enum WireReply<'a> {
     Ok(&'a [u8]),
     Ok(&'a [u8]),
@@ -393,16 +431,26 @@ impl ActorHandle {
 
 
 #[cfg(test)]
 #[cfg(test)]
 mod tests {
 mod tests {
-    use std::net::IpAddr;
-
     use super::*;
     use super::*;
 
 
     use btlib::{
     use btlib::{
         crypto::{CredStore, CredsPriv},
         crypto::{CredStore, CredsPriv},
+        log::BuilderExt,
     };
     };
     use btlib_tests::TEST_STORE;
     use btlib_tests::TEST_STORE;
     use btmsg::BlockAddr;
     use btmsg::BlockAddr;
-    use btserde::{from_slice, write_to};
+    use btserde::to_vec;
+    use ctor::ctor;
+    use std::net::IpAddr;
+
+    /// The log level to use when running tests.
+    const LOG_LEVEL: &str = "warn";
+
+    #[ctor]
+    fn ctor() {
+        std::env::set_var("RUST_LOG", LOG_LEVEL);
+        env_logger::Builder::from_default_env().btformat().init();
+    }
 
 
     #[derive(Serialize, Deserialize)]
     #[derive(Serialize, Deserialize)]
     struct EchoMsg(String);
     struct EchoMsg(String);
@@ -421,46 +469,45 @@ mod tests {
         }
         }
     }
     }
 
 
-    fn echo_deserializer(slice: &[u8]) -> Result<EchoMsg> {
-        from_slice(slice).map_err(|err| err.into())
-    }
-
-    fn echo_serializer(msg: &EchoMsg, buf: &mut Vec<u8>) -> Result<()> {
-        write_to(msg, buf).map_err(|err| err.into())
-    }
-
     #[tokio::test]
     #[tokio::test]
     async fn local_call() {
     async fn local_call() {
         const EXPECTED: &str = "hello";
         const EXPECTED: &str = "hello";
         let ip_addr = IpAddr::from([127, 0, 0, 1]);
         let ip_addr = IpAddr::from([127, 0, 0, 1]);
         let creds = TEST_STORE.node_creds().unwrap();
         let creds = TEST_STORE.node_creds().unwrap();
         let runtime = new_runtime(ip_addr, creds).unwrap();
         let runtime = new_runtime(ip_addr, creds).unwrap();
-        let name = runtime
-            .activate(echo, echo_deserializer, echo_serializer)
-            .await;
-        let reply = runtime.call(&name, EchoMsg(EXPECTED.into())).await.unwrap();
+        let name = runtime.activate(echo).await;
+
+        let reply = runtime
+            .call(&name, Uuid::default(), EchoMsg(EXPECTED.into()))
+            .await
+            .unwrap();
+
         assert_eq!(EXPECTED, reply.0)
         assert_eq!(EXPECTED, reply.0)
     }
     }
 
 
-    //#[tokio::test]
-    #[allow(dead_code)]
+    #[tokio::test]
     async fn remote_call() {
     async fn remote_call() {
-        btlib::log::init();
-
         const EXPECTED: &str = "hello";
         const EXPECTED: &str = "hello";
-        let ip_addr = IpAddr::from([127, 0, 0, 1]);
+        let ip_addr = IpAddr::from([127, 0, 0, 2]);
         let creds = TEST_STORE.node_creds().unwrap();
         let creds = TEST_STORE.node_creds().unwrap();
         let runtime = new_runtime(ip_addr, creds.clone()).unwrap();
         let runtime = new_runtime(ip_addr, creds.clone()).unwrap();
-        runtime
-            .activate(echo, echo_deserializer, echo_serializer)
-            .await;
+        let actor_name = runtime.activate(echo).await;
         let bind_path = Arc::new(creds.bind_path().unwrap());
         let bind_path = Arc::new(creds.bind_path().unwrap());
         let block_addr = Arc::new(BlockAddr::new(ip_addr, bind_path));
         let block_addr = Arc::new(BlockAddr::new(ip_addr, bind_path));
         let transmitter = btmsg::transmitter(block_addr, creds).await.unwrap();
         let transmitter = btmsg::transmitter(block_addr, creds).await.unwrap();
+        let buf = to_vec(&EchoMsg(EXPECTED.to_string())).unwrap();
+        let wire_msg = WireMsg {
+            to: actor_name.act_id,
+            from: Uuid::default(),
+            payload: &buf,
+        };
+
         let reply = transmitter
         let reply = transmitter
-            .call_through(Adapter(EchoMsg(EXPECTED.into())))
+            .call(wire_msg, ReplyCallback::<EchoMsg>::new())
             .await
             .await
+            .unwrap()
             .unwrap();
             .unwrap();
+
         assert_eq!(EXPECTED, reply.0);
         assert_eq!(EXPECTED, reply.0);
     }
     }
 }
 }