123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459 |
- use std::collections::HashSet;
- use proc_macro2::{Ident, Span};
- use crate::{
- error::MaybeErr,
- parsing::{DestinationState, Message, State},
- Protocol,
- };
- impl Protocol {
- pub(crate) fn validate(&self) -> syn::Result<()> {
- self.all_states_declared_and_used()
- .combine(self.match_receivers_and_senders())
- .combine(self.no_undeliverable_msgs())
- .combine(self.valid_replies())
- .into()
- }
- const UNDECLARED_STATE_ERR: &str = "State was not declared.";
- const UNUSED_STATE_ERR: &str = "State was declared but never used.";
- const END_STATE: &str = "End";
- /// Verifies that every state which is used has been declared, except for the End state.
- fn all_states_declared_and_used(&self) -> MaybeErr {
- let end = Ident::new(Self::END_STATE, Span::call_site());
- let declared: HashSet<&Ident> = self
- .states_def
- .states
- .as_ref()
- .iter()
- .chain([&end].into_iter())
- .collect();
- let mut used: HashSet<&Ident> = HashSet::with_capacity(declared.len());
- for transition in self.transitions.iter() {
- let in_state = &transition.in_state;
- used.insert(&in_state.state_trait);
- used.extend(in_state.owned_states.as_ref().iter());
- if let Some(in_msg) = &transition.in_msg {
- used.extend(in_msg.owned_states.as_ref().iter());
- }
- for out_states in transition.out_states.as_ref().iter() {
- used.insert(&out_states.state_trait);
- used.extend(out_states.owned_states.as_ref().iter());
- }
- // We don't have to check the states referred to in out_msgs because the
- // match_receivers_and_senders method ensures that each of these exists in a receiver
- // position.
- }
- let undeclared: MaybeErr = used
- .difference(&declared)
- .map(|ident| syn::Error::new(ident.span(), Self::UNDECLARED_STATE_ERR))
- .collect();
- let unused: MaybeErr = declared
- .difference(&used)
- .filter(|ident| **ident != Self::END_STATE)
- .map(|ident| syn::Error::new(ident.span(), Self::UNUSED_STATE_ERR))
- .collect();
- undeclared.combine(unused)
- }
- const UNMATCHED_SENDER_ERR: &str = "No receiver found for message type.";
- const UNMATCHED_RECEIVER_ERR: &str = "No sender found for message type.";
- const ACTIVATE_MSG: &str = "Activate";
- /// Ensures that the recipient state for every sent message has a receiving transition
- /// defined, and every receiver has a sender (except for the Activate message which is sent
- /// by the runtime).
- fn match_receivers_and_senders(&self) -> MaybeErr {
- let mut senders: HashSet<(&State, &Message)> = HashSet::new();
- let mut receivers: HashSet<(&State, &Message)> = HashSet::new();
- for transition in self.transitions.iter() {
- if let Some(msg) = &transition.in_msg {
- receivers.insert((&transition.in_state, msg));
- if msg.msg_type == Self::ACTIVATE_MSG {
- // The Activate message is sent by the run time, so a sender is created to
- // represent it.
- senders.insert((&transition.in_state, msg));
- }
- }
- for dest in transition.out_msgs.as_ref().iter() {
- let dest_state = match &dest.state {
- DestinationState::Individual(dest_state) => dest_state,
- DestinationState::Service(dest_state) => dest_state,
- };
- senders.insert((dest_state, &dest.msg));
- }
- }
- let extra_senders: MaybeErr = senders
- .difference(&receivers)
- .map(|pair| syn::Error::new(pair.1.msg_type.span(), Self::UNMATCHED_SENDER_ERR))
- .collect();
- let extra_receivers: MaybeErr = receivers
- .difference(&senders)
- .map(|pair| syn::Error::new(pair.1.msg_type.span(), Self::UNMATCHED_RECEIVER_ERR))
- .collect();
- extra_senders.combine(extra_receivers)
- }
- const UNDELIVERABLE_ERR: &str =
- "Receiver must either be a service, an owned state, or an out state, or the message must be a reply.";
- /// Checks that messages are only sent to destinations which are either services, states
- /// which are owned by the sender, listed in the output states, or that the message is a
- /// reply.
- fn no_undeliverable_msgs(&self) -> MaybeErr {
- let mut err = MaybeErr::none();
- for transition in self.transitions.iter() {
- let mut allowed_states: Option<HashSet<&Ident>> = None;
- for dest in transition.out_msgs.as_ref().iter() {
- if dest.msg.is_reply {
- continue;
- }
- match &dest.state {
- DestinationState::Service(_) => continue,
- DestinationState::Individual(dest_state) => {
- let allowed = allowed_states.get_or_insert_with(|| {
- transition
- .out_states
- .as_ref()
- .iter()
- .map(|state| &state.state_trait)
- .chain(transition.in_state.owned_states.as_ref().iter())
- .collect()
- });
- if !allowed.contains(&dest_state.state_trait) {
- err = err.combine(
- syn::Error::new(
- dest_state.state_trait.span(),
- Self::UNDELIVERABLE_ERR,
- )
- .into(),
- );
- }
- }
- }
- }
- }
- err
- }
- const INVALID_REPLY_ERR: &str =
- "Replies can only be used in transitions which handle messages.";
- const MULTIPLE_REPLIES_ERR: &str =
- "Only a single reply can be sent in response to any message.";
- /// Verifies that replies are only sent in response to messages.
- fn valid_replies(&self) -> MaybeErr {
- let mut err = MaybeErr::none();
- for transition in self.transitions.iter() {
- let replies: Vec<_> = transition
- .out_msgs
- .as_ref()
- .iter()
- .map(|dest| &dest.msg)
- .filter(|msg| msg.is_reply)
- .collect();
- if replies.is_empty() {
- continue;
- }
- if replies.len() > 1 {
- err = err.combine(
- replies
- .iter()
- .map(|reply| {
- syn::Error::new(reply.msg_type.span(), Self::MULTIPLE_REPLIES_ERR)
- })
- .collect(),
- );
- }
- if transition.in_msg.is_none() {
- err = err.combine(
- replies
- .iter()
- .map(|reply| {
- syn::Error::new(reply.msg_type.span(), Self::INVALID_REPLY_ERR)
- })
- .collect(),
- );
- }
- }
- err
- }
- }
- #[cfg(test)]
- mod tests {
- use super::*;
- use syn::parse_str;
- macro_rules! assert_ok {
- ($maybe_err:expr) => {
- let result: syn::Result<()> = $maybe_err.into();
- assert!(result.is_ok(), "{}", result.err().unwrap());
- };
- }
- macro_rules! assert_err {
- ($maybe_err:expr, $expected_msg:expr) => {
- let result: syn::Result<()> = $maybe_err.into();
- assert!(result.is_err());
- assert_eq!($expected_msg, result.err().unwrap().to_string());
- };
- }
- /// A minimal valid protocol definition.
- const MIN_PROTOCOL: &str = "
- let name = Test;
- let states = [Init];
- Init?Activate -> End;
- ";
- #[test]
- fn all_states_declared_and_used_ok() {
- let result = parse_str::<Protocol>(MIN_PROTOCOL)
- .unwrap()
- .all_states_declared_and_used();
- assert_ok!(result);
- }
- #[test]
- fn all_states_declared_and_used_end_not_used_ok() {
- const INPUT: &str = "
- let name = Test;
- let states = [Init];
- Init?Activate -> Init;
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .all_states_declared_and_used();
- assert_ok!(result);
- }
- #[test]
- fn all_states_declared_and_used_undeclared_err() {
- const INPUT: &str = "
- let name = Undeclared;
- let states = [Init];
- Init?Activate -> Next;
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .all_states_declared_and_used();
- assert_err!(result, Protocol::UNDECLARED_STATE_ERR);
- }
- #[test]
- fn all_states_declared_and_used_undeclared_out_state_owned_err() {
- const INPUT: &str = "
- let name = Undeclared;
- let states = [Init, Next];
- Init?Activate -> Init, Next[Undeclared];
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .all_states_declared_and_used();
- assert_err!(result, Protocol::UNDECLARED_STATE_ERR);
- }
- #[test]
- fn all_states_declared_and_used_undeclared_in_state_owned_err() {
- const INPUT: &str = "
- let name = Undeclared;
- let states = [Init, Next];
- Init[Undeclared]?Activate -> Next;
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .all_states_declared_and_used();
- assert_err!(result, Protocol::UNDECLARED_STATE_ERR);
- }
- #[test]
- fn all_states_declared_and_used_unused_err() {
- const INPUT: &str = "
- let name = Unused;
- let states = [Init, Extra];
- Init?Activate -> End;
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .all_states_declared_and_used();
- assert_err!(result, Protocol::UNUSED_STATE_ERR);
- }
- #[test]
- fn match_receivers_and_senders_ok() {
- let result = parse_str::<Protocol>(MIN_PROTOCOL)
- .unwrap()
- .match_receivers_and_senders();
- assert_ok!(result);
- }
- #[test]
- fn match_receivers_and_senders_send_activate_ok() {
- const INPUT: &str = "
- let name = Unbalanced;
- let states = [First, Second];
- First?Activate -> First, >Second!Activate;
- Second?Activate -> Second;
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .match_receivers_and_senders();
- assert_ok!(result);
- }
- #[test]
- fn match_receivers_and_senders_unmatched_sender_err() {
- const INPUT: &str = "
- let name = Unbalanced;
- let states = [Init, Other];
- Init?Activate -> Init, >Other!Activate;
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .match_receivers_and_senders();
- assert_err!(result, Protocol::UNMATCHED_SENDER_ERR);
- }
- #[test]
- fn match_receivers_and_senders_unmatched_receiver_err() {
- const INPUT: &str = "
- let name = Unbalanced;
- let states = [Init];
- Init?NotExists -> Init;
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .match_receivers_and_senders();
- assert_err!(result, Protocol::UNMATCHED_RECEIVER_ERR);
- }
- #[test]
- fn no_undeliverable_msgs_ok() {
- let result = parse_str::<Protocol>(MIN_PROTOCOL)
- .unwrap()
- .no_undeliverable_msgs();
- assert_ok!(result);
- }
- #[test]
- fn no_undeliverable_msgs_reply_ok() {
- const INPUT: &str = "
- let name = Undeliverable;
- let states = [Listening, Client];
- Listening?Msg -> Listening, >Client!Msg::Reply;
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .no_undeliverable_msgs();
- assert_ok!(result);
- }
- #[test]
- fn no_undeliverable_msgs_service_ok() {
- const INPUT: &str = "
- let name = Undeliverable;
- let states = [Client, Server];
- Client -> Client, >service(Server)!Msg;
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .no_undeliverable_msgs();
- assert_ok!(result);
- }
- #[test]
- fn no_undeliverable_msgs_owned_ok() {
- const INPUT: &str = "
- let name = Undeliverable;
- let states = [FileClient, FileHandle];
- FileClient[FileHandle] -> FileClient, >FileHandle!FileOp;
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .no_undeliverable_msgs();
- assert_ok!(result);
- }
- #[test]
- fn no_undeliverable_msgs_err() {
- const INPUT: &str = "
- let name = Undeliverable;
- let states = [Client, Server];
- Client -> Client, >Server!Msg;
- ";
- let result = parse_str::<Protocol>(INPUT)
- .unwrap()
- .no_undeliverable_msgs();
- assert_err!(result, Protocol::UNDELIVERABLE_ERR);
- }
- #[test]
- fn valid_replies_ok() {
- const INPUT: &str = "
- let name = ValidReplies;
- let states = [Client, Server];
- Server?Msg -> Server, >Client!Msg::Reply;
- ";
- let result = parse_str::<Protocol>(INPUT).unwrap().valid_replies();
- assert_ok!(result);
- }
- #[test]
- fn valid_replies_invalid_reply_err() {
- const INPUT: &str = "
- let name = ValidReplies;
- let states = [Client, Server];
- Client -> Client, >Server!Msg::Reply;
- ";
- let result = parse_str::<Protocol>(INPUT).unwrap().valid_replies();
- assert_err!(result, Protocol::INVALID_REPLY_ERR);
- }
- #[test]
- fn valid_replies_multiple_replies_err() {
- const INPUT: &str = "
- let name = ValidReplies;
- let states = [Client, OtherClient, Server];
- Server?Msg -> Server, >Client!Msg::Reply, OtherClient!Msg::Reply;
- ";
- let result = parse_str::<Protocol>(INPUT).unwrap().valid_replies();
- assert_err!(result, Protocol::MULTIPLE_REPLIES_ERR);
- }
- }
|