Explorar o código

* Modified the `activate` method to pass a pointer to the `Runtime`
to new actors.
* Added the requirement that a `Runtime` must have a `'static` lifetime, so that
a simple reference can be passed to actors.

Matthew Carr hai 1 ano
pai
achega
49538e26a9
Modificáronse 3 ficheiros con 199 adicións e 56 borrados
  1. 2 0
      Cargo.lock
  2. 3 1
      crates/btrun/Cargo.toml
  3. 194 55
      crates/btrun/src/lib.rs

+ 2 - 0
Cargo.lock

@@ -461,6 +461,8 @@ dependencies = [
  "ctor",
  "env_logger",
  "futures",
+ "lazy_static",
+ "log",
  "serde",
  "strum",
  "tokio",

+ 3 - 1
crates/btrun/Cargo.toml

@@ -9,13 +9,15 @@ edition = "2021"
 btlib = { path = "../btlib" }
 btmsg = { path = "../btmsg" }
 btserde = { path = "../btserde" }
-tokio = { version = "1.23.0", features = ["rt"] }
+tokio = { version = "1.23.0", features = ["rt-multi-thread"] }
 futures = "0.3.25"
 serde = { version = "^1.0.136", features = ["derive"] }
 uuid = { version = "1.3.3", features = ["v4", "fast-rng", "macro-diagnostics", "serde"] }
 anyhow = { version = "1.0.66", features = ["std", "backtrace"] }
 bytes = "1.3.0"
 strum = { version = "^0.24.0", features = ["derive"] }
+lazy_static = { version = "1.4.0" }
+log = "0.4.17"
 
 [dev-dependencies]
 btlib-tests = { path = "../btlib-tests" }

+ 194 - 55
crates/btrun/src/lib.rs

@@ -3,6 +3,7 @@
 use std::{
     any::Any,
     collections::HashMap,
+    fmt::Display,
     future::{ready, Future, Ready},
     marker::PhantomData,
     net::IpAddr,
@@ -21,9 +22,21 @@ use tokio::{
 };
 use uuid::Uuid;
 
-/// Creates a new [Runtime] instance which listens at the given IP address and which uses the given
-/// credentials.
-pub fn new_runtime<C: 'static + Send + Sync + Creds>(
+/// Declares a new [Runtime] which listens for messages at the given IP address and uses the given
+/// [Creds]. Runtimes are intended to be created once in a process's lifetime and continue running
+/// until the process exits.
+#[macro_export]
+macro_rules! declare_runtime {
+    ($name:ident, $ip_addr:expr, $creds:expr) => {
+        ::lazy_static::lazy_static! {
+            static ref $name: Runtime =  _new_runtime($ip_addr, $creds).unwrap();
+        }
+    };
+}
+
+///  This function is not intended to be called directly by downstream crates. Use the macro
+/// [declare_runtime] to create a [Runtime] instead.
+pub fn _new_runtime<C: 'static + Send + Sync + Creds>(
     ip_addr: IpAddr,
     creds: Arc<C>,
 ) -> Result<Runtime> {
@@ -41,10 +54,10 @@ pub fn new_runtime<C: 'static + Send + Sync + Creds>(
 
 /// An actor runtime.
 ///
-/// Actors can be activated by the runtime and execute autonomously until they halt. Running actors
-/// can be sent messages using the `send` method, which does not wait for a response from the
-/// recipient. If a reply is needed, then `call` can be used, which returns a future that will not
-/// be ready until the reply has been received.
+/// Actors can be activated by the runtime and execute autonomously until they return. Running
+/// actors can be sent messages using the `send` method, which does not wait for a response from the
+/// recipient. If a reply is needed, then `call` can be used, which returns a future that will
+/// be ready when the reply has been received.
 pub struct Runtime {
     _rx: Receiver,
     path: Arc<BlockPath>,
@@ -57,6 +70,12 @@ impl Runtime {
         &self.path
     }
 
+    /// Returns the number of actors that are currently executing in this [Runtime].
+    pub async fn num_running(&self) -> usize {
+        let guard = self.handles.read().await;
+        guard.len()
+    }
+
     /// Sends a message to the actor identified by the given [ActorName].
     pub async fn send<T: 'static + SendMsg>(
         &self,
@@ -128,11 +147,11 @@ impl Runtime {
     }
 
     /// Activates a new actor using the given activator function and returns a handle to it.
-    pub async fn activate<Msg, F, Fut>(&self, activator: F) -> ActorName
+    pub async fn activate<Msg, F, Fut>(&'static self, activator: F) -> ActorName
     where
         Msg: 'static + CallMsg,
         Fut: 'static + Send + Future<Output = ()>,
-        F: FnOnce(mpsc::Receiver<Envelope<Msg>>, Uuid) -> Fut,
+        F: FnOnce(&'static Runtime, mpsc::Receiver<Envelope<Msg>>, Uuid) -> Fut,
     {
         let mut guard = self.handles.write().await;
         let act_id = {
@@ -179,7 +198,7 @@ impl Runtime {
                 fut
             }
         };
-        let handle = tokio::task::spawn(activator(rx, act_id));
+        let handle = tokio::task::spawn(activator(self, rx, act_id));
         let actor_handle = ActorHandle::new(handle, tx, deliverer);
         guard.insert(act_id, actor_handle);
         ActorName::new(self.path.clone(), act_id)
@@ -200,8 +219,50 @@ impl Runtime {
     {
         todo!()
     }
+
+    /// Returns the [ActorHandle] for the actor with the given name.
+    ///
+    /// If there is no such actor in this runtime then a [RuntimeError::BadActorName] error is
+    /// returned.
+    ///
+    /// Note that the actor will be aborted when the given handle is dropped (unless it has already
+    /// returned when the handle is dropped), and no further messages will be delivered to it by
+    /// this runtime.
+    pub async fn take(&self, name: &ActorName) -> Result<ActorHandle> {
+        if name.path == self.path {
+            let mut guard = self.handles.write().await;
+            if let Some(handle) = guard.remove(&name.act_id) {
+                Ok(handle)
+            } else {
+                Err(RuntimeError::BadActorName(name.clone()).into())
+            }
+        } else {
+            Err(RuntimeError::BadActorName(name.clone()).into())
+        }
+    }
+}
+
+impl Drop for Runtime {
+    fn drop(&mut self) {
+        panic!("A Runtime was dropped. Panicking to avoid undefined behavior.");
+    }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum RuntimeError {
+    BadActorName(ActorName),
 }
 
+impl Display for RuntimeError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Self::BadActorName(name) => write!(f, "bad actor name: {name}"),
+        }
+    }
+}
+
+impl std::error::Error for RuntimeError {}
+
 /// Deserializes replies sent over the wire.
 struct ReplyCallback<T> {
     _phantom: PhantomData<T>,
@@ -307,6 +368,12 @@ impl ActorName {
     }
 }
 
+impl Display for ActorName {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{}@{}", self.act_id, self.path)
+    }
+}
+
 /// Trait for messages which expect exactly one reply.
 pub trait CallMsg: Serialize + DeserializeOwned + Send + Sync {
     /// The reply type expected for this message.
@@ -383,8 +450,8 @@ impl<T: CallMsg> Envelope<T> {
 
 type FutureResult = Pin<Box<dyn Send + Future<Output = Result<()>>>>;
 
-struct ActorHandle {
-    _handle: JoinHandle<()>,
+pub struct ActorHandle {
+    handle: Option<JoinHandle<()>>,
     sender: Box<dyn Send + Sync + Any>,
     deliverer: Box<dyn Send + Sync + Fn(WireEnvelope<'_>) -> FutureResult>,
 }
@@ -396,7 +463,7 @@ impl ActorHandle {
         F: 'static + Send + Sync + Fn(WireEnvelope<'_>) -> FutureResult,
     {
         Self {
-            _handle: handle,
+            handle: Some(handle),
             sender: Box::new(sender),
             deliverer: Box::new(deliverer),
         }
@@ -408,7 +475,7 @@ impl ActorHandle {
             .ok_or_else(|| bterr!("unexpected message type"))
     }
 
-    async fn send<T: 'static + SendMsg>(&self, msg: T) -> Result<()> {
+    pub async fn send<T: 'static + SendMsg>(&self, msg: T) -> Result<()> {
         let sender = self.sender()?;
         sender
             .send(Envelope::send(msg))
@@ -417,7 +484,7 @@ impl ActorHandle {
         Ok(())
     }
 
-    async fn call_through<T: 'static + CallMsg>(&self, msg: T) -> Result<T::Reply> {
+    pub async fn call_through<T: 'static + CallMsg>(&self, msg: T) -> Result<T::Reply> {
         let sender = self.sender()?;
         let (envelope, rx) = Envelope::call(msg);
         sender
@@ -427,6 +494,25 @@ impl ActorHandle {
         let reply = rx.await?;
         Ok(reply)
     }
+
+    pub async fn returned(&mut self) -> Result<()> {
+        if let Some(handle) = self.handle.take() {
+            handle.await?;
+        }
+        Ok(())
+    }
+
+    pub fn abort(&mut self) {
+        if let Some(handle) = self.handle.take() {
+            handle.abort();
+        }
+    }
+}
+
+impl Drop for ActorHandle {
+    fn drop(&mut self) {
+        self.abort();
+    }
 }
 
 #[cfg(test)]
@@ -434,21 +520,48 @@ mod tests {
     use super::*;
 
     use btlib::{
-        crypto::{CredStore, CredsPriv},
+        crypto::{ConcreteCreds, CredStore, CredsPriv},
         log::BuilderExt,
     };
     use btlib_tests::TEST_STORE;
     use btmsg::BlockAddr;
     use btserde::to_vec;
     use ctor::ctor;
-    use std::net::IpAddr;
+    use lazy_static::lazy_static;
+    use std::net::{IpAddr, Ipv4Addr};
+    use tokio::runtime::Builder;
+
+    const RUNTIME_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
+    lazy_static! {
+        static ref RUNTIME_CREDS: Arc<ConcreteCreds> = TEST_STORE.node_creds().unwrap();
+    }
+    declare_runtime!(RUNTIME, RUNTIME_ADDR, RUNTIME_CREDS.clone());
+
+    lazy_static! {
+        /// A tokio async runtime.
+        ///
+        /// When the `#[tokio::test]` attribute is used on a test, a new current thread runtime
+        /// is created for each test
+        /// (source: https://docs.rs/tokio/latest/tokio/attr.test.html#current-thread-runtime).
+        /// This creates a problem, because the first test thread to access the `RUNTIME` static
+        /// will initialize its `Receiver` in its runtime, which will stop running at the end of
+        /// the test. Hence subsequent tests will not be able to send remote messages to this
+        /// `Runtime`.
+        ///
+        /// By creating a single async runtime which is used by all of the tests, we can avoid this
+        /// problem.
+        static ref ASYNC_RT: tokio::runtime::Runtime = Builder::new_current_thread()
+            .enable_all()
+            .build()
+            .unwrap();
+    }
 
     /// The log level to use when running tests.
     const LOG_LEVEL: &str = "warn";
 
     #[ctor]
     fn ctor() {
-        std::env::set_var("RUST_LOG", LOG_LEVEL);
+        std::env::set_var("RUST_LOG", format!("{},quinn=WARN", LOG_LEVEL));
         env_logger::Builder::from_default_env().btformat().init();
     }
 
@@ -459,7 +572,11 @@ mod tests {
         type Reply = EchoMsg;
     }
 
-    async fn echo(mut mailbox: mpsc::Receiver<Envelope<EchoMsg>>, _act_id: Uuid) {
+    async fn echo(
+        _rt: &'static Runtime,
+        mut mailbox: mpsc::Receiver<Envelope<EchoMsg>>,
+        _act_id: Uuid,
+    ) {
         while let Some(msg) = mailbox.recv().await {
             if let Envelope::Call { msg, reply } = msg {
                 if let Err(_) = reply.send(msg) {
@@ -469,45 +586,67 @@ mod tests {
         }
     }
 
-    #[tokio::test]
-    async fn local_call() {
-        const EXPECTED: &str = "hello";
-        let ip_addr = IpAddr::from([127, 0, 0, 1]);
-        let creds = TEST_STORE.node_creds().unwrap();
-        let runtime = new_runtime(ip_addr, creds).unwrap();
-        let name = runtime.activate(echo).await;
-
-        let reply = runtime
-            .call(&name, Uuid::default(), EchoMsg(EXPECTED.into()))
-            .await
-            .unwrap();
+    #[test]
+    fn local_call() {
+        ASYNC_RT.block_on(async {
+            const EXPECTED: &str = "hello";
+            let name = RUNTIME.activate(echo).await;
 
-        assert_eq!(EXPECTED, reply.0)
-    }
+            let reply = RUNTIME
+                .call(&name, Uuid::default(), EchoMsg(EXPECTED.into()))
+                .await
+                .unwrap();
 
-    #[tokio::test]
-    async fn remote_call() {
-        const EXPECTED: &str = "hello";
-        let ip_addr = IpAddr::from([127, 0, 0, 2]);
-        let creds = TEST_STORE.node_creds().unwrap();
-        let runtime = new_runtime(ip_addr, creds.clone()).unwrap();
-        let actor_name = runtime.activate(echo).await;
-        let bind_path = Arc::new(creds.bind_path().unwrap());
-        let block_addr = Arc::new(BlockAddr::new(ip_addr, bind_path));
-        let transmitter = Transmitter::new(block_addr, creds).await.unwrap();
-        let buf = to_vec(&EchoMsg(EXPECTED.to_string())).unwrap();
-        let wire_msg = WireMsg {
-            to: actor_name.act_id,
-            from: Uuid::default(),
-            payload: &buf,
-        };
+            assert_eq!(EXPECTED, reply.0);
 
-        let reply = transmitter
-            .call(wire_msg, ReplyCallback::<EchoMsg>::new())
-            .await
-            .unwrap()
-            .unwrap();
+            RUNTIME.take(&name).await.unwrap();
+        })
+    }
 
-        assert_eq!(EXPECTED, reply.0);
+    #[test]
+    fn remote_call() {
+        ASYNC_RT.block_on(async {
+            const EXPECTED: &str = "hello";
+            let actor_name = RUNTIME.activate(echo).await;
+            let bind_path = Arc::new(RUNTIME_CREDS.bind_path().unwrap());
+            let block_addr = Arc::new(BlockAddr::new(RUNTIME_ADDR, bind_path));
+            let transmitter = Transmitter::new(block_addr, RUNTIME_CREDS.clone())
+                .await
+                .unwrap();
+            let buf = to_vec(&EchoMsg(EXPECTED.to_string())).unwrap();
+            let wire_msg = WireMsg {
+                to: actor_name.act_id,
+                from: Uuid::default(),
+                payload: &buf,
+            };
+
+            let reply = transmitter
+                .call(wire_msg, ReplyCallback::<EchoMsg>::new())
+                .await
+                .unwrap()
+                .unwrap();
+
+            assert_eq!(EXPECTED, reply.0);
+
+            RUNTIME.take(&actor_name).await.unwrap();
+        });
+    }
+
+    /// Tests the `num_running` method.
+    ///
+    /// This test uses its own runtime and so can use the `#[tokio::test]` attribute.
+    #[tokio::test]
+    async fn num_running() {
+        declare_runtime!(
+            LOCAL_RT,
+            // This needs to be different from the address where `RUNTIME` is listening.
+            IpAddr::from([127, 0, 0, 2]),
+            TEST_STORE.node_creds().unwrap()
+        );
+        assert_eq!(0, LOCAL_RT.num_running().await);
+        let name = LOCAL_RT.activate(echo).await;
+        assert_eq!(1, LOCAL_RT.num_running().await);
+        LOCAL_RT.take(&name).await.unwrap();
+        assert_eq!(0, LOCAL_RT.num_running().await);
     }
 }