#![feature(impl_trait_in_assoc_type)] use std::{ any::Any, collections::{hash_map, HashMap}, fmt::Display, future::{ready, Future, Ready}, marker::PhantomData, net::IpAddr, ops::DerefMut, pin::Pin, sync::Arc, }; use btlib::{bterr, crypto::Creds, error::StringError, BlockPath, Result}; use btserde::{from_slice, to_vec, write_to}; use bttp::{DeserCallback, MsgCallback, Receiver, Replier, Transmitter}; use kernel::{kernel, SpawnReq}; use serde::{Deserialize, Serialize}; use tokio::{ sync::{mpsc, oneshot, Mutex, RwLock}, task::AbortHandle, }; mod kernel; pub mod model; use model::*; /// Declares a new [Runtime] which listens for messages at the given IP address and uses the given /// [Creds]. Runtimes are intended to be created once in a process's lifetime and continue running /// until the process exits. #[macro_export] macro_rules! declare_runtime { ($name:ident, $ip_addr:expr, $creds:expr) => { ::lazy_static::lazy_static! { static ref $name: &'static $crate::Runtime = { ::lazy_static::lazy_static! { static ref RUNTIME: $crate::Runtime = $crate::Runtime::_new($creds).unwrap(); static ref RECEIVER: ::bttp::Receiver = _new_receiver($ip_addr, $creds, &*RUNTIME); } // By dereferencing RECEIVER we ensure it is started. let _ = &*RECEIVER; &*RUNTIME }; } }; } /// This function is not intended to be called by downstream crates. #[doc(hidden)] pub fn _new_receiver(ip_addr: IpAddr, creds: Arc, runtime: &'static Runtime) -> Receiver where C: 'static + Send + Sync + Creds, { let callback = RuntimeCallback::new(runtime); Receiver::new(ip_addr, creds, callback).unwrap() } /// Type used to implement an actor's mailbox. pub type Mailbox = mpsc::Receiver>; /// An actor runtime. /// /// Actors can be activated by the runtime and execute autonomously until they return. Running /// actors can be sent messages using the `send` method, which does not wait for a response from the /// recipient. If a reply is needed, then `call` can be used, which returns a future that will /// be ready when the reply has been received. pub struct Runtime { path: Arc, handles: RwLock>, peers: RwLock, Transmitter>>, registry: RwLock>, kernel_sender: mpsc::Sender, } impl Runtime { /// The size of the buffer to use for the channel between [Runtime] and [kernel] used for /// spawning tasks. const SPAWN_REQ_BUF_SZ: usize = 16; /// This method is not intended to be called directly by downstream crates. Use the macro /// [declare_runtime] to create a [Runtime]. /// /// If you create a non-static [Runtime], your process will panic when it is dropped. #[doc(hidden)] pub fn _new(creds: Arc) -> Result { let path = Arc::new(creds.bind_path()?); let (sender, receiver) = mpsc::channel(Self::SPAWN_REQ_BUF_SZ); tokio::task::spawn(kernel(receiver)); Ok(Runtime { path, handles: RwLock::new(HashMap::new()), peers: RwLock::new(HashMap::new()), registry: RwLock::new(HashMap::new()), kernel_sender: sender, }) } pub fn path(&self) -> &Arc { &self.path } /// Returns the number of actors that are currently executing in this [Runtime]. pub async fn num_running(&self) -> usize { let guard = self.handles.read().await; guard.len() } /// Sends a message to the actor identified by the given [ActorName]. pub async fn send( &self, to: ActorName, from: ActorName, msg: T, ) -> Result<()> { if to.path().as_ref() == self.path.as_ref() { let guard = self.handles.read().await; if let Some(handle) = guard.get(&to.act_id()) { handle.send(from, msg).await } else { Err(bterr!("invalid actor name")) } } else { let guard = self.peers.read().await; if let Some(peer) = guard.get(to.path()) { let buf = to_vec(&msg)?; let wire_msg = WireMsg { to, from, payload: &buf, }; peer.send(wire_msg).await } else { todo!("Discover the network location of the recipient runtime and connect to it.") } } } /// Sends a message to the service identified by [ServiceName]. pub async fn send_service( &'static self, to: ServiceAddr, from: ActorName, msg: T, ) -> Result<()> { if to.path().as_ref() == self.path.as_ref() { let actor_id = self.service_provider(&to).await?; let handles = self.handles.read().await; if let Some(handle) = handles.get(&actor_id) { handle.send(from, msg).await } else { panic!( "Service record '{}' had a non-existent actor with ID '{}'.", to.service_id(), actor_id ); } } else { todo!("Send the message to an appropriate peer.") } } /// Sends a message to the actor identified by the given [ActorName] and returns a future which /// is ready when a reply has been received. pub async fn call( &self, to: ActorName, from: ActorName, msg: T, ) -> Result { if to.path().as_ref() == self.path.as_ref() { let guard = self.handles.read().await; if let Some(handle) = guard.get(&to.act_id()) { handle.call_through(msg).await } else { Err(bterr!("invalid actor name")) } } else { let guard = self.peers.read().await; if let Some(peer) = guard.get(to.path()) { let buf = to_vec(&msg)?; let wire_msg = WireMsg { to, from, payload: &buf, }; peer.call(wire_msg, ReplyCallback::::new()).await? } else { todo!("Use the filesystem to find the address of the recipient and connect to it.") } } } /// Calls a service identified by [ServiceName]. pub async fn call_service( &'static self, to: ServiceAddr, msg: T, ) -> Result { if to.path().as_ref() == self.path.as_ref() { let actor_id = self.service_provider(&to).await?; let handles = self.handles.read().await; if let Some(handle) = handles.get(&actor_id) { handle.call_through(msg).await } else { panic!( "Service record '{}' had a non-existent actor with ID '{}'.", to.service_id(), actor_id ); } } else { todo!("Send the message to an appropriate peer.") } } fn service_not_registered_err(id: &ServiceId) -> btlib::Error { bterr!("Service is not registered: '{id}'") } async fn service_provider(&'static self, to: &ServiceAddr) -> Result { let actor_id = { let registry = self.registry.read().await; if let Some(record) = registry.get(to.service_id()) { record.actor_ids.first().copied() } else { return Err(Self::service_not_registered_err(to.service_id())); } }; let actor_id = if let Some(actor_id) = actor_id { actor_id } else { let mut registry = self.registry.write().await; if let Some(record) = registry.get_mut(to.service_id()) { // It's possible that another thread got the write lock before us and they // already spawned an actor. if record.actor_ids.is_empty() { let spawner = record.spawner.as_ref(); let actor_name = spawner(self).await?; let actor_id = actor_name.act_id(); record.actor_ids.push(actor_id); actor_id } else { record.actor_ids[0] } } else { return Err(Self::service_not_registered_err(to.service_id())); } }; Ok(actor_id) } /// Spawns a new actor using the given activator function and returns a handle to it. pub async fn spawn(&'static self, activator: F) -> ActorName where Msg: 'static + CallMsg, Fut: 'static + Send + Future, F: FnOnce(&'static Runtime, Mailbox, ActorId) -> Fut, { let mut guard = self.handles.write().await; let act_id = { let mut act_id = ActorId::new(); while guard.contains_key(&act_id) { act_id = ActorId::new(); } act_id }; let act_name = self.actor_name(act_id); let (tx, rx) = mpsc::channel::>(MAILBOX_LIMIT); // The deliverer closure is responsible for deserializing messages received over the wire // and delivering them to the actor's mailbox, as well as sending replies to call messages. let deliverer = { let buffer = Arc::new(Mutex::new(Vec::::new())); let tx = tx.clone(); let act_name = act_name.clone(); move |envelope: WireEnvelope| { let (wire_msg, replier) = envelope.into_parts(); let result = from_slice(wire_msg.payload); let buffer = buffer.clone(); let tx = tx.clone(); let act_name = act_name.clone(); let fut: FutureResult = Box::pin(async move { let msg = result?; if let Some(mut replier) = replier { let (envelope, rx) = Envelope::new_call(msg); tx.send(envelope).await.map_err(|_| { bterr!("failed to deliver message. Recipient may have halted.") })?; match rx.await { Ok(reply) => { let mut guard = buffer.lock().await; guard.clear(); write_to(&reply, guard.deref_mut())?; let wire_reply = WireReply::Ok(&guard); replier.reply(wire_reply).await } Err(err) => replier.reply_err(err.to_string(), None).await, } } else { tx.send(Envelope::new_send(act_name, msg)) .await .map_err(|_| { bterr!("failed to deliver message. Recipient may have halted.") }) } }); fut } }; let (req, receiver) = SpawnReq::new(activator(self, rx, act_id)); self.kernel_sender .send(req) .await .unwrap_or_else(|err| panic!("The kernel has panicked: {err}")); let handle = receiver .await .unwrap_or_else(|err| panic!("Kernel failed to send abort handle: {err}")); let actor_handle = ActorHandle::new(handle, tx, deliverer); guard.insert(act_id, actor_handle); act_name } /// Registers a service activation closure for [ServiceId]. An error is returned if the /// [ServiceId] has already been registered. pub async fn register(&self, id: ServiceId, spawner: F) -> Result where Msg: 'static + CallMsg, F: 'static + Send + Sync + Fn(&'static Runtime) -> Pin>>>, { let mut guard = self.registry.write().await; match guard.entry(id.clone()) { hash_map::Entry::Vacant(entry) => { entry.insert(ServiceRecord::new(spawner)); Ok(ServiceName::new(self.path().clone(), id.clone())) } hash_map::Entry::Occupied(_) => { log::info!("Updated registration for service '{id}'."); Ok(ServiceName::new(self.path().clone(), id)) } } } /// Removes the registration for the service with the given ID. /// /// If a vector reference is given in `service_providers`, the service providers which /// are part of the deregistered service are appended to it. Otherwise, their /// handles are dropped and their tasks are aborted. /// /// A [RuntimeError::BadServiceId] error is returned if there is no service registration with /// the given ID in this runtime. pub async fn deregister( &self, id: &ServiceId, service_providers: Option<&mut Vec>, ) -> Result<()> { let record = { let mut registry = self.registry.write().await; if let Some(record) = registry.remove(id) { record } else { return Err(RuntimeError::BadServiceId(id.clone()).into()); } }; let mut handles = self.handles.write().await; let removed = record .actor_ids .into_iter() .flat_map(|act_id| handles.remove(&act_id)); // If a vector was provided, we put all the removed service providers in it. Otherwise // we just drop them. if let Some(service_providers) = service_providers { service_providers.extend(removed); } else { for _ in removed {} } Ok(()) } /// Returns the [ActorHandle] for the actor with the given name. /// /// If there is no such actor in this runtime then a [RuntimeError::BadActorName] error is /// returned. /// /// Note that the actor will be aborted when the given handle is dropped (unless it has already /// returned when the handle is dropped), and no further messages will be delivered to it by /// this runtime. pub async fn take(&self, name: &ActorName) -> Result { if name.path().as_ref() == self.path.as_ref() { let mut guard = self.handles.write().await; if let Some(handle) = guard.remove(&name.act_id()) { Ok(handle) } else { Err(RuntimeError::BadActorName(name.clone()).into()) } } else { Err(RuntimeError::BadActorName(name.clone()).into()) } } /// Returns the name of the actor in this runtime with the given actor ID. pub fn actor_name(&self, act_id: ActorId) -> ActorName { ActorName::new(self.path.clone(), act_id) } } impl Drop for Runtime { fn drop(&mut self) { panic!("A Runtime was dropped. Panicking to avoid undefined behavior."); } } /// Closure type used to spawn new service providers. type Spawner = Box Pin>>>>; struct ServiceRecord { spawner: Spawner, actor_ids: Vec, } impl ServiceRecord { fn new(spawner: F) -> Self where F: 'static + Send + Sync + Fn(&'static Runtime) -> Pin>>>, { Self { spawner: Box::new(spawner), actor_ids: Vec::new(), } } } #[derive(Debug, Clone, PartialEq, Eq)] pub enum RuntimeError { BadActorName(ActorName), BadServiceId(ServiceId), } impl Display for RuntimeError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::BadActorName(name) => write!(f, "bad actor name: {name}"), Self::BadServiceId(service_id) => { write!(f, "service ID is not registered: {service_id}") } } } } impl std::error::Error for RuntimeError {} /// Deserializes replies sent over the wire. struct ReplyCallback { _phantom: PhantomData, } impl ReplyCallback { fn new() -> Self { Self { _phantom: PhantomData, } } } impl Default for ReplyCallback { fn default() -> Self { Self::new() } } impl DeserCallback for ReplyCallback { type Arg<'de> = WireReply<'de> where T: 'de; type Return = Result; type CallFut<'de> = Ready where T: 'de, T::Reply: 'de; fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> { let result = match arg { WireReply::Ok(slice) => from_slice(slice).map_err(|err| err.into()), WireReply::Err(msg) => Err(StringError::new(msg.to_string()).into()), }; ready(result) } } struct SendReplyCallback { replier: Option, } impl SendReplyCallback { fn new(replier: Replier) -> Self { Self { replier: Some(replier), } } } impl DeserCallback for SendReplyCallback { type Arg<'de> = WireReply<'de>; type Return = Result<()>; type CallFut<'de> = impl 'de + Future; fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> { async move { if let Some(mut replier) = self.replier.take() { replier.reply(arg).await } else { Ok(()) } } } } /// This struct implements the server callback for network messages. #[derive(Clone)] struct RuntimeCallback { rt: &'static Runtime, } impl RuntimeCallback { fn new(rt: &'static Runtime) -> Self { Self { rt } } async fn deliver_local(&self, msg: WireMsg<'_>, replier: Option) -> Result<()> { let guard = self.rt.handles.read().await; if let Some(handle) = guard.get(&msg.to.act_id()) { let envelope = if let Some(replier) = replier { WireEnvelope::Call { msg, replier } } else { WireEnvelope::Send { msg } }; (handle.deliverer)(envelope).await } else { Err(bterr!("invalid actor name: {}", msg.to)) } } async fn route_msg(&self, msg: WireMsg<'_>, replier: Option) -> Result<()> { let guard = self.rt.peers.read().await; if let Some(tx) = guard.get(msg.to.path()) { if let Some(replier) = replier { let callback = SendReplyCallback::new(replier); tx.call(msg, callback).await? } else { tx.send(msg).await } } else { Err(bterr!( "unable to deliver message to peer at '{}'", msg.to.path() )) } } } impl MsgCallback for RuntimeCallback { type Arg<'de> = WireMsg<'de>; type CallFut<'de> = impl 'de + Future>; fn call<'de>(&'de self, arg: bttp::MsgReceived>) -> Self::CallFut<'de> { async move { let (_, body, replier) = arg.into_parts(); if body.to.path() == self.rt.path() { self.deliver_local(body, replier).await } else { self.route_msg(body, replier).await } } } } /// The maximum number of messages which can be kept in an actor's mailbox. const MAILBOX_LIMIT: usize = 32; /// The type of messages sent over the wire between runtimes. #[derive(Serialize, Deserialize)] struct WireMsg<'a> { to: ActorName, from: ActorName, payload: &'a [u8], } impl<'a> WireMsg<'a> { #[allow(dead_code)] fn new(to: ActorName, from: ActorName, payload: &'a [u8]) -> Self { Self { to, from, payload } } } impl<'a> bttp::CallMsg<'a> for WireMsg<'a> { type Reply<'r> = WireReply<'r>; } impl<'a> bttp::SendMsg<'a> for WireMsg<'a> {} #[derive(Serialize, Deserialize)] enum WireReply<'a> { Ok(&'a [u8]), Err(&'a str), } /// A wrapper around [WireMsg] which indicates whether a call or send was executed. enum WireEnvelope<'de> { Send { msg: WireMsg<'de> }, Call { msg: WireMsg<'de>, replier: Replier }, } impl<'de> WireEnvelope<'de> { fn into_parts(self) -> (WireMsg<'de>, Option) { match self { Self::Send { msg } => (msg, None), Self::Call { msg, replier } => (msg, Some(replier)), } } } pub enum EnvelopeKind { Call { reply: Option>, }, Send { from: ActorName, }, } impl EnvelopeKind { pub fn name(&self) -> &'static str { match self { Self::Call { .. } => "Call", Self::Send { .. } => "Send", } } } /// Wrapper around a message type `T` which indicates who the message is from and, if the message /// was dispatched with `call`, provides a channel to reply to it. pub struct Envelope { msg: T, kind: EnvelopeKind, } impl Envelope { pub fn new(msg: T, kind: EnvelopeKind) -> Self { Self { msg, kind } } /// Creates a new envelope containing the given message which does not expect a reply. fn new_send(from: ActorName, msg: T) -> Self { Self { kind: EnvelopeKind::Send { from }, msg, } } /// Creates a new envelope containing the given message which expects exactly one reply. fn new_call(msg: T) -> (Self, oneshot::Receiver) { let (tx, rx) = oneshot::channel::(); let envelope = Self { kind: EnvelopeKind::Call { reply: Some(tx) }, msg, }; (envelope, rx) } /// Returns the name of the actor which sent this message. pub fn from(&self) -> Option<&ActorName> { match &self.kind { EnvelopeKind::Send { from } => Some(from), _ => None, } } /// Returns a reference to the message in this envelope. pub fn msg(&self) -> &T { &self.msg } /// Sends a reply to this message. /// /// If this message is not expecting a reply, or if this message has already been replied to, /// then an error is returned. pub fn reply(&mut self, reply: T::Reply) -> Result<()> { match &mut self.kind { EnvelopeKind::Call { reply: tx } => { if let Some(tx) = tx.take() { tx.send(reply).map_err(|_| bterr!("Failed to send reply.")) } else { Err(bterr!("Reply has already been sent.")) } } _ => Err(bterr!("Can't reply to '{}' messages.", self.kind.name())), } } /// Returns true if this message expects a reply and it has not already been replied to. pub fn needs_reply(&self) -> bool { matches!(&self.kind, EnvelopeKind::Call { .. }) } pub fn split(self) -> (T, EnvelopeKind) { (self.msg, self.kind) } } type FutureResult = Pin>>>; pub struct ActorHandle { handle: AbortHandle, sender: Box, deliverer: Box) -> FutureResult>, } impl ActorHandle { fn new(handle: AbortHandle, sender: mpsc::Sender>, deliverer: F) -> Self where T: 'static + CallMsg, F: 'static + Send + Sync + Fn(WireEnvelope<'_>) -> FutureResult, { Self { handle, sender: Box::new(sender), deliverer: Box::new(deliverer), } } fn sender(&self) -> Result<&mpsc::Sender>> { self.sender .downcast_ref::>>() .ok_or_else(|| bterr!("Attempt to send message as the wrong type.")) } /// Sends a message to the actor represented by this handle. pub async fn send(&self, from: ActorName, msg: T) -> Result<()> { let sender = self.sender()?; sender .send(Envelope::new_send(from, msg)) .await .map_err(|_| bterr!("failed to enqueue message"))?; Ok(()) } pub async fn call_through(&self, msg: T) -> Result { let sender = self.sender()?; let (envelope, rx) = Envelope::new_call(msg); sender .send(envelope) .await .map_err(|_| bterr!("failed to enqueue call"))?; let reply = rx.await?; Ok(reply) } pub fn abort(&self) { self.handle.abort(); } } impl Drop for ActorHandle { fn drop(&mut self) { self.abort(); } } /// Sets up variable declarations and logging configuration to facilitate testing with a [Runtime]. #[macro_export] macro_rules! test_setup { () => { const RUNTIME_ADDR: ::std::net::IpAddr = ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1)); lazy_static! { static ref RUNTIME_CREDS: ::std::sync::Arc<::btlib::crypto::ConcreteCreds> = { let test_store = &::btlib_tests::TEST_STORE; ::btlib::crypto::CredStore::node_creds(test_store).unwrap() }; } declare_runtime!(RUNTIME, RUNTIME_ADDR, RUNTIME_CREDS.clone()); lazy_static! { /// A tokio async runtime. /// /// When the `#[tokio::test]` attribute is used on a test, a new current thread runtime /// is created for each test /// (source: https://docs.rs/tokio/latest/tokio/attr.test.html#current-thread-runtime). /// This creates a problem, because the first test thread to access the `RUNTIME` static /// will initialize its `Receiver` in its runtime, which will stop running at the end of /// the test. Hence subsequent tests will not be able to send remote messages to this /// `Runtime`. /// /// By creating a single async runtime which is used by all of the tests, we can avoid this /// problem. static ref ASYNC_RT: tokio::runtime::Runtime = ::tokio::runtime::Builder ::new_current_thread() .enable_all() .build() .unwrap(); } /// The log level to use when running tests. const LOG_LEVEL: &str = "warn"; #[::ctor::ctor] #[allow(non_snake_case)] fn ctor() { ::std::env::set_var("RUST_LOG", format!("{},quinn=WARN", LOG_LEVEL)); let mut builder = ::env_logger::Builder::from_default_env(); ::btlib::log::BuilderExt::btformat(&mut builder).init(); } }; } #[cfg(test)] pub mod test { use super::*; use btlib::crypto::{CredStore, CredsPriv}; use btlib_tests::TEST_STORE; use bttp::BlockAddr; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use crate::CallMsg; test_setup!(); #[derive(Serialize, Deserialize)] struct EchoMsg(String); impl CallMsg for EchoMsg { type Reply = EchoMsg; } async fn echo( _rt: &'static Runtime, mut mailbox: mpsc::Receiver>, _act_id: ActorId, ) { while let Some(envelope) = mailbox.recv().await { let (msg, kind) = envelope.split(); match kind { EnvelopeKind::Call { reply } => { let replier = reply.unwrap_or_else(|| panic!("The reply has already been sent.")); if let Err(_) = replier.send(msg) { panic!("failed to send reply"); } } _ => panic!("Expected EchoMsg to be a Call Message."), } } } #[test] fn local_call() { ASYNC_RT.block_on(async { const EXPECTED: &str = "hello"; let name = RUNTIME.spawn(echo).await; let from = ActorName::new(name.path().clone(), ActorId::new()); let reply = RUNTIME .call(name.clone(), from, EchoMsg(EXPECTED.into())) .await .unwrap(); assert_eq!(EXPECTED, reply.0); RUNTIME.take(&name).await.unwrap(); }) } /// Tests the `num_running` method. /// /// This test uses its own runtime and so can use the `#[tokio::test]` attribute. #[tokio::test] async fn num_running() { declare_runtime!( LOCAL_RT, // This needs to be different from the address where `RUNTIME` is listening. IpAddr::from([127, 0, 0, 2]), TEST_STORE.node_creds().unwrap() ); assert_eq!(0, LOCAL_RT.num_running().await); let name = LOCAL_RT.spawn(echo).await; assert_eq!(1, LOCAL_RT.num_running().await); LOCAL_RT.take(&name).await.unwrap(); assert_eq!(0, LOCAL_RT.num_running().await); } #[test] fn remote_call() { ASYNC_RT.block_on(async { const EXPECTED: &str = "hello"; let actor_name = RUNTIME.spawn(echo).await; let bind_path = Arc::new(RUNTIME_CREDS.bind_path().unwrap()); let block_addr = Arc::new(BlockAddr::new(RUNTIME_ADDR, bind_path)); let transmitter = Transmitter::new(block_addr, RUNTIME_CREDS.clone()) .await .unwrap(); let buf = to_vec(&EchoMsg(EXPECTED.to_string())).unwrap(); let wire_msg = WireMsg::new(actor_name.clone(), RUNTIME.actor_name(ActorId::new()), &buf); let reply = transmitter .call(wire_msg, ReplyCallback::::new()) .await .unwrap() .unwrap(); assert_eq!(EXPECTED, reply.0); RUNTIME.take(&actor_name).await.unwrap(); }); } }