|
- #![feature(impl_trait_in_assoc_type)]
- use btrun::model::*;
- use btrun::test_setup;
- use btrun::*;
- use btlib::Result;
- use btproto::protocol;
- use log;
- use once_cell::sync::Lazy;
- use serde::{Deserialize, Serialize};
- use std::{
- future::{ready, Future, Ready},
- sync::{
- atomic::{AtomicU8, Ordering},
- Arc,
- },
- };
- test_setup!();
- mod ping_pong {
- use super::*;
- use btlib::bterr;
- // The following code is a proof-of-concept for what types should be generated for a
- // simple ping-pong protocol:
- protocol! {
- named PingProtocol;
- let server = [Server];
- let client = [Client];
- Client -> End, >service(Server)!Ping;
- Server?Ping -> End, >Client!Ping::Reply;
- }
- //
- // In words, the protocol is described as follows.
- // 1. When the Listening state receives the Ping message it returns the End state and a
- // Ping::Reply message to be sent to the SentPing state.
- // 2. When the SentPing state receives the Ping::Reply message it returns the End state.
- //
- // The End state represents an end to the session described by the protocol. When an actor
- // transitions to the End state its function returns.
- // When a state is expecting a Reply message, an error occurs if the message is not received
- // in a timely manner.
- enum PingClientState<T: Client> {
- Client(T),
- End(End),
- }
- impl<T: Client> PingClientState<T> {
- const fn name(&self) -> &'static str {
- match self {
- Self::Client(_) => "Client",
- Self::End(_) => "End",
- }
- }
- }
- struct ClientHandleManual<T: Client> {
- state: Option<PingClientState<T>>,
- client_name: ActorName,
- runtime: &'static Runtime,
- }
- impl<T: Client> ClientHandleManual<T> {
- async fn send_ping(
- mut self,
- msg: Ping,
- service: ServiceAddr,
- ) -> TransResult<Self, (ClientHandleManual<T>, T::OnSendPingReturn)> {
- let state = if let Some(state) = self.state.take() {
- state
- } else {
- return TransResult::Abort {
- from: self,
- err: bterr!("The shared state was not returned."),
- };
- };
- match state {
- PingClientState::Client(state) => {
- let result = self
- .runtime
- .call_service(
- service,
- self.client_name.clone(),
- PingProtocolMsgs::Ping(msg),
- )
- .await;
- let reply_enum = match result {
- Ok(reply_enum) => reply_enum,
- Err(err) => {
- self.state = Some(PingClientState::Client(state));
- return TransResult::Abort { from: self, err };
- }
- };
- if let PingProtocolMsgs::PingReply(reply) = reply_enum {
- match state.on_send_ping(reply).await {
- TransResult::Ok((new_state, return_var)) => {
- self.state = Some(PingClientState::End(new_state));
- TransResult::Ok((self, return_var))
- }
- TransResult::Abort { from, err } => {
- self.state = Some(PingClientState::Client(from));
- TransResult::Abort { from: self, err }
- }
- TransResult::Fatal { err } => return TransResult::Fatal { err },
- }
- } else {
- TransResult::Abort {
- from: self,
- err: bterr!("Unexpected reply type."),
- }
- }
- }
- state => {
- let err = bterr!("Can't send Ping in state {}.", state.name());
- self.state = Some(state);
- TransResult::Abort { from: self, err }
- }
- }
- }
- }
- async fn spawn_client_manual<T: Client>(
- init: T,
- runtime: &'static Runtime,
- ) -> ClientHandleManual<T> {
- let state = Some(PingClientState::Client(init));
- let client_name = runtime.spawn(None, do_nothing_actor).await.unwrap();
- ClientHandleManual {
- state,
- client_name,
- runtime,
- }
- }
- async fn register_server_manual<Init, F>(
- make_init: F,
- rt: &'static Runtime,
- id: ServiceId,
- ) -> Result<ServiceName>
- where
- Init: 'static + Server,
- F: 'static + Send + Sync + Clone + Fn() -> Init,
- {
- enum ServerState<S> {
- Server(S),
- End(End),
- }
- impl<S> Named for ServerState<S> {
- fn name(&self) -> Arc<String> {
- static SERVER_NAME: Lazy<Arc<String>> = Lazy::new(|| Arc::new("Server".into()));
- static END_NAME: Lazy<Arc<String>> = Lazy::new(|| Arc::new("End".into()));
- match self {
- Self::Server(_) => SERVER_NAME.clone(),
- Self::End(_) => END_NAME.clone(),
- }
- }
- }
- async fn server_loop<Init, F>(
- _runtime: &'static Runtime,
- make_init: F,
- mut mailbox: Mailbox<PingProtocolMsgs>,
- actor_id: ActorId,
- ) -> ActorResult
- where
- Init: 'static + Server,
- F: 'static + Send + Sync + FnOnce() -> Init,
- {
- let mut state = ServerState::Server(make_init());
- while let Some(envelope) = mailbox.recv().await {
- state = match envelope {
- Envelope::Call {
- msg,
- reply: replier,
- ..
- } => match (state, msg) {
- (ServerState::Server(listening_state), PingProtocolMsgs::Ping(msg)) => {
- match listening_state.handle_ping(msg).await {
- TransResult::Ok((new_state, reply)) => {
- let replier = replier
- .ok_or_else(|| bterr!("Reply has already been sent."))
- .unwrap();
- if let Err(_) = replier.send(PingProtocolMsgs::PingReply(reply))
- {
- return Err(ActorError::new(
- bterr!("Failed to send Ping reply."),
- ActorErrorPayload {
- actor_id,
- actor_impl: Init::actor_impl(),
- state: Init::state_name(),
- message: PingProtocolMsgKinds::Ping.name(),
- kind: TransKind::Receive,
- },
- ));
- }
- ServerState::End(new_state)
- }
- TransResult::Abort { from, err } => {
- log::warn!("Aborted transition from the {} while handling the {} message: {}", "Server", "Ping", err);
- ServerState::Server(from)
- }
- TransResult::Fatal { err } => {
- return Err(ActorError::new(
- err,
- ActorErrorPayload {
- actor_id,
- actor_impl: Init::actor_impl(),
- state: Init::state_name(),
- message: PingProtocolMsgKinds::Ping.name(),
- kind: TransKind::Receive,
- },
- ));
- }
- }
- }
- (state, _) => state,
- },
- envelope => {
- return Err(ActorError::new(
- bterr!("Unexpected envelope type: {}", envelope.name()),
- ActorErrorPayload {
- actor_id,
- actor_impl: Init::actor_impl(),
- state: state.name(),
- message: envelope.msg_name(),
- kind: TransKind::Receive,
- },
- ))
- }
- };
- if let ServerState::End(_) = state {
- break;
- }
- }
- Ok(actor_id)
- }
- rt.register::<PingProtocolMsgs, _>(id, move |runtime| {
- let make_init = make_init.clone();
- let fut = async move {
- let actor_impl = runtime
- .spawn(None, move |mailbox, act_id, runtime| {
- server_loop(runtime, make_init, mailbox, act_id)
- })
- .await
- .unwrap();
- Ok(actor_impl)
- };
- Box::pin(fut)
- })
- .await
- }
- #[derive(Serialize, Deserialize)]
- pub struct Ping;
- impl CallMsg for Ping {
- type Reply = PingReply;
- }
- #[derive(Serialize, Deserialize)]
- pub struct PingReply;
- struct ClientImpl {
- counter: Arc<AtomicU8>,
- }
- impl ClientImpl {
- fn new(counter: Arc<AtomicU8>) -> Self {
- counter.fetch_add(1, Ordering::SeqCst);
- Self { counter }
- }
- }
- impl Client for ClientImpl {
- actor_name!("ping_client");
- type OnSendPingReturn = ();
- type OnSendPingFut = impl Future<Output = TransResult<Self, (End, ())>>;
- fn on_send_ping(self, _msg: PingReply) -> Self::OnSendPingFut {
- self.counter.fetch_sub(1, Ordering::SeqCst);
- ready(TransResult::Ok((End, ())))
- }
- }
- struct ServerState {
- counter: Arc<AtomicU8>,
- }
- impl ServerState {
- fn new(counter: Arc<AtomicU8>) -> Self {
- counter.fetch_add(1, Ordering::SeqCst);
- Self { counter }
- }
- }
- impl Server for ServerState {
- actor_name!("ping_server");
- type HandlePingFut = impl Future<Output = TransResult<Self, (End, PingReply)>>;
- fn handle_ping(self, _msg: Ping) -> Self::HandlePingFut {
- self.counter.fetch_sub(1, Ordering::SeqCst);
- ready(TransResult::Ok((End, PingReply)))
- }
- }
- #[test]
- fn ping_pong_test() {
- ASYNC_RT.block_on(async {
- const SERVICE_ID: &str = "PingPongProtocolServer";
- let service_id = ServiceId::from(SERVICE_ID);
- let counter = Arc::new(AtomicU8::new(0));
- let service_name = {
- let service_counter = counter.clone();
- let make_init = move || {
- let server_counter = service_counter.clone();
- ServerState::new(server_counter)
- };
- register_server_manual(make_init, &RUNTIME, service_id.clone())
- .await
- .unwrap()
- };
- let client_handle =
- spawn_client_manual(ClientImpl::new(counter.clone()), &RUNTIME).await;
- let service_addr = ServiceAddr::new(service_name, true);
- client_handle.send_ping(Ping, service_addr).await.unwrap();
- assert_eq!(0, counter.load(Ordering::SeqCst));
- });
- }
- }
- mod travel_agency {
- use super::*;
- // Here's another protocol example. This is the Customer and Travel Agency protocol used as an
- // example in the survey paper "Behavioral Types in Programming Languages."
- // Note that the Choosing state can send messages at any time, not just in response to another
- // message because there is a transition from Choosing that doesn't use the receive operator
- // (`?`).
- protocol! {
- named TravelAgency;
- let agency = [Listening];
- let customer = [Choosing];
- Choosing -> Choosing, >service(Listening)!Query;
- Choosing -> Choosing, >service(Listening)!Accept;
- Choosing -> Choosing, >service(Listening)!Reject;
- Listening?Query -> Listening, >Choosing!Query::Reply;
- Listening?Accept -> End, >Choosing!Accept::Reply;
- Listening?Reject -> End, >Choosing!Reject::Reply;
- }
- #[derive(Serialize, Deserialize)]
- pub struct Query;
- impl CallMsg for Query {
- type Reply = ();
- }
- #[derive(Serialize, Deserialize)]
- pub struct Reject;
- impl CallMsg for Reject {
- type Reply = ();
- }
- #[derive(Serialize, Deserialize)]
- pub struct Accept;
- impl CallMsg for Accept {
- type Reply = ();
- }
- }
- #[allow(dead_code)]
- mod client_callback {
- use super::*;
- use btlib::bterr;
- use once_cell::sync::Lazy;
- use std::{marker::PhantomData, panic::panic_any, time::Duration};
- use tokio::{sync::oneshot, time::timeout};
- #[derive(Serialize, Deserialize)]
- pub struct Register {
- factor: usize,
- }
- #[derive(Serialize, Deserialize)]
- pub struct Completed {
- value: usize,
- }
- protocol! {
- named ClientCallback;
- let server = [Listening];
- let worker = [Working];
- let client = [Unregistered, Registered];
- Unregistered -> Registered, >service(Listening)!Register[Registered];
- Listening?Register[Registered] -> Listening, Working[Registered];
- Working[Registered] -> End, >Registered!Completed;
- Registered?Completed -> End;
- }
- struct UnregisteredState {
- sender: oneshot::Sender<usize>,
- }
- impl Unregistered for UnregisteredState {
- actor_name!("callback_client");
- type OnSendRegisterReturn = ();
- type OnSendRegisterRegistered = RegisteredState;
- type OnSendRegisterFut = Ready<TransResult<Self, (Self::OnSendRegisterRegistered, ())>>;
- fn on_send_register(self) -> Self::OnSendRegisterFut {
- ready(TransResult::Ok((
- RegisteredState {
- sender: self.sender,
- },
- (),
- )))
- }
- }
- struct RegisteredState {
- sender: oneshot::Sender<usize>,
- }
- impl Registered for RegisteredState {
- type HandleCompletedFut = Ready<TransResult<Self, End>>;
- fn handle_completed(self, arg: Completed) -> Self::HandleCompletedFut {
- self.sender.send(arg.value).unwrap();
- ready(TransResult::Ok(End))
- }
- }
- struct ListeningState {
- multiple: usize,
- }
- impl Listening for ListeningState {
- actor_name!("callback_server");
- type HandleRegisterListening = ListeningState;
- type HandleRegisterWorking = WorkingState;
- type HandleRegisterFut = Ready<TransResult<Self, (ListeningState, WorkingState)>>;
- fn handle_register(self, arg: Register) -> Self::HandleRegisterFut {
- let multiple = self.multiple;
- ready(TransResult::Ok((
- self,
- WorkingState {
- factor: arg.factor,
- multiple,
- },
- )))
- }
- }
- struct WorkingState {
- factor: usize,
- multiple: usize,
- }
- impl Working for WorkingState {
- actor_name!("callback_worker");
- type OnSendCompletedFut = Ready<TransResult<Self, (End, Completed)>>;
- fn on_send_completed(self) -> Self::OnSendCompletedFut {
- let value = self.multiple * self.factor;
- ready(TransResult::Ok((End, Completed { value })))
- }
- }
- use ::tokio::sync::Mutex;
- enum ClientStateManual<Init: Unregistered> {
- Unregistered(Init),
- Registered(Init::OnSendRegisterRegistered),
- End(End),
- }
- impl<Init: Unregistered> Named for ClientStateManual<Init> {
- fn name(&self) -> Arc<String> {
- static UNREGISTERED_NAME: Lazy<Arc<String>> =
- Lazy::new(|| Arc::new("Unregistered".into()));
- static REGISTERED_NAME: Lazy<Arc<String>> = Lazy::new(|| Arc::new("Registered".into()));
- static END_NAME: Lazy<Arc<String>> = Lazy::new(|| Arc::new("End".into()));
- match self {
- Self::Unregistered(_) => UNREGISTERED_NAME.clone(),
- Self::Registered(_) => REGISTERED_NAME.clone(),
- Self::End(_) => END_NAME.clone(),
- }
- }
- }
- struct ClientHandleManual<Init: Unregistered, State> {
- runtime: &'static Runtime,
- state: Arc<Mutex<Option<ClientStateManual<Init>>>>,
- name: ActorName,
- type_state: PhantomData<State>,
- }
- impl<Init: Unregistered, State> ClientHandleManual<Init, State> {
- fn new_type<NewState>(self) -> ClientHandleManual<Init, NewState> {
- ClientHandleManual {
- runtime: self.runtime,
- state: self.state,
- name: self.name,
- type_state: PhantomData,
- }
- }
- }
- impl<
- Init: Unregistered,
- State: Unregistered<OnSendRegisterRegistered = NewState>,
- NewState: Registered,
- > ClientHandleManual<Init, State>
- {
- async fn send_register(
- self,
- to: ServiceAddr,
- msg: Register,
- ) -> TransResult<
- Self,
- (
- ClientHandleManual<Init, NewState>,
- Init::OnSendRegisterReturn,
- ),
- > {
- let mut guard = self.state.lock().await;
- let state = guard
- .take()
- .unwrap_or_else(|| panic!("Logic error. The state was not returned."));
- match state {
- ClientStateManual::Unregistered(state) => match state.on_send_register().await {
- TransResult::Ok((new_state, return_var)) => {
- let msg = ClientCallbackMsgs::Register(msg);
- let result = self.runtime.send_service(to, self.name.clone(), msg).await;
- if let Err(err) = result {
- return TransResult::Fatal { err };
- }
- *guard = Some(ClientStateManual::Registered(new_state));
- drop(guard);
- TransResult::Ok((self.new_type(), return_var))
- }
- TransResult::Abort { from, err } => {
- *guard = Some(ClientStateManual::Unregistered(from));
- drop(guard);
- return TransResult::Abort { from: self, err };
- }
- TransResult::Fatal { err } => {
- return TransResult::Fatal { err };
- }
- },
- state => {
- let name = state.name();
- *guard = Some(state);
- drop(guard);
- TransResult::Abort {
- from: self,
- err: bterr!(
- "Unexpected state '{}' for '{}' method.",
- name,
- "send_register"
- ),
- }
- }
- }
- }
- }
- async fn spawn_client_manual<Init>(
- init: Init,
- runtime: &'static Runtime,
- ) -> ClientHandleManual<Init, Init>
- where
- Init: 'static + Unregistered,
- {
- let state = Arc::new(Mutex::new(Some(ClientStateManual::Unregistered(init))));
- let name = {
- let state = state.clone();
- runtime.spawn(None, move |mut mailbox, actor_id, _| async move {
- while let Some(envelope) = mailbox.recv().await {
- let mut guard = state.lock().await;
- let state = guard.take()
- .unwrap_or_else(|| panic!("Logic error. The state was not returned."));
- let new_state = match envelope {
- Envelope::Send { msg, .. } => {
- match (state, msg) {
- (ClientStateManual::Registered(curr_state), ClientCallbackMsgs::Completed(msg)) => {
- match curr_state.handle_completed(msg).await {
- TransResult::Ok(next) => ClientStateManual::<Init>::End(next),
- TransResult::Abort { from, err } => {
- log::warn!("Aborted transition from the {} state while handling the {} message: {}", "Registered", "Completed", err);
- ClientStateManual::Registered(from)
- }
- TransResult::Fatal { err } => {
- panic_any(ActorError::new(
- err,
- ActorErrorPayload {
- actor_id,
- actor_impl: Init::actor_impl(),
- state: Init::OnSendRegisterRegistered::state_name(),
- message: ClientCallbackMsgKinds::Completed.name(),
- kind: TransKind::Receive,
- }));
- }
- }
- }
- (state, msg) => {
- log::error!("Unexpected message {} in state {}.", msg.name(), state.name());
- state
- }
- }
- }
- envelope => return Err(ActorError::new(
- bterr!("Unexpected envelope type: {}", envelope.name()),
- ActorErrorPayload {
- actor_id,
- actor_impl: Init::actor_impl(),
- state: state.name(),
- message: envelope.msg_name(),
- kind: TransKind::Receive,
- }))
- };
- *guard = Some(new_state);
- if let Some(state) = &*guard {
- if let ClientStateManual::End(_) = state {
- break;
- }
- }
- }
- Ok(actor_id)
- }).await.unwrap()
- };
- ClientHandleManual {
- runtime,
- state,
- name,
- type_state: PhantomData,
- }
- }
- async fn register_server_manual<Init, F>(
- make_init: F,
- runtime: &'static Runtime,
- service_id: ServiceId,
- ) -> Result<ServiceName>
- where
- Init: 'static + Listening<HandleRegisterListening = Init>,
- F: 'static + Send + Sync + Clone + Fn() -> Init,
- {
- enum ServerState<Init: Listening> {
- Listening(Init),
- }
- impl<S: Listening> Named for ServerState<S> {
- fn name(&self) -> Arc<String> {
- static LISTENING_NAME: Lazy<Arc<String>> =
- Lazy::new(|| Arc::new("Listening".into()));
- match self {
- Self::Listening(_) => LISTENING_NAME.clone(),
- }
- }
- }
- async fn server_loop<Init, F>(
- runtime: &'static Runtime,
- make_init: F,
- mut mailbox: Mailbox<ClientCallbackMsgs>,
- actor_id: ActorId,
- ) -> ActorResult
- where
- Init: 'static + Listening<HandleRegisterListening = Init>,
- F: 'static + Send + Sync + Fn() -> Init,
- {
- let mut state = ServerState::Listening(make_init());
- while let Some(envelope) = mailbox.recv().await {
- let new_state = match envelope {
- Envelope::Send { msg, from, .. } => match (state, msg) {
- (ServerState::Listening(curr_state), ClientCallbackMsgs::Register(msg)) => {
- match curr_state.handle_register(msg).await {
- TransResult::Ok((new_state, working_state)) => {
- start_worker_manual(working_state, from, runtime).await;
- ServerState::Listening(new_state)
- }
- TransResult::Abort { from, err } => {
- log::warn!("Aborted transition from the {} state while handling the {} message: {}", "Listening", "Register", err);
- ServerState::Listening(from)
- }
- TransResult::Fatal { err } => {
- let err = ActorError::new(
- err,
- ActorErrorPayload {
- actor_id,
- actor_impl: Init::actor_impl(),
- state: Init::state_name(),
- message: ClientCallbackMsgKinds::Register.name(),
- kind: TransKind::Receive,
- },
- );
- panic_any(format!("{err}"));
- }
- }
- }
- (state, msg) => {
- log::error!(
- "Unexpected message {} in state {}.",
- msg.name(),
- state.name()
- );
- state
- }
- },
- envelope => {
- return Err(ActorError::new(
- bterr!("Unexpected envelope type: {}", envelope.name()),
- ActorErrorPayload {
- actor_id,
- actor_impl: Init::actor_impl(),
- state: state.name(),
- message: envelope.msg_name(),
- kind: TransKind::Receive,
- },
- ))
- }
- };
- state = new_state;
- }
- Ok(actor_id)
- }
- runtime
- .register::<ClientCallbackMsgs, _>(service_id, move |runtime: &'static Runtime| {
- let make_init = make_init.clone();
- let fut = async move {
- let make_init = make_init.clone();
- let actor_impl = runtime
- .spawn(None, move |mailbox, act_id, runtime| {
- server_loop(runtime, make_init, mailbox, act_id)
- })
- .await
- .unwrap();
- Ok(actor_impl)
- };
- Box::pin(fut)
- })
- .await
- }
- async fn start_worker_manual<Init>(
- init: Init,
- owned: ActorName,
- runtime: &'static Runtime,
- ) -> ActorName
- where
- Init: 'static + Working,
- {
- enum WorkerState<S: Working> {
- Working(S),
- }
- runtime
- .spawn::<ClientCallbackMsgs, _, _>(
- Some(owned.clone()),
- move |_, actor_id, _| async move {
- let msg = match init.on_send_completed().await {
- TransResult::Ok((End, msg)) => msg,
- TransResult::Abort { err, .. } | TransResult::Fatal { err } => {
- let err = ActorError::new(
- err,
- ActorErrorPayload {
- actor_id,
- actor_impl: Init::actor_impl(),
- state: Init::state_name(),
- message: ClientCallbackMsgKinds::Completed.name(),
- kind: TransKind::Send,
- },
- );
- panic_any(format!("{err}"))
- }
- };
- let from = runtime.actor_name(actor_id);
- let msg = ClientCallbackMsgs::Completed(msg);
- runtime.send(owned, from, msg).await.unwrap_or_else(|err| {
- let err = ActorError::new(
- err,
- ActorErrorPayload {
- actor_id,
- actor_impl: Init::actor_impl(),
- state: Init::state_name(),
- message: ClientCallbackMsgKinds::Completed.name(),
- kind: TransKind::Send,
- },
- );
- panic_any(format!("{err}"));
- });
- Ok(actor_id)
- },
- )
- .await
- .unwrap()
- }
- #[test]
- fn client_callback_protocol() {
- ASYNC_RT.block_on(async {
- const SERVICE_ID: &str = "ClientCallbackProtocolListening";
- let service_id = ServiceId::from(SERVICE_ID);
- let service_name = {
- let make_init = move || ListeningState { multiple: 2 };
- register_server_manual(make_init, &RUNTIME, service_id.clone())
- .await
- .unwrap()
- };
- let (sender, receiver) = oneshot::channel();
- let client_handle = spawn_client_manual(UnregisteredState { sender }, &RUNTIME).await;
- let service_addr = ServiceAddr::new(service_name, false);
- client_handle
- .send_register(service_addr, Register { factor: 21 })
- .await
- .unwrap();
- let value = timeout(Duration::from_millis(500), receiver)
- .await
- .unwrap()
- .unwrap();
- assert_eq!(42, value);
- });
- }
- }
|