use std::{collections::HashSet, hash::Hash}; use proc_macro2::{Ident, Span}; use btrun::End; use crate::{ error::{self, MaybeErr}, model::{MsgInfo, ProtocolModel}, parsing::{DestinationState, GetSpan, State}, }; impl ProtocolModel { #[allow(dead_code)] pub(crate) fn validate(&self) -> syn::Result<()> { self.all_states_declared_and_used() .combine(self.receivers_and_senders_matched()) .combine(self.no_undeliverable_msgs()) .combine(self.replies_expected()) .combine(self.no_unobservable_states()) .into() } /// Verifies that every state which is declared is actually used. fn all_states_declared_and_used(&self) -> MaybeErr { let end = Ident::new(End::ident(), Span::call_site()); let mut declared: HashSet<&Ident> = HashSet::new(); declared.insert(&end); for actor_def in self.def().actor_defs.iter() { for state in actor_def.states.as_ref().iter() { declared.insert(state); } } let mut used: HashSet<&Ident> = HashSet::with_capacity(declared.len()); for transition in self.def().transitions.iter() { let in_state = &transition.in_state; used.insert(&in_state.state_trait); used.extend(in_state.owned_states().map(|ident| ident.as_ref())); if let Some(in_msg) = transition.in_msg() { used.extend(in_msg.owned_states().map(|ident| ident.as_ref())); } for out_states in transition.out_states.as_ref().iter() { used.insert(&out_states.state_trait); used.extend(out_states.owned_states().map(|ident| ident.as_ref())); } // We don't have to check the states referred to in out_msgs because the // receivers_and_senders_matched method ensures that each of these exists in a receiver // position. } let undeclared: MaybeErr = used .difference(&declared) .map(|ident| syn::Error::new(ident.span(), error::msgs::UNDECLARED_STATE)) .collect(); let unused: MaybeErr = declared .difference(&used) .filter(|ident| **ident != End::ident()) .map(|ident| syn::Error::new(ident.span(), error::msgs::UNUSED_STATE)) .collect(); undeclared.combine(unused) } /// Ensures that the recipient state for every sent message has a receiving transition /// defined, and every receiver has a sender. Note that each message isn't required to have a /// unique sender or a unique receiver, just that at least one of each much be defined. fn receivers_and_senders_matched<'s>(&'s self) -> MaybeErr { /// Represents a message sender or receiver. /// /// This type is essentially just a tuple of references, but was created so a [Hash] /// implementation could be defined. #[cfg_attr(test, derive(Debug))] struct MsgEndpoint<'a> { state: &'a State, msg_info: &'a MsgInfo, } impl<'a> MsgEndpoint<'a> { fn new(state: &'a State, msg_info: &'a MsgInfo) -> Self { Self { state, msg_info } } } impl<'a> PartialEq for MsgEndpoint<'a> { fn eq(&self, other: &Self) -> bool { self.state.state_trait == other.state.state_trait && self.msg_info.def().msg_type == self.msg_info.def().msg_type && self.msg_info.is_reply() == self.msg_info.is_reply() } } impl<'a> Eq for MsgEndpoint<'a> {} impl<'a> Hash for MsgEndpoint<'a> { fn hash(&self, state: &mut H) { self.state.state_trait.hash(state); self.msg_info.def().msg_type.hash(state); self.msg_info.is_reply().hash(state); } } #[cfg(test)] impl<'a> std::fmt::Display for MsgEndpoint<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "({}, {})", self.state.state_trait, self.msg_info.msg_name() ) } } let msgs = self.msg_lookup(); let mut outgoing: HashSet> = HashSet::new(); let mut incoming: HashSet> = HashSet::new(); for actor in self.actors_iter() { let methods = actor .states() .values() .flat_map(|state| state.methods().values()); for method in methods { let transition = method.def(); if let Some(msg) = transition.in_msg() { let msg_info = msgs.lookup(msg); incoming.insert(MsgEndpoint::new(&transition.in_state, msg_info)); } for dest in transition.out_msgs.as_ref().iter() { let dest_state = match &dest.state { DestinationState::Individual(dest_state) => dest_state, DestinationState::Service(dest_state) => dest_state, }; let msg_info = self.msg_lookup().lookup(&dest.msg); outgoing.insert(MsgEndpoint::new(dest_state, msg_info)); if actor.is_client() { if let Some(reply) = msg_info.reply() { incoming.insert(MsgEndpoint::new(&transition.in_state, reply)); } } } } } let extra_senders: MaybeErr = outgoing .difference(&incoming) .map(|endpoint| { syn::Error::new( endpoint.msg_info.def().span(), error::msgs::UNMATCHED_OUTGOING, ) }) .collect(); let extra_receivers: MaybeErr = incoming .difference(&outgoing) .map(|endpoint| { syn::Error::new( endpoint.msg_info.def().span(), error::msgs::UNMATCHED_INCOMING, ) }) .collect(); extra_senders.combine(extra_receivers) } /// Checks that messages are only sent to destinations which are either services, states /// which are owned by the sender, listed in the output states, or that the message is a /// reply. fn no_undeliverable_msgs(&self) -> MaybeErr { let mut err = MaybeErr::none(); for transition in self.def().transitions.iter() { let mut allowed_states: Option> = None; for dest in transition.out_msgs.as_ref().iter() { if dest.msg.is_reply() { continue; } match &dest.state { DestinationState::Service(_) => continue, DestinationState::Individual(dest_state) => { let owned_states = transition .in_state .owned_states() .map(|ident| ident.as_ref()); let allowed = allowed_states.get_or_insert_with(|| { transition .out_states .as_ref() .iter() .map(|state| state.state_trait.as_ref()) .chain(owned_states) .collect() }); if !allowed.contains(dest_state.state_trait.as_ref()) { err = err.combine( syn::Error::new( dest_state.state_trait.span(), error::msgs::UNDELIVERABLE, ) .into(), ); } } } } } err } /// Verifies that exactly one reply is sent in response to a previously sent message. fn replies_expected(&self) -> MaybeErr { let mut err = MaybeErr::none(); for transition in self.def().transitions.iter() { let replies: Vec<_> = transition .out_msgs .as_ref() .iter() .map(|dest| &dest.msg) .filter(|msg| msg.is_reply()) .collect(); if replies.is_empty() { continue; } if replies.len() > 1 { err = err.combine( replies .iter() .map(|reply| { syn::Error::new(reply.msg_type.span(), error::msgs::MULTIPLE_REPLIES) }) .collect(), ); } if transition.in_msg().is_none() { err = err.combine( replies .iter() .map(|reply| { syn::Error::new(reply.msg_type.span(), error::msgs::INVALID_REPLY) }) .collect(), ); } } err } /// Checks that there are no client states which are only receiving replies. Such states can't /// be observed, because the methods which sent the original messages will return their replies. fn no_unobservable_states(&self) -> MaybeErr { self.actors_iter() .filter(|actor| actor.is_client()) .flat_map(|actor| actor.states().values()) .filter(|state| { state.methods().values().all(|method| { if let Some(in_msg) = method.def().in_msg() { in_msg.is_reply() } else { false } }) }) .map(|state| syn::Error::new(state.span(), error::msgs::UNOBSERVABLE_STATE)) .collect() } } #[cfg(test)] mod tests { use super::*; use crate::{ error::{assert_err, assert_ok}, parsing::{ActorDef, Dest, Message, NameDef, Protocol, Transition}, }; #[test] fn all_states_declared_and_used_ok() { let input = ProtocolModel::new(Protocol::minimal()).unwrap(); let result = input.all_states_declared_and_used(); assert_ok(result); } #[test] fn all_states_declared_and_used_end_not_used_ok() { const STATE_NAME: &str = "Init"; let input = ProtocolModel::new(Protocol::new( NameDef::new("Test"), [ActorDef::new("actor", [STATE_NAME])], [Transition::new( State::new(STATE_NAME, []), Some(Message::new("Activate", false, [])), [State::new(STATE_NAME, [])], [], )], )) .unwrap(); let result = input.all_states_declared_and_used(); assert_ok(result); } #[test] fn all_states_declared_and_used_undeclared_err() { let input = ProtocolModel::new(Protocol::new( NameDef::new("Undeclared"), [ActorDef::new("actor", ["Init"])], [Transition::new( State::new("Init", []), Some(Message::new("Activate", false, [])), [State::new("Next", [])], [], )], )) .unwrap(); let result = input.all_states_declared_and_used(); assert_err(result, error::msgs::UNDECLARED_STATE); } #[test] fn all_states_declared_and_used_undeclared_out_state_owned_err() { let input = ProtocolModel::new(Protocol::new( NameDef::new("Undeclared"), [ActorDef::new("actor", ["Init", "Next"])], [Transition::new( State::new("Init", []), Some(Message::new("Activate", false, [])), [State::new("Init", []), State::new("Next", ["Undeclared"])], [], )], )) .unwrap(); let result = input.all_states_declared_and_used(); assert_err(result, error::msgs::UNDECLARED_STATE); } #[test] fn all_states_declared_and_used_undeclared_in_state_owned_err() { let input = ProtocolModel::new(Protocol::new( NameDef::new("Undeclared"), [ActorDef::new("actor", ["Init", "Next"])], [Transition::new( State::new("Init", ["Undeclared"]), Some(Message::new("Activate", false, [])), [State::new("Next", [])], [], )], )) .unwrap(); let result = input.all_states_declared_and_used(); assert_err(result, error::msgs::UNDECLARED_STATE); } #[test] fn all_states_declared_and_used_unused_err() { let input = ProtocolModel::new(Protocol::new( NameDef::new("Unused"), [ActorDef::new("actor", ["Init", "Extra"])], [Transition::new( State::new("Init", []), Some(Message::new("Activate", false, [])), [State::new("End", [])], [], )], )) .unwrap(); let result = input.all_states_declared_and_used(); assert_err(result, error::msgs::UNUSED_STATE); } #[test] fn receivers_and_senders_matched_ok() { let input = ProtocolModel::new(Protocol::minimal()).unwrap(); let result = input.receivers_and_senders_matched(); assert_ok(result); } #[test] fn receivers_and_senders_call_msg_reused_for_non_call_ok() { let input = ProtocolModel::new(Protocol::new( NameDef::new("OwnedTypes"), [ ActorDef::new("server", ["Listening"]), ActorDef::new("client", ["Client"]), ActorDef::new("file", ["FileInit", "Opened"]), ActorDef::new("file_handle", ["FileHandle"]), ], [ Transition::new( State::new("Client", []), None, [ State::new("Client", []), State::new("FileHandle", ["Opened"]), ], [Dest::new( DestinationState::Service(State::new("Listening", [])), Message::new("Open", false, []), )], ), Transition::new( State::new("Listening", []), Some(Message::new("Open", false, [])), [State::new("Listening", []), State::new("FileInit", [])], [ Dest::new( DestinationState::Individual(State::new("Client", [])), Message::new("Open", true, ["Opened"]), ), // Note that the same "Open" message is being used here, but that the // FileInit state does not send a reply. This should be allowed. Dest::new( DestinationState::Individual(State::new("FileInit", [])), Message::new("Open", false, []), ), ], ), Transition::new( State::new("FileInit", []), Some(Message::new("Open", false, [])), [State::new("Opened", [])], [], ), ], )) .unwrap(); let result = input.receivers_and_senders_matched(); assert_ok(result); } #[test] fn receivers_and_senders_matched_unmatched_sender_err() { let input = ProtocolModel::new(Protocol::new( NameDef::new("Unbalanced"), [ActorDef::new("actor", ["Init"])], [Transition::new( State::new("Init", []), None, [State::new("Init", [])], [Dest::new( DestinationState::Service(State::new("Init", [])), Message::new("Msg", false, []), )], )], )) .unwrap(); let result = input.receivers_and_senders_matched(); assert_err(result, error::msgs::UNMATCHED_OUTGOING); } #[test] fn receivers_and_senders_matched_unmatched_receiver_err() { let input = ProtocolModel::new(Protocol::new( NameDef::new("Unbalanced"), [ActorDef::new("actor", ["Init"])], [Transition::new( State::new("Init", []), Some(Message::new("NotExists", false, [])), [State::new("Init", [])], [], )], )) .unwrap(); let result = input.receivers_and_senders_matched(); assert_err(result, error::msgs::UNMATCHED_INCOMING); } #[test] fn receivers_and_senders_matched_servers_must_explicitly_receive_replies_err() { // Only client actors are allowed to implicitly receive replies. let input = ProtocolModel::new(Protocol::new( NameDef::new("Conversation"), [ ActorDef::new("alice", ["Alice"]), ActorDef::new("bob", ["Bob"]), ], [ Transition::new( State::new("Alice", []), None, [State::new("Alice", [])], [Dest::new( DestinationState::Service(State::new("Bob", [])), Message::new("Greeting", false, []), )], ), // Notice that because Bob only has transitions which handle messages, bob is a // server actor. Transition::new( State::new("Bob", []), Some(Message::new("Greeting", false, [])), [State::new("Bob", [])], [Dest::new( DestinationState::Individual(State::new("Alice", [])), Message::new("Query", false, []), )], ), // Alice is sending a Query::Reply to Bob, but because he does not have a // transition which accepts that message type, this will be an unmatched outgoing // error. Transition::new( State::new("Alice", []), Some(Message::new("Query", false, [])), [State::new("End", [])], [Dest::new( DestinationState::Individual(State::new("Bob", [])), Message::new("Query", true, []), )], ), ], )) .unwrap(); let result = input.receivers_and_senders_matched(); assert_err(result, error::msgs::UNMATCHED_OUTGOING); } #[test] fn no_undeliverable_msgs_ok() { let input = ProtocolModel::new(Protocol::minimal()).unwrap(); let result = input.no_undeliverable_msgs(); assert_ok(result); } #[test] fn no_undeliverable_msgs_reply_ok() { let input = ProtocolModel::new(Protocol::new( NameDef::new("Undeliverable"), [ActorDef::new("actor", ["Listening", "Client"])], [Transition::new( State::new("Listening", []), Some(Message::new("Msg", false, [])), [State::new("Listening", [])], [Dest::new( DestinationState::Individual(State::new("Client", [])), Message::new("Msg", true, []), )], )], )) .unwrap(); let result = input.no_undeliverable_msgs(); assert_ok(result); } #[test] fn no_undeliverable_msgs_service_ok() { let input = ProtocolModel::new(Protocol::new( NameDef::new("Undeliverable"), [ActorDef::new("actor", ["Client", "Server"])], [Transition::new( State::new("Client", []), None, [State::new("Client", [])], [Dest::new( DestinationState::Service(State::new("Server", [])), Message::new("Msg", false, []), )], )], )) .unwrap(); let result = input.no_undeliverable_msgs(); assert_ok(result); } #[test] fn no_undeliverable_msgs_owned_ok() { let input = ProtocolModel::new(Protocol::new( NameDef::new("Undeliverable"), [ActorDef::new("actor", ["FileClient", "FileHandle"])], [Transition::new( State::new("FileClient", ["FileHandle"]), None, [State::new("FileClient", [])], [Dest::new( DestinationState::Individual(State::new("FileHandle", [])), Message::new("FileOp", false, []), )], )], )) .unwrap(); let result = input.no_undeliverable_msgs(); assert_ok(result); } #[test] fn no_undeliverable_msgs_err() { let input = ProtocolModel::new(Protocol::new( NameDef::new("Undeliverable"), [ActorDef::new("actor", ["Client", "Server"])], [Transition::new( State::new("Client", []), None, [State::new("Client", [])], [Dest::new( DestinationState::Individual(State::new("Server", [])), Message::new("Msg", false, []), )], )], )) .unwrap(); let result = input.no_undeliverable_msgs(); assert_err(result, error::msgs::UNDELIVERABLE); } #[test] fn replies_expected_ok() { let input = ProtocolModel::new(Protocol::new( NameDef::new("ValidReplies"), [ActorDef::new("actor", ["Client", "Server"])], [Transition::new( State::new("Server", []), Some(Message::new("Msg", false, [])), [State::new("Server", [])], [Dest::new( DestinationState::Individual(State::new("Client", [])), Message::new("Msg", true, []), )], )], )) .unwrap(); let result = input.replies_expected(); assert_ok(result); } #[test] fn replies_expected_invalid_reply_err() { let input = ProtocolModel::new(Protocol::new( NameDef::new("ValidReplies"), [ActorDef::new("actor", ["Client", "Server"])], [Transition::new( State::new("Client", []), None, [State::new("Client", [])], [Dest::new( DestinationState::Individual(State::new("Server", [])), Message::new("Msg", true, []), )], )], )) .unwrap(); let result = input.replies_expected(); assert_err(result, error::msgs::INVALID_REPLY); } #[test] fn replies_expected_multiple_replies_err() { let input = ProtocolModel::new(Protocol::new( NameDef::new("ValidReplies"), [ActorDef::new("actor", ["Client", "OtherClient", "Server"])], [Transition::new( State::new("Server", []), Some(Message::new("Msg", false, [])), [State::new("Server", [])], [ Dest::new( DestinationState::Individual(State::new("Client", [])), Message::new("Msg", true, []), ), Dest::new( DestinationState::Individual(State::new("OtherClient", [])), Message::new("Msg", true, []), ), ], )], )) .unwrap(); let result = input.replies_expected(); assert_err(result, error::msgs::MULTIPLE_REPLIES); } }