#![feature(impl_trait_in_assoc_type)] use btrun::*; use btlib::{ crypto::{ConcreteCreds, CredStore, CredsPriv}, log::BuilderExt, Result, }; use btlib_tests::TEST_STORE; use btproto::protocol; use btserde::to_vec; use bttp::{BlockAddr, Transmitter}; use ctor::ctor; use lazy_static::lazy_static; use log; use serde::{Deserialize, Serialize}; use std::{ future::{ready, Future, Ready}, net::{IpAddr, Ipv4Addr}, sync::{ atomic::{AtomicU8, Ordering}, Arc, }, }; use tokio::{runtime::Builder, sync::mpsc}; use uuid::Uuid; const RUNTIME_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); lazy_static! { static ref RUNTIME_CREDS: Arc = TEST_STORE.node_creds().unwrap(); } declare_runtime!(RUNTIME, RUNTIME_ADDR, RUNTIME_CREDS.clone()); lazy_static! { /// A tokio async runtime. /// /// When the `#[tokio::test]` attribute is used on a test, a new current thread runtime /// is created for each test /// (source: https://docs.rs/tokio/latest/tokio/attr.test.html#current-thread-runtime). /// This creates a problem, because the first test thread to access the `RUNTIME` static /// will initialize its `Receiver` in its runtime, which will stop running at the end of /// the test. Hence subsequent tests will not be able to send remote messages to this /// `Runtime`. /// /// By creating a single async runtime which is used by all of the tests, we can avoid this /// problem. static ref ASYNC_RT: tokio::runtime::Runtime = Builder::new_current_thread() .enable_all() .build() .unwrap(); } /// The log level to use when running tests. const LOG_LEVEL: &str = "warn"; #[ctor] fn ctor() { std::env::set_var("RUST_LOG", format!("{},quinn=WARN", LOG_LEVEL)); env_logger::Builder::from_default_env().btformat().init(); } #[derive(Serialize, Deserialize)] struct EchoMsg(String); impl CallMsg for EchoMsg { type Reply = EchoMsg; } async fn echo( _rt: &'static Runtime, mut mailbox: mpsc::Receiver>, _act_id: Uuid, ) { while let Some(envelope) = mailbox.recv().await { let (msg, kind) = envelope.split(); match kind { EnvelopeKind::Call { reply } => { let replier = reply.unwrap_or_else(|| panic!("The reply has already been sent.")); if let Err(_) = replier.send(msg) { panic!("failed to send reply"); } } _ => panic!("Expected EchoMsg to be a Call Message."), } } } #[test] fn local_call() { ASYNC_RT.block_on(async { const EXPECTED: &str = "hello"; let name = RUNTIME.spawn(echo).await; let from = ActorName::new(name.path().clone(), Uuid::default()); let reply = RUNTIME .call(name.clone(), from, EchoMsg(EXPECTED.into())) .await .unwrap(); assert_eq!(EXPECTED, reply.0); RUNTIME.take(&name).await.unwrap(); }) } #[test] fn remote_call() { ASYNC_RT.block_on(async { const EXPECTED: &str = "hello"; let actor_name = RUNTIME.spawn(echo).await; let bind_path = Arc::new(RUNTIME_CREDS.bind_path().unwrap()); let block_addr = Arc::new(BlockAddr::new(RUNTIME_ADDR, bind_path)); let transmitter = Transmitter::new(block_addr, RUNTIME_CREDS.clone()) .await .unwrap(); let buf = to_vec(&EchoMsg(EXPECTED.to_string())).unwrap(); let wire_msg = WireMsg::new( actor_name.clone(), RUNTIME.actor_name(Uuid::default()), &buf, ); let reply = transmitter .call(wire_msg, ReplyCallback::::new()) .await .unwrap() .unwrap(); assert_eq!(EXPECTED, reply.0); RUNTIME.take(&actor_name).await.unwrap(); }); } /// Tests the `num_running` method. /// /// This test uses its own runtime and so can use the `#[tokio::test]` attribute. #[tokio::test] async fn num_running() { declare_runtime!( LOCAL_RT, // This needs to be different from the address where `RUNTIME` is listening. IpAddr::from([127, 0, 0, 2]), TEST_STORE.node_creds().unwrap() ); assert_eq!(0, LOCAL_RT.num_running().await); let name = LOCAL_RT.spawn(echo).await; assert_eq!(1, LOCAL_RT.num_running().await); LOCAL_RT.take(&name).await.unwrap(); assert_eq!(0, LOCAL_RT.num_running().await); } mod ping_pong { use super::*; use btlib::bterr; // The following code is a proof-of-concept for what types should be generated for a // simple ping-pong protocol: protocol! { named PingProtocol; let server = [Server]; let client = [Client]; Client -> End, >service(Server)!Ping; Server?Ping -> End, >Client!Ping::Reply; } // // In words, the protocol is described as follows. // 1. The ClientInit state receives the Activate message. It returns the SentPing state and a // Ping message to be sent to the Listening state. // 2. The ServerInit state receives the Activate message. It returns the Listening state. // 3. When the Listening state receives the Ping message it returns the End state and a // Ping::Reply message to be sent to the SentPing state. // 4. When the SentPing state receives the Ping::Reply message it returns the End state. // // The End state represents an end to the session described by the protocol. When an actor // transitions to the End state its function returns. // The generated actor implementation is the sender of the Activate message. // When a state is expecting a Reply message, an error occurs if the message is not received // in a timely manner. enum PingClientState { Client(T), End(End), } impl PingClientState { const fn name(&self) -> &'static str { match self { Self::Client(_) => "Client", Self::End(_) => "End", } } } struct ClientHandle { state: Option>, runtime: &'static Runtime, } impl ClientHandle { async fn send_ping(&mut self, mut msg: Ping, service: ServiceAddr) -> Result { let state = self .state .take() .ok_or_else(|| bterr!("State was not returned."))?; let (new_state, result) = match state { PingClientState::Client(state) => { let (new_state, _) = state.on_send_ping(&mut msg).await?; let new_state = PingClientState::End(new_state); let result = self .runtime .call_service(service, PingProtocolMsgs::Ping(msg)) .await; (new_state, result) } state => { let result = Err(bterr!("Can't send Ping in state {}.", state.name())); (state, result) } }; self.state = Some(new_state); let reply = result?; match reply { PingProtocolMsgs::PingReply(reply) => Ok(reply), msg => Err(bterr!( "Unexpected message type sent in reply: {}", msg.name() )), } } } async fn spawn_client(init: T, runtime: &'static Runtime) -> ClientHandle { let state = Some(PingClientState::Client(init)); ClientHandle { state, runtime } } async fn register_server( make_init: F, rt: &'static Runtime, id: ServiceId, ) -> Result where Init: 'static + Server, F: 'static + Send + Sync + Clone + Fn() -> Init, { enum ServerState { Server(S), End(End), } async fn server_loop( _runtime: &'static Runtime, make_init: F, mut mailbox: Mailbox, _act_id: Uuid, ) where Init: 'static + Server, F: 'static + Send + Sync + FnOnce() -> Init, { let mut state = ServerState::Server(make_init()); while let Some(envelope) = mailbox.recv().await { let (msg, msg_kind) = envelope.split(); state = match (state, msg) { (ServerState::Server(listening_state), PingProtocolMsgs::Ping(msg)) => { let (new_state, reply) = listening_state.handle_ping(msg).await.unwrap(); match msg_kind { EnvelopeKind::Call { reply: replier } => { let replier = replier.expect("The reply has already been sent."); if let Err(_) = replier.send(PingProtocolMsgs::PingReply(reply)) { panic!("Failed to send Ping reply."); } ServerState::End(new_state) } _ => panic!("'Ping' was expected to be a Call message."), } } (state, _) => state, }; if let ServerState::End(_) = state { break; } } } rt.register::(id, move |runtime| { let make_init = make_init.clone(); let fut = async move { let actor_name = runtime .spawn(move |_, mailbox, act_id| { server_loop(runtime, make_init, mailbox, act_id) }) .await; Ok(actor_name) }; Box::pin(fut) }) .await } #[derive(Serialize, Deserialize)] pub struct Ping; impl CallMsg for Ping { type Reply = PingReply; } #[derive(Serialize, Deserialize)] pub struct PingReply; struct ClientState { counter: Arc, } impl ClientState { fn new(counter: Arc) -> Self { counter.fetch_add(1, Ordering::SeqCst); Self { counter } } } impl Client for ClientState { type OnSendPingFut = impl Future>; fn on_send_ping(self, _msg: &mut Ping) -> Self::OnSendPingFut { self.counter.fetch_sub(1, Ordering::SeqCst); ready(Ok((End, PingReply))) } } struct ServerState { counter: Arc, } impl ServerState { fn new(counter: Arc) -> Self { counter.fetch_add(1, Ordering::SeqCst); Self { counter } } } impl Server for ServerState { type HandlePingFut = impl Future>; fn handle_ping(self, _msg: Ping) -> Self::HandlePingFut { self.counter.fetch_sub(1, Ordering::SeqCst); ready(Ok((End, PingReply))) } } #[test] fn ping_pong_test() { ASYNC_RT.block_on(async { const SERVICE_ID: &str = "PingPongProtocolServer"; let service_id = ServiceId::from(SERVICE_ID); let counter = Arc::new(AtomicU8::new(0)); let service_name = { let service_counter = counter.clone(); let make_init = move || { let server_counter = service_counter.clone(); ServerState::new(server_counter) }; register_server(make_init, &RUNTIME, service_id.clone()) .await .unwrap() }; let mut client_handle = spawn_client(ClientState::new(counter.clone()), &RUNTIME).await; let service_addr = ServiceAddr::new(service_name, true); client_handle.send_ping(Ping, service_addr).await.unwrap(); assert_eq!(0, counter.load(Ordering::SeqCst)); RUNTIME.take_service(&service_id).await.unwrap(); }); } } mod travel_agency { use super::*; // Here's another protocol example. This is the Customer and Travel Agency protocol used as an // example in the survey paper "Behavioral Types in Programming Languages." // Note that the Choosing state can send messages at any time, not just in response to another // message because there is a transition from Choosing that doesn't use the receive operator // (`?`). protocol! { named TravelAgency; let agency = [Listening]; let customer = [Choosing]; Choosing -> Choosing, >service(Listening)!Query; Choosing -> Choosing, >service(Listening)!Accept; Choosing -> Choosing, >service(Listening)!Reject; Listening?Query -> Listening, >Choosing!Query::Reply; Choosing?Query::Reply -> Choosing; Listening?Accept -> End, >Choosing!Accept::Reply; Choosing?Accept::Reply -> End; Listening?Reject -> End, >Choosing!Reject::Reply; Choosing?Reject::Reply -> End; } #[derive(Serialize, Deserialize)] pub struct Query; impl CallMsg for Query { type Reply = (); } #[derive(Serialize, Deserialize)] pub struct Reject; impl CallMsg for Reject { type Reply = (); } #[derive(Serialize, Deserialize)] pub struct Accept; impl CallMsg for Accept { type Reply = (); } } #[allow(dead_code)] mod client_callback { use super::*; use std::time::Duration; use tokio::{sync::oneshot, time::timeout}; #[derive(Serialize, Deserialize)] pub struct Register { factor: usize, } #[derive(Serialize, Deserialize)] pub struct Completed { value: usize, } protocol! { named ClientCallback; let server = [Listening]; let worker = [Working]; let client = [Unregistered, Registered]; Unregistered -> Registered, >service(Listening)!Register[Registered]; Listening?Register[Registered] -> Listening, Working[Registered]; Working[Registered] -> End, >Registered!Completed; Registered?Completed -> End; } struct UnregisteredState { sender: oneshot::Sender, } impl Unregistered for UnregisteredState { type OnSendRegisterRegistered = RegisteredState; type OnSendRegisterFut = Ready>; fn on_send_register(self, _arg: &mut Register) -> Self::OnSendRegisterFut { ready(Ok(RegisteredState { sender: self.sender, })) } } struct RegisteredState { sender: oneshot::Sender, } impl Registered for RegisteredState { type HandleCompletedFut = Ready>; fn handle_completed(self, arg: Completed) -> Self::HandleCompletedFut { self.sender.send(arg.value).unwrap(); ready(Ok(End)) } } struct ListeningState { multiple: usize, } impl Listening for ListeningState { type HandleRegisterListening = ListeningState; type HandleRegisterWorking = WorkingState; type HandleRegisterFut = Ready>; fn handle_register(self, arg: Register) -> Self::HandleRegisterFut { let multiple = self.multiple; ready(Ok(( self, WorkingState { factor: arg.factor, multiple, }, ))) } } struct WorkingState { factor: usize, multiple: usize, } impl Working for WorkingState { type OnSendCompletedFut = Ready>; fn on_send_completed(self) -> Self::OnSendCompletedFut { let value = self.multiple * self.factor; ready(Ok((End, Completed { value }))) } } use ::tokio::sync::Mutex; enum ClientState { Unregistered(Init), Registered(Init::OnSendRegisterRegistered), End(End), } impl ClientState { pub fn name(&self) -> &'static str { match self { Self::Unregistered(_) => "Unregistered", Self::Registered(_) => "Registered", Self::End(_) => "End", } } } struct ClientHandle { runtime: &'static Runtime, state: Arc>>>, name: ActorName, } impl ClientHandle { async fn send_register(&self, to: ServiceAddr, mut msg: Register) -> Result<()> { let mut guard = self.state.lock().await; let state = guard .take() .unwrap_or_else(|| panic!("Logic error. The state was not returned.")); let new_state = match state { ClientState::Unregistered(state) => { let new_state = state.on_send_register(&mut msg).await?; let msg = ClientCallbackMsgs::Register(msg); self.runtime .send_service(to, self.name.clone(), msg) .await?; // QUESTION: Should `on_send_register` be required to return the previous state // if it encounters an error? ClientState::Registered(new_state) } state => state, }; *guard = Some(new_state); Ok(()) } } async fn spawn_client(init: Init, runtime: &'static Runtime) -> ClientHandle where Init: 'static + Unregistered, { let state = Arc::new(Mutex::new(Some(ClientState::Unregistered(init)))); let name = { let state = state.clone(); runtime.spawn(move |_, mut mailbox, _act_id| async move { while let Some(envelope) = mailbox.recv().await { let mut guard = state.lock().await; let state = guard.take() .unwrap_or_else(|| panic!("Logic error. The state was not returned.")); let (msg, _kind) = envelope.split(); let new_state = match (state, msg) { (ClientState::Registered(curr_state), ClientCallbackMsgs::Completed(msg)) => { match curr_state.handle_completed(msg).await { Ok(next) => ClientState::::End(next), Err(err) => { log::error!("Failed to handle 'Completed' message in 'Registered' state: {err}"); panic!("We can't transition to a new state because we gave ownership away and that method failed!") } } } (state, msg) => { log::error!("Unexpected message '{}' in state '{}'.", msg.name(), state.name()); state } }; *guard = Some(new_state); } }).await }; ClientHandle { runtime, state, name, } } async fn register_server( make_init: F, runtime: &'static Runtime, service_id: ServiceId, ) -> Result where Init: 'static + Listening, F: 'static + Send + Sync + Clone + Fn() -> Init, { enum ServerState { Listening(S), } impl ServerState { fn name(&self) -> &'static str { match self { Self::Listening(_) => "Listening", } } } async fn server_loop( runtime: &'static Runtime, make_init: F, mut mailbox: Mailbox, _act_id: Uuid, ) where Init: 'static + Listening, F: 'static + Send + Sync + Fn() -> Init, { let mut state = ServerState::Listening(make_init()); while let Some(envelope) = mailbox.recv().await { let (msg, msg_kind) = envelope.split(); let new_state = match (state, msg) { (ServerState::Listening(curr_state), ClientCallbackMsgs::Register(msg)) => { match curr_state.handle_register(msg).await { Ok((new_state, working_state)) => { if let EnvelopeKind::Send { from, .. } = msg_kind { start_worker(working_state, from, runtime).await; } else { log::error!("Expected Register to be a Send message."); } ServerState::Listening(new_state) } Err(err) => { log::error!("Failed to handle the Register message: {err}"); todo!("Need to recover the previous state from err.") } } } (state, msg) => { log::error!( "Unexpected message '{}' in state '{}'.", msg.name(), state.name() ); state } }; state = new_state; } } runtime .register::(service_id, move |runtime: &'static Runtime| { let make_init = make_init.clone(); let fut = async move { let make_init = make_init.clone(); let actor_name = runtime .spawn(move |_, mailbox, act_id| { server_loop(runtime, make_init, mailbox, act_id) }) .await; Ok(actor_name) }; Box::pin(fut) }) .await } async fn start_worker( init: Init, owned: ActorName, runtime: &'static Runtime, ) -> ActorName where Init: 'static + Working, { enum WorkerState { Working(S), } runtime .spawn::(move |_, _, act_id| async move { let msg = match init.on_send_completed().await { Ok((End, msg)) => msg, Err(err) => { log::error!("Failed to send Completed message: {err}"); return; } }; let from = runtime.actor_name(act_id); let msg = ClientCallbackMsgs::Completed(msg); if let Err(err) = runtime.send(owned, from, msg).await { log::error!("Failed to send Completed message: {err}"); } }) .await } #[test] fn client_callback_protocol() { ASYNC_RT.block_on(async { const SERVICE_ID: &str = "ClientCallbackProtocolListening"; let service_id = ServiceId::from(SERVICE_ID); let service_name = { let make_init = move || ListeningState { multiple: 2 }; register_server(make_init, &RUNTIME, service_id.clone()) .await .unwrap() }; let (sender, receiver) = oneshot::channel(); let client_handle = spawn_client(UnregisteredState { sender }, &RUNTIME).await; let service_addr = ServiceAddr::new(service_name, false); client_handle .send_register(service_addr, Register { factor: 21 }) .await .unwrap(); let value = timeout(Duration::from_millis(500), receiver) .await .unwrap() .unwrap(); assert_eq!(42, value); }); } }