use std::collections::HashSet; use proc_macro2::{Ident, Span}; use crate::{ error::MaybeErr, parsing::{DestinationState, Message, State}, Protocol, }; impl Protocol { pub(crate) fn validate(&self) -> syn::Result<()> { self.all_states_declared_and_used() .combine(self.match_receivers_and_senders()) .combine(self.no_undeliverable_msgs()) .combine(self.valid_replies()) .into() } const UNDECLARED_STATE_ERR: &str = "State was not declared."; const UNUSED_STATE_ERR: &str = "State was declared but never used."; const END_STATE: &str = "End"; /// Verifies that every state which is used has been declared, except for the End state. fn all_states_declared_and_used(&self) -> MaybeErr { let end = Ident::new(Self::END_STATE, Span::call_site()); let declared: HashSet<&Ident> = self .states_def .states .as_ref() .iter() .chain([&end].into_iter()) .collect(); let mut used: HashSet<&Ident> = HashSet::with_capacity(declared.len()); for transition in self.transitions.iter() { let in_state = &transition.in_state; used.insert(&in_state.state_trait); used.extend(in_state.owned_states.as_ref().iter()); if let Some(in_msg) = &transition.in_msg { used.extend(in_msg.owned_states.as_ref().iter()); } for out_states in transition.out_states.as_ref().iter() { used.insert(&out_states.state_trait); used.extend(out_states.owned_states.as_ref().iter()); } // We don't have to check the states referred to in out_msgs because the // match_receivers_and_senders method ensures that each of these exists in a receiver // position. } let undeclared: MaybeErr = used .difference(&declared) .map(|ident| syn::Error::new(ident.span(), Self::UNDECLARED_STATE_ERR)) .collect(); let unused: MaybeErr = declared .difference(&used) .filter(|ident| **ident != Self::END_STATE) .map(|ident| syn::Error::new(ident.span(), Self::UNUSED_STATE_ERR)) .collect(); undeclared.combine(unused) } const UNMATCHED_SENDER_ERR: &str = "No receiver found for message type."; const UNMATCHED_RECEIVER_ERR: &str = "No sender found for message type."; const ACTIVATE_MSG: &str = "Activate"; /// Ensures that the recipient state for every sent message has a receiving transition /// defined, and every receiver has a sender (except for the Activate message which is sent /// by the runtime). fn match_receivers_and_senders(&self) -> MaybeErr { let mut senders: HashSet<(&State, &Message)> = HashSet::new(); let mut receivers: HashSet<(&State, &Message)> = HashSet::new(); for transition in self.transitions.iter() { if let Some(msg) = &transition.in_msg { receivers.insert((&transition.in_state, msg)); if msg.msg_type == Self::ACTIVATE_MSG { // The Activate message is sent by the run time, so a sender is created to // represent it. senders.insert((&transition.in_state, msg)); } } for dest in transition.out_msgs.as_ref().iter() { let dest_state = match &dest.state { DestinationState::Individual(dest_state) => dest_state, DestinationState::Service(dest_state) => dest_state, }; senders.insert((dest_state, &dest.msg)); } } let extra_senders: MaybeErr = senders .difference(&receivers) .map(|pair| syn::Error::new(pair.1.msg_type.span(), Self::UNMATCHED_SENDER_ERR)) .collect(); let extra_receivers: MaybeErr = receivers .difference(&senders) .map(|pair| syn::Error::new(pair.1.msg_type.span(), Self::UNMATCHED_RECEIVER_ERR)) .collect(); extra_senders.combine(extra_receivers) } const UNDELIVERABLE_ERR: &str = "Receiver must either be a service, an owned state, or an out state, or the message must be a reply."; /// Checks that messages are only sent to destinations which are either services, states /// which are owned by the sender, listed in the output states, or that the message is a /// reply. fn no_undeliverable_msgs(&self) -> MaybeErr { let mut err = MaybeErr::none(); for transition in self.transitions.iter() { let mut allowed_states: Option> = None; for dest in transition.out_msgs.as_ref().iter() { if dest.msg.is_reply { continue; } match &dest.state { DestinationState::Service(_) => continue, DestinationState::Individual(dest_state) => { let allowed = allowed_states.get_or_insert_with(|| { transition .out_states .as_ref() .iter() .map(|state| &state.state_trait) .chain(transition.in_state.owned_states.as_ref().iter()) .collect() }); if !allowed.contains(&dest_state.state_trait) { err = err.combine( syn::Error::new( dest_state.state_trait.span(), Self::UNDELIVERABLE_ERR, ) .into(), ); } } } } } err } const INVALID_REPLY_ERR: &str = "Replies can only be used in transitions which handle messages."; const MULTIPLE_REPLIES_ERR: &str = "Only a single reply can be sent in response to any message."; /// Verifies that replies are only sent in response to messages. fn valid_replies(&self) -> MaybeErr { let mut err = MaybeErr::none(); for transition in self.transitions.iter() { let replies: Vec<_> = transition .out_msgs .as_ref() .iter() .map(|dest| &dest.msg) .filter(|msg| msg.is_reply) .collect(); if replies.is_empty() { continue; } if replies.len() > 1 { err = err.combine( replies .iter() .map(|reply| { syn::Error::new(reply.msg_type.span(), Self::MULTIPLE_REPLIES_ERR) }) .collect(), ); } if transition.in_msg.is_none() { err = err.combine( replies .iter() .map(|reply| { syn::Error::new(reply.msg_type.span(), Self::INVALID_REPLY_ERR) }) .collect(), ); } } err } } #[cfg(test)] mod tests { use super::*; use syn::parse_str; macro_rules! assert_ok { ($maybe_err:expr) => { let result: syn::Result<()> = $maybe_err.into(); assert!(result.is_ok(), "{}", result.err().unwrap()); }; } macro_rules! assert_err { ($maybe_err:expr, $expected_msg:expr) => { let result: syn::Result<()> = $maybe_err.into(); assert!(result.is_err()); assert_eq!($expected_msg, result.err().unwrap().to_string()); }; } /// A minimal valid protocol definition. const MIN_PROTOCOL: &str = " let name = Test; let states = [Init]; Init?Activate -> End; "; #[test] fn all_states_declared_and_used_ok() { let result = parse_str::(MIN_PROTOCOL) .unwrap() .all_states_declared_and_used(); assert_ok!(result); } #[test] fn all_states_declared_and_used_end_not_used_ok() { const INPUT: &str = " let name = Test; let states = [Init]; Init?Activate -> Init; "; let result = parse_str::(INPUT) .unwrap() .all_states_declared_and_used(); assert_ok!(result); } #[test] fn all_states_declared_and_used_undeclared_err() { const INPUT: &str = " let name = Undeclared; let states = [Init]; Init?Activate -> Next; "; let result = parse_str::(INPUT) .unwrap() .all_states_declared_and_used(); assert_err!(result, Protocol::UNDECLARED_STATE_ERR); } #[test] fn all_states_declared_and_used_undeclared_out_state_owned_err() { const INPUT: &str = " let name = Undeclared; let states = [Init, Next]; Init?Activate -> Init, Next[Undeclared]; "; let result = parse_str::(INPUT) .unwrap() .all_states_declared_and_used(); assert_err!(result, Protocol::UNDECLARED_STATE_ERR); } #[test] fn all_states_declared_and_used_undeclared_in_state_owned_err() { const INPUT: &str = " let name = Undeclared; let states = [Init, Next]; Init[Undeclared]?Activate -> Next; "; let result = parse_str::(INPUT) .unwrap() .all_states_declared_and_used(); assert_err!(result, Protocol::UNDECLARED_STATE_ERR); } #[test] fn all_states_declared_and_used_unused_err() { const INPUT: &str = " let name = Unused; let states = [Init, Extra]; Init?Activate -> End; "; let result = parse_str::(INPUT) .unwrap() .all_states_declared_and_used(); assert_err!(result, Protocol::UNUSED_STATE_ERR); } #[test] fn match_receivers_and_senders_ok() { let result = parse_str::(MIN_PROTOCOL) .unwrap() .match_receivers_and_senders(); assert_ok!(result); } #[test] fn match_receivers_and_senders_send_activate_ok() { const INPUT: &str = " let name = Unbalanced; let states = [First, Second]; First?Activate -> First, >Second!Activate; Second?Activate -> Second; "; let result = parse_str::(INPUT) .unwrap() .match_receivers_and_senders(); assert_ok!(result); } #[test] fn match_receivers_and_senders_unmatched_sender_err() { const INPUT: &str = " let name = Unbalanced; let states = [Init, Other]; Init?Activate -> Init, >Other!Activate; "; let result = parse_str::(INPUT) .unwrap() .match_receivers_and_senders(); assert_err!(result, Protocol::UNMATCHED_SENDER_ERR); } #[test] fn match_receivers_and_senders_unmatched_receiver_err() { const INPUT: &str = " let name = Unbalanced; let states = [Init]; Init?NotExists -> Init; "; let result = parse_str::(INPUT) .unwrap() .match_receivers_and_senders(); assert_err!(result, Protocol::UNMATCHED_RECEIVER_ERR); } #[test] fn no_undeliverable_msgs_ok() { let result = parse_str::(MIN_PROTOCOL) .unwrap() .no_undeliverable_msgs(); assert_ok!(result); } #[test] fn no_undeliverable_msgs_reply_ok() { const INPUT: &str = " let name = Undeliverable; let states = [Listening, Client]; Listening?Msg -> Listening, >Client!Msg::Reply; "; let result = parse_str::(INPUT) .unwrap() .no_undeliverable_msgs(); assert_ok!(result); } #[test] fn no_undeliverable_msgs_service_ok() { const INPUT: &str = " let name = Undeliverable; let states = [Client, Server]; Client -> Client, >service(Server)!Msg; "; let result = parse_str::(INPUT) .unwrap() .no_undeliverable_msgs(); assert_ok!(result); } #[test] fn no_undeliverable_msgs_owned_ok() { const INPUT: &str = " let name = Undeliverable; let states = [FileClient, FileHandle]; FileClient[FileHandle] -> FileClient, >FileHandle!FileOp; "; let result = parse_str::(INPUT) .unwrap() .no_undeliverable_msgs(); assert_ok!(result); } #[test] fn no_undeliverable_msgs_err() { const INPUT: &str = " let name = Undeliverable; let states = [Client, Server]; Client -> Client, >Server!Msg; "; let result = parse_str::(INPUT) .unwrap() .no_undeliverable_msgs(); assert_err!(result, Protocol::UNDELIVERABLE_ERR); } #[test] fn valid_replies_ok() { const INPUT: &str = " let name = ValidReplies; let states = [Client, Server]; Server?Msg -> Server, >Client!Msg::Reply; "; let result = parse_str::(INPUT).unwrap().valid_replies(); assert_ok!(result); } #[test] fn valid_replies_invalid_reply_err() { const INPUT: &str = " let name = ValidReplies; let states = [Client, Server]; Client -> Client, >Server!Msg::Reply; "; let result = parse_str::(INPUT).unwrap().valid_replies(); assert_err!(result, Protocol::INVALID_REPLY_ERR); } #[test] fn valid_replies_multiple_replies_err() { const INPUT: &str = " let name = ValidReplies; let states = [Client, OtherClient, Server]; Server?Msg -> Server, >Client!Msg::Reply, OtherClient!Msg::Reply; "; let result = parse_str::(INPUT).unwrap().valid_replies(); assert_err!(result, Protocol::MULTIPLE_REPLIES_ERR); } }