#![feature(impl_trait_in_assoc_type)] use btrun::model::*; use btrun::test_setup; use btrun::*; use btlib::Result; use btproto::protocol; use log; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use std::{ future::{ready, Future, Ready}, sync::{ atomic::{AtomicU8, Ordering}, Arc, }, }; test_setup!(); mod ping_pong { use super::*; use btlib::bterr; // The following code is a proof-of-concept for what types should be generated for a // simple ping-pong protocol: protocol! { named PingProtocol; let server = [Server]; let client = [Client]; Client -> End, >service(Server)!Ping; Server?Ping -> End, >Client!Ping::Reply; } // // In words, the protocol is described as follows. // 1. When the Listening state receives the Ping message it returns the End state and a // Ping::Reply message to be sent to the SentPing state. // 2. When the SentPing state receives the Ping::Reply message it returns the End state. // // The End state represents an end to the session described by the protocol. When an actor // transitions to the End state its function returns. // When a state is expecting a Reply message, an error occurs if the message is not received // in a timely manner. enum PingClientState { Client(T), End(End), } impl PingClientState { const fn name(&self) -> &'static str { match self { Self::Client(_) => "Client", Self::End(_) => "End", } } } struct ClientHandleManual { state: Option>, client_name: ActorName, runtime: &'static Runtime, } impl ClientHandleManual { async fn send_ping( mut self, msg: Ping, service: ServiceAddr, ) -> TransResult> { let state = if let Some(state) = self.state.take() { state } else { return TransResult::Abort { from: self, err: bterr!("The shared state was not returned."), }; }; match state { PingClientState::Client(state) => { let result = self .runtime .call_service( service, self.client_name.clone(), PingProtocolMsgs::Ping(msg), ) .await; let reply_enum = match result { Ok(reply_enum) => reply_enum, Err(err) => { self.state = Some(PingClientState::Client(state)); return TransResult::Abort { from: self, err }; } }; if let PingProtocolMsgs::PingReply(reply) = reply_enum { match state.on_send_ping(reply).await { TransResult::Ok(new_state) => { self.state = Some(PingClientState::End(new_state)); TransResult::Ok(self) } TransResult::Abort { from, err } => { self.state = Some(PingClientState::Client(from)); TransResult::Abort { from: self, err } } TransResult::Fatal { err } => return TransResult::Fatal { err }, } } else { TransResult::Abort { from: self, err: bterr!("Unexpected reply type."), } } } state => { let err = bterr!("Can't send Ping in state {}.", state.name()); self.state = Some(state); TransResult::Abort { from: self, err } } } } } async fn spawn_client_manual( init: T, runtime: &'static Runtime, ) -> ClientHandleManual { let state = Some(PingClientState::Client(init)); let client_name = runtime.spawn(None, do_nothing_actor).await.unwrap(); ClientHandleManual { state, client_name, runtime, } } async fn register_server_manual( make_init: F, rt: &'static Runtime, id: ServiceId, ) -> Result where Init: 'static + Server, F: 'static + Send + Sync + Clone + Fn() -> Init, { enum ServerState { Server(S), End(End), } impl Named for ServerState { fn name(&self) -> Arc { static SERVER_NAME: Lazy> = Lazy::new(|| Arc::new("Server".into())); static END_NAME: Lazy> = Lazy::new(|| Arc::new("End".into())); match self { Self::Server(_) => SERVER_NAME.clone(), Self::End(_) => END_NAME.clone(), } } } async fn server_loop( _runtime: &'static Runtime, make_init: F, mut mailbox: Mailbox, actor_id: ActorId, ) -> ActorResult where Init: 'static + Server, F: 'static + Send + Sync + FnOnce() -> Init, { let mut state = ServerState::Server(make_init()); while let Some(envelope) = mailbox.recv().await { state = match envelope { Envelope::Call { msg, reply: replier, .. } => match (state, msg) { (ServerState::Server(listening_state), PingProtocolMsgs::Ping(msg)) => { match listening_state.handle_ping(msg).await { TransResult::Ok((new_state, reply)) => { let replier = replier .ok_or_else(|| bterr!("Reply has already been sent.")) .unwrap(); if let Err(_) = replier.send(PingProtocolMsgs::PingReply(reply)) { return Err(ActorError::new( bterr!("Failed to send Ping reply."), ActorErrorPayload { actor_id, actor_impl: Init::actor_impl(), state: Init::state_name(), message: PingProtocolMsgKinds::Ping.name(), kind: TransKind::Receive, }, )); } ServerState::End(new_state) } TransResult::Abort { from, err } => { log::warn!("Aborted transition from the {} while handling the {} message: {}", "Server", "Ping", err); ServerState::Server(from) } TransResult::Fatal { err } => { return Err(ActorError::new( err, ActorErrorPayload { actor_id, actor_impl: Init::actor_impl(), state: Init::state_name(), message: PingProtocolMsgKinds::Ping.name(), kind: TransKind::Receive, }, )); } } } (state, _) => state, }, envelope => { return Err(ActorError::new( bterr!("Unexpected envelope type: {}", envelope.name()), ActorErrorPayload { actor_id, actor_impl: Init::actor_impl(), state: state.name(), message: envelope.msg_name(), kind: TransKind::Receive, }, )) } }; if let ServerState::End(_) = state { break; } } Ok(actor_id) } rt.register::(id, move |runtime| { let make_init = make_init.clone(); let fut = async move { let actor_impl = runtime .spawn(None, move |mailbox, act_id, runtime| { server_loop(runtime, make_init, mailbox, act_id) }) .await .unwrap(); Ok(actor_impl) }; Box::pin(fut) }) .await } #[derive(Serialize, Deserialize)] pub struct Ping; impl CallMsg for Ping { type Reply = PingReply; } #[derive(Serialize, Deserialize)] pub struct PingReply; struct ClientImpl { counter: Arc, } impl ClientImpl { fn new(counter: Arc) -> Self { counter.fetch_add(1, Ordering::SeqCst); Self { counter } } } impl Client for ClientImpl { actor_name!("ping_client"); type OnSendPingFut = impl Future>; fn on_send_ping(self, _msg: PingReply) -> Self::OnSendPingFut { self.counter.fetch_sub(1, Ordering::SeqCst); ready(TransResult::Ok(End)) } } struct ServerState { counter: Arc, } impl ServerState { fn new(counter: Arc) -> Self { counter.fetch_add(1, Ordering::SeqCst); Self { counter } } } impl Server for ServerState { actor_name!("ping_server"); type HandlePingFut = impl Future>; fn handle_ping(self, _msg: Ping) -> Self::HandlePingFut { self.counter.fetch_sub(1, Ordering::SeqCst); ready(TransResult::Ok((End, PingReply))) } } #[test] fn ping_pong_test() { ASYNC_RT.block_on(async { const SERVICE_ID: &str = "PingPongProtocolServer"; let service_id = ServiceId::from(SERVICE_ID); let counter = Arc::new(AtomicU8::new(0)); let service_name = { let service_counter = counter.clone(); let make_init = move || { let server_counter = service_counter.clone(); ServerState::new(server_counter) }; register_server_manual(make_init, &RUNTIME, service_id.clone()) .await .unwrap() }; let client_handle = spawn_client_manual(ClientImpl::new(counter.clone()), &RUNTIME).await; let service_addr = ServiceAddr::new(service_name, true); client_handle.send_ping(Ping, service_addr).await.unwrap(); assert_eq!(0, counter.load(Ordering::SeqCst)); }); } } mod travel_agency { use super::*; // Here's another protocol example. This is the Customer and Travel Agency protocol used as an // example in the survey paper "Behavioral Types in Programming Languages." // Note that the Choosing state can send messages at any time, not just in response to another // message because there is a transition from Choosing that doesn't use the receive operator // (`?`). protocol! { named TravelAgency; let agency = [Listening]; let customer = [Choosing]; Choosing -> Choosing, >service(Listening)!Query; Choosing -> Choosing, >service(Listening)!Accept; Choosing -> Choosing, >service(Listening)!Reject; Listening?Query -> Listening, >Choosing!Query::Reply; Listening?Accept -> End, >Choosing!Accept::Reply; Listening?Reject -> End, >Choosing!Reject::Reply; } #[derive(Serialize, Deserialize)] pub struct Query; impl CallMsg for Query { type Reply = (); } #[derive(Serialize, Deserialize)] pub struct Reject; impl CallMsg for Reject { type Reply = (); } #[derive(Serialize, Deserialize)] pub struct Accept; impl CallMsg for Accept { type Reply = (); } } #[allow(dead_code)] mod client_callback { use super::*; use btlib::bterr; use once_cell::sync::Lazy; use std::{marker::PhantomData, panic::panic_any, time::Duration}; use tokio::{sync::oneshot, time::timeout}; #[derive(Serialize, Deserialize)] pub struct Register { factor: usize, } #[derive(Serialize, Deserialize)] pub struct Completed { value: usize, } protocol! { named ClientCallback; let server = [Listening]; let worker = [Working]; let client = [Unregistered, Registered]; Unregistered -> Registered, >service(Listening)!Register[Registered]; Listening?Register[Registered] -> Listening, Working[Registered]; Working[Registered] -> End, >Registered!Completed; Registered?Completed -> End; } struct UnregisteredState { sender: oneshot::Sender, } impl Unregistered for UnregisteredState { actor_name!("callback_client"); type OnSendRegisterRegistered = RegisteredState; type OnSendRegisterFut = Ready>; fn on_send_register(self) -> Self::OnSendRegisterFut { ready(TransResult::Ok(RegisteredState { sender: self.sender, })) } } struct RegisteredState { sender: oneshot::Sender, } impl Registered for RegisteredState { type HandleCompletedFut = Ready>; fn handle_completed(self, arg: Completed) -> Self::HandleCompletedFut { self.sender.send(arg.value).unwrap(); ready(TransResult::Ok(End)) } } struct ListeningState { multiple: usize, } impl Listening for ListeningState { actor_name!("callback_server"); type HandleRegisterListening = ListeningState; type HandleRegisterWorking = WorkingState; type HandleRegisterFut = Ready>; fn handle_register(self, arg: Register) -> Self::HandleRegisterFut { let multiple = self.multiple; ready(TransResult::Ok(( self, WorkingState { factor: arg.factor, multiple, }, ))) } } struct WorkingState { factor: usize, multiple: usize, } impl Working for WorkingState { actor_name!("callback_worker"); type OnSendCompletedFut = Ready>; fn on_send_completed(self) -> Self::OnSendCompletedFut { let value = self.multiple * self.factor; ready(TransResult::Ok((End, Completed { value }))) } } use ::tokio::sync::Mutex; enum ClientStateManual { Unregistered(Init), Registered(Init::OnSendRegisterRegistered), End(End), } impl Named for ClientStateManual { fn name(&self) -> Arc { static UNREGISTERED_NAME: Lazy> = Lazy::new(|| Arc::new("Unregistered".into())); static REGISTERED_NAME: Lazy> = Lazy::new(|| Arc::new("Registered".into())); static END_NAME: Lazy> = Lazy::new(|| Arc::new("End".into())); match self { Self::Unregistered(_) => UNREGISTERED_NAME.clone(), Self::Registered(_) => REGISTERED_NAME.clone(), Self::End(_) => END_NAME.clone(), } } } struct ClientHandleManual { runtime: &'static Runtime, state: Arc>>>, name: ActorName, type_state: PhantomData, } impl ClientHandleManual { fn new_type(self) -> ClientHandleManual { ClientHandleManual { runtime: self.runtime, state: self.state, name: self.name, type_state: PhantomData, } } } impl< Init: Unregistered, State: Unregistered, NewState: Registered, > ClientHandleManual { async fn send_register( self, to: ServiceAddr, msg: Register, ) -> Result> { { let mut guard = self.state.lock().await; let state = guard .take() .unwrap_or_else(|| panic!("Logic error. The state was not returned.")); let new_state = match state { ClientStateManual::Unregistered(state) => { match state.on_send_register().await { TransResult::Ok(new_state) => { let msg = ClientCallbackMsgs::Register(msg); self.runtime .send_service(to, self.name.clone(), msg) .await?; ClientStateManual::Registered(new_state) } TransResult::Abort { from, err } => { log::warn!( "Aborted transition from the {} state: {}", "Unregistered", err ); ClientStateManual::Unregistered(from) } TransResult::Fatal { err } => { return Err(err); } } } state => state, }; *guard = Some(new_state); } Ok(self.new_type()) } } async fn spawn_client_manual( init: Init, runtime: &'static Runtime, ) -> ClientHandleManual where Init: 'static + Unregistered, { let state = Arc::new(Mutex::new(Some(ClientStateManual::Unregistered(init)))); let name = { let state = state.clone(); runtime.spawn(None, move |mut mailbox, actor_id, _| async move { while let Some(envelope) = mailbox.recv().await { let mut guard = state.lock().await; let state = guard.take() .unwrap_or_else(|| panic!("Logic error. The state was not returned.")); let new_state = match envelope { Envelope::Send { msg, .. } => { match (state, msg) { (ClientStateManual::Registered(curr_state), ClientCallbackMsgs::Completed(msg)) => { match curr_state.handle_completed(msg).await { TransResult::Ok(next) => ClientStateManual::::End(next), TransResult::Abort { from, err } => { log::warn!("Aborted transition from the {} state while handling the {} message: {}", "Registered", "Completed", err); ClientStateManual::Registered(from) } TransResult::Fatal { err } => { panic_any(ActorError::new( err, ActorErrorPayload { actor_id, actor_impl: Init::actor_impl(), state: Init::OnSendRegisterRegistered::state_name(), message: ClientCallbackMsgKinds::Completed.name(), kind: TransKind::Receive, })); } } } (state, msg) => { log::error!("Unexpected message {} in state {}.", msg.name(), state.name()); state } } } envelope => return Err(ActorError::new( bterr!("Unexpected envelope type: {}", envelope.name()), ActorErrorPayload { actor_id, actor_impl: Init::actor_impl(), state: state.name(), message: envelope.msg_name(), kind: TransKind::Receive, })) }; *guard = Some(new_state); if let Some(state) = &*guard { if let ClientStateManual::End(_) = state { break; } } } Ok(actor_id) }).await.unwrap() }; ClientHandleManual { runtime, state, name, type_state: PhantomData, } } async fn register_server_manual( make_init: F, runtime: &'static Runtime, service_id: ServiceId, ) -> Result where Init: 'static + Listening, F: 'static + Send + Sync + Clone + Fn() -> Init, { enum ServerState { Listening(Init), } impl Named for ServerState { fn name(&self) -> Arc { static LISTENING_NAME: Lazy> = Lazy::new(|| Arc::new("Listening".into())); match self { Self::Listening(_) => LISTENING_NAME.clone(), } } } async fn server_loop( runtime: &'static Runtime, make_init: F, mut mailbox: Mailbox, actor_id: ActorId, ) -> ActorResult where Init: 'static + Listening, F: 'static + Send + Sync + Fn() -> Init, { let mut state = ServerState::Listening(make_init()); while let Some(envelope) = mailbox.recv().await { let new_state = match envelope { Envelope::Send { msg, from, .. } => match (state, msg) { (ServerState::Listening(curr_state), ClientCallbackMsgs::Register(msg)) => { match curr_state.handle_register(msg).await { TransResult::Ok((new_state, working_state)) => { start_worker_manual(working_state, from, runtime).await; ServerState::Listening(new_state) } TransResult::Abort { from, err } => { log::warn!("Aborted transition from the {} state while handling the {} message: {}", "Listening", "Register", err); ServerState::Listening(from) } TransResult::Fatal { err } => { let err = ActorError::new( err, ActorErrorPayload { actor_id, actor_impl: Init::actor_impl(), state: Init::state_name(), message: ClientCallbackMsgKinds::Register.name(), kind: TransKind::Receive, }, ); panic_any(format!("{err}")); } } } (state, msg) => { log::error!( "Unexpected message {} in state {}.", msg.name(), state.name() ); state } }, envelope => { return Err(ActorError::new( bterr!("Unexpected envelope type: {}", envelope.name()), ActorErrorPayload { actor_id, actor_impl: Init::actor_impl(), state: state.name(), message: envelope.msg_name(), kind: TransKind::Receive, }, )) } }; state = new_state; } Ok(actor_id) } runtime .register::(service_id, move |runtime: &'static Runtime| { let make_init = make_init.clone(); let fut = async move { let make_init = make_init.clone(); let actor_impl = runtime .spawn(None, move |mailbox, act_id, runtime| { server_loop(runtime, make_init, mailbox, act_id) }) .await .unwrap(); Ok(actor_impl) }; Box::pin(fut) }) .await } async fn start_worker_manual( init: Init, owned: ActorName, runtime: &'static Runtime, ) -> ActorName where Init: 'static + Working, { enum WorkerState { Working(S), } runtime .spawn::( Some(owned.clone()), move |_, actor_id, _| async move { let msg = match init.on_send_completed().await { TransResult::Ok((End, msg)) => msg, TransResult::Abort { err, .. } | TransResult::Fatal { err } => { let err = ActorError::new( err, ActorErrorPayload { actor_id, actor_impl: Init::actor_impl(), state: Init::state_name(), message: ClientCallbackMsgKinds::Completed.name(), kind: TransKind::Send, }, ); panic_any(format!("{err}")) } }; let from = runtime.actor_name(actor_id); let msg = ClientCallbackMsgs::Completed(msg); runtime.send(owned, from, msg).await.unwrap_or_else(|err| { let err = ActorError::new( err, ActorErrorPayload { actor_id, actor_impl: Init::actor_impl(), state: Init::state_name(), message: ClientCallbackMsgKinds::Completed.name(), kind: TransKind::Send, }, ); panic_any(format!("{err}")); }); Ok(actor_id) }, ) .await .unwrap() } #[test] fn client_callback_protocol() { ASYNC_RT.block_on(async { const SERVICE_ID: &str = "ClientCallbackProtocolListening"; let service_id = ServiceId::from(SERVICE_ID); let service_name = { let make_init = move || ListeningState { multiple: 2 }; register_server_manual(make_init, &RUNTIME, service_id.clone()) .await .unwrap() }; let (sender, receiver) = oneshot::channel(); let client_handle = spawn_client_manual(UnregisteredState { sender }, &RUNTIME).await; let service_addr = ServiceAddr::new(service_name, false); client_handle .send_register(service_addr, Register { factor: 21 }) .await .unwrap(); let value = timeout(Duration::from_millis(500), receiver) .await .unwrap() .unwrap(); assert_eq!(42, value); }); } }