Pārlūkot izejas kodu

Added code to validate protocol definitions.

Matthew Carr 1 gadu atpakaļ
vecāks
revīzija
a53fe62a1a

+ 52 - 0
crates/btproto/src/error.rs

@@ -0,0 +1,52 @@
+//! Code to assist with error handling, primarily [MaybeErr].
+
+/// A wrapper around `Option<syn::Error>` which allows errors to be easily combined and created
+/// from iterators of [syn::Error].
+#[derive(Debug, Default)]
+pub(crate) struct MaybeErr(Option<syn::Error>);
+
+impl MaybeErr {
+    pub(crate) fn none() -> Self {
+        MaybeErr(None)
+    }
+
+    pub(crate) fn combine(self, other: Self) -> Self {
+        let option = match (self.0, other.0) {
+            (Some(left), right) => fold_errs(right, left),
+            (left, Some(right)) => fold_errs(left, right),
+            _ => None,
+        };
+        MaybeErr(option)
+    }
+}
+
+impl From<MaybeErr> for Result<(), syn::Error> {
+    fn from(value: MaybeErr) -> Self {
+        if let Some(err) = value.0 {
+            Err(err)
+        } else {
+            Ok(())
+        }
+    }
+}
+
+impl From<syn::Error> for MaybeErr {
+    fn from(err: syn::Error) -> Self {
+        MaybeErr(Some(err))
+    }
+}
+
+impl FromIterator<syn::Error> for MaybeErr {
+    fn from_iter<T: IntoIterator<Item = syn::Error>>(iter: T) -> Self {
+        MaybeErr(iter.into_iter().fold(None, fold_errs))
+    }
+}
+
+fn fold_errs(accum: Option<syn::Error>, curr: syn::Error) -> Option<syn::Error> {
+    if let Some(mut accum) = accum {
+        accum.combine(curr);
+        Some(accum)
+    } else {
+        Some(curr)
+    }
+}

+ 12 - 2
crates/btproto/src/lib.rs

@@ -5,9 +5,19 @@ use syn::parse_macro_input;
 mod parsing;
 use parsing::Protocol;
 
+mod error;
 mod generation;
 mod validation;
 
+macro_rules! unwrap_or_compile_err {
+    ($result:expr) => {
+        match $result {
+            Ok(value) => value,
+            Err(err) => return err.into_compile_error().into(),
+        }
+    };
+}
+
 /// Generates types for the parties participating in a messaging protocol.
 /// The grammar recognized by this macro is given below in the dialect of Extended Backus-Naur Form
 /// recognized by the `llgen` tool:
@@ -27,6 +37,6 @@ mod validation;
 #[proc_macro]
 pub fn protocol(input: TokenStream) -> TokenStream {
     let input = parse_macro_input!(input as Protocol);
-    input.validate().unwrap();
-    TokenStream::from(input.generate())
+    unwrap_or_compile_err!(input.validate());
+    input.generate().into()
 }

+ 13 - 8
crates/btproto/src/parsing.rs

@@ -50,7 +50,7 @@ pub(crate) struct StatesDef {
 }
 
 impl StatesDef {
-    const ARRAY_IDENT_ERR: &str = "invalid states array identifier";
+    const ARRAY_IDENT_ERR: &str = "invalid states array identifier. Expected 'states'.";
 }
 
 impl Parse for StatesDef {
@@ -69,7 +69,8 @@ impl Parse for StatesDef {
     }
 }
 
-#[cfg_attr(test, derive(Debug, PartialEq))]
+#[cfg_attr(test, derive(Debug))]
+#[derive(Hash, PartialEq, Eq)]
 pub(crate) struct IdentArray(Punctuated<Ident, Token![,]>);
 
 impl IdentArray {
@@ -133,7 +134,8 @@ impl Parse for Transition {
     }
 }
 
-#[cfg_attr(test, derive(Debug, PartialEq))]
+#[cfg_attr(test, derive(Debug))]
+#[derive(Hash, PartialEq, Eq)]
 pub(crate) struct State {
     pub(crate) state_trait: Ident,
     pub(crate) owned_states: IdentArray,
@@ -232,7 +234,7 @@ impl Parse for Dest {
 
 #[cfg_attr(test, derive(Debug, PartialEq))]
 pub(crate) enum DestinationState {
-    Service(Ident),
+    Service(State),
     Individual(State),
 }
 
@@ -259,14 +261,18 @@ impl Parse for DestinationState {
             if let Some(extra_dest) = dest_states.next() {
                 return Err(syn::Error::new(extra_dest.span(), Self::MULTI_STATE_ERR));
             }
-            Ok(DestinationState::Service(dest_state))
+            Ok(DestinationState::Service(State {
+                state_trait: dest_state,
+                owned_states: IdentArray::empty(),
+            }))
         } else {
             Ok(DestinationState::Individual(input.parse()?))
         }
     }
 }
 
-#[cfg_attr(test, derive(Debug, PartialEq))]
+#[cfg_attr(test, derive(Debug))]
+#[derive(Hash, PartialEq, Eq)]
 pub(crate) struct Message {
     pub(crate) msg_type: Ident,
     pub(crate) is_reply: bool,
@@ -714,8 +720,7 @@ let states = [{}];
     #[test]
     fn destination_state_parse_service() {
         const EXPECTED_DEST_STATE: &str = "Listening";
-        let expected =
-            DestinationState::Service(Ident::new(EXPECTED_DEST_STATE, Span::call_site()));
+        let expected = DestinationState::Service(State::new(EXPECTED_DEST_STATE, iter::empty()));
         let input = format!("service({EXPECTED_DEST_STATE})");
 
         let actual = parse_str::<DestinationState>(&input).unwrap();

+ 454 - 3
crates/btproto/src/validation.rs

@@ -1,8 +1,459 @@
-use super::Protocol;
+use std::collections::HashSet;
+
+use proc_macro2::{Ident, Span};
+
+use crate::{
+    error::MaybeErr,
+    parsing::{DestinationState, Message, State},
+    Protocol,
+};
 
 impl Protocol {
     pub(crate) fn validate(&self) -> syn::Result<()> {
-        // TODO: Validate input.
-        Ok(())
+        self.all_states_declared_and_used()
+            .combine(self.match_receivers_and_senders())
+            .combine(self.no_undeliverable_msgs())
+            .combine(self.valid_replies())
+            .into()
+    }
+
+    const UNDECLARED_STATE_ERR: &str = "State was not declared.";
+    const UNUSED_STATE_ERR: &str = "State was declared but never used.";
+    const END_STATE: &str = "End";
+
+    /// Verifies that every state which is used has been declared, except for the End state.
+    fn all_states_declared_and_used(&self) -> MaybeErr {
+        let end = Ident::new(Self::END_STATE, Span::call_site());
+        let declared: HashSet<&Ident> = self
+            .states_def
+            .states
+            .as_ref()
+            .iter()
+            .chain([&end].into_iter())
+            .collect();
+        let mut used: HashSet<&Ident> = HashSet::with_capacity(declared.len());
+        for transition in self.transitions.iter() {
+            let in_state = &transition.in_state;
+            used.insert(&in_state.state_trait);
+            used.extend(in_state.owned_states.as_ref().iter());
+            if let Some(in_msg) = &transition.in_msg {
+                used.extend(in_msg.owned_states.as_ref().iter());
+            }
+            for out_states in transition.out_states.as_ref().iter() {
+                used.insert(&out_states.state_trait);
+                used.extend(out_states.owned_states.as_ref().iter());
+            }
+            // We don't have to check the states referred to in out_msgs because the
+            // match_receivers_and_senders method ensures that each of these exists in a receiver
+            // position.
+        }
+        let undeclared: MaybeErr = used
+            .difference(&declared)
+            .map(|ident| syn::Error::new(ident.span(), Self::UNDECLARED_STATE_ERR))
+            .collect();
+        let unused: MaybeErr = declared
+            .difference(&used)
+            .filter(|ident| **ident != Self::END_STATE)
+            .map(|ident| syn::Error::new(ident.span(), Self::UNUSED_STATE_ERR))
+            .collect();
+        undeclared.combine(unused)
+    }
+
+    const UNMATCHED_SENDER_ERR: &str = "No receiver found for message type.";
+    const UNMATCHED_RECEIVER_ERR: &str = "No sender found for message type.";
+    const ACTIVATE_MSG: &str = "Activate";
+
+    /// Ensures that the recipient state for every sent message has a receiving transition
+    /// defined, and every receiver has a sender (except for the Activate message which is sent
+    /// by the runtime).
+    fn match_receivers_and_senders(&self) -> MaybeErr {
+        let mut senders: HashSet<(&State, &Message)> = HashSet::new();
+        let mut receivers: HashSet<(&State, &Message)> = HashSet::new();
+        for transition in self.transitions.iter() {
+            if let Some(msg) = &transition.in_msg {
+                receivers.insert((&transition.in_state, msg));
+                if msg.msg_type == Self::ACTIVATE_MSG {
+                    // The Activate message is sent by the run time, so a sender is created to
+                    // represent it.
+                    senders.insert((&transition.in_state, msg));
+                }
+            }
+            for dest in transition.out_msgs.as_ref().iter() {
+                let dest_state = match &dest.state {
+                    DestinationState::Individual(dest_state) => dest_state,
+                    DestinationState::Service(dest_state) => dest_state,
+                };
+                senders.insert((dest_state, &dest.msg));
+            }
+        }
+        let extra_senders: MaybeErr = senders
+            .difference(&receivers)
+            .map(|pair| syn::Error::new(pair.1.msg_type.span(), Self::UNMATCHED_SENDER_ERR))
+            .collect();
+        let extra_receivers: MaybeErr = receivers
+            .difference(&senders)
+            .map(|pair| syn::Error::new(pair.1.msg_type.span(), Self::UNMATCHED_RECEIVER_ERR))
+            .collect();
+        extra_senders.combine(extra_receivers)
+    }
+
+    const UNDELIVERABLE_ERR: &str =
+        "Receiver must either be a service, an owned state, or an out state, or the message must be a reply.";
+
+    /// Checks that messages are only sent to destinations which are either services, states
+    /// which are owned by the sender, listed in the output states, or that the message is a
+    /// reply.
+    fn no_undeliverable_msgs(&self) -> MaybeErr {
+        let mut err = MaybeErr::none();
+        for transition in self.transitions.iter() {
+            let mut allowed_states: Option<HashSet<&Ident>> = None;
+            for dest in transition.out_msgs.as_ref().iter() {
+                if dest.msg.is_reply {
+                    continue;
+                }
+                match &dest.state {
+                    DestinationState::Service(_) => continue,
+                    DestinationState::Individual(dest_state) => {
+                        let allowed = allowed_states.get_or_insert_with(|| {
+                            transition
+                                .out_states
+                                .as_ref()
+                                .iter()
+                                .map(|state| &state.state_trait)
+                                .chain(transition.in_state.owned_states.as_ref().iter())
+                                .collect()
+                        });
+                        if !allowed.contains(&dest_state.state_trait) {
+                            err = err.combine(
+                                syn::Error::new(
+                                    dest_state.state_trait.span(),
+                                    Self::UNDELIVERABLE_ERR,
+                                )
+                                .into(),
+                            );
+                        }
+                    }
+                }
+            }
+        }
+        err
+    }
+
+    const INVALID_REPLY_ERR: &str =
+        "Replies can only be used in transitions which handle messages.";
+    const MULTIPLE_REPLIES_ERR: &str =
+        "Only a single reply can be sent in response to any message.";
+
+    /// Verifies that replies are only sent in response to messages.
+    fn valid_replies(&self) -> MaybeErr {
+        let mut err = MaybeErr::none();
+        for transition in self.transitions.iter() {
+            let replies: Vec<_> = transition
+                .out_msgs
+                .as_ref()
+                .iter()
+                .map(|dest| &dest.msg)
+                .filter(|msg| msg.is_reply)
+                .collect();
+            if replies.is_empty() {
+                continue;
+            }
+            if replies.len() > 1 {
+                err = err.combine(
+                    replies
+                        .iter()
+                        .map(|reply| {
+                            syn::Error::new(reply.msg_type.span(), Self::MULTIPLE_REPLIES_ERR)
+                        })
+                        .collect(),
+                );
+            }
+            if transition.in_msg.is_none() {
+                err = err.combine(
+                    replies
+                        .iter()
+                        .map(|reply| {
+                            syn::Error::new(reply.msg_type.span(), Self::INVALID_REPLY_ERR)
+                        })
+                        .collect(),
+                );
+            }
+        }
+        err
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use syn::parse_str;
+
+    macro_rules! assert_ok {
+        ($maybe_err:expr) => {
+            let result: syn::Result<()> = $maybe_err.into();
+            assert!(result.is_ok(), "{}", result.err().unwrap());
+        };
+    }
+
+    macro_rules! assert_err {
+        ($maybe_err:expr, $expected_msg:expr) => {
+            let result: syn::Result<()> = $maybe_err.into();
+            assert!(result.is_err());
+            assert_eq!($expected_msg, result.err().unwrap().to_string());
+        };
+    }
+
+    /// A minimal valid protocol definition.
+    const MIN_PROTOCOL: &str = "
+let name = Test;
+let states = [Init];
+Init?Activate -> End;
+";
+
+    #[test]
+    fn all_states_declared_and_used_ok() {
+        let result = parse_str::<Protocol>(MIN_PROTOCOL)
+            .unwrap()
+            .all_states_declared_and_used();
+
+        assert_ok!(result);
+    }
+
+    #[test]
+    fn all_states_declared_and_used_end_not_used_ok() {
+        const INPUT: &str = "
+let name = Test;
+let states = [Init];
+Init?Activate -> Init;
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .all_states_declared_and_used();
+
+        assert_ok!(result);
+    }
+
+    #[test]
+    fn all_states_declared_and_used_undeclared_err() {
+        const INPUT: &str = "
+let name = Undeclared;
+let states = [Init];
+Init?Activate -> Next;
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .all_states_declared_and_used();
+
+        assert_err!(result, Protocol::UNDECLARED_STATE_ERR);
+    }
+
+    #[test]
+    fn all_states_declared_and_used_undeclared_out_state_owned_err() {
+        const INPUT: &str = "
+let name = Undeclared;
+let states = [Init, Next];
+Init?Activate -> Init, Next[Undeclared];
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .all_states_declared_and_used();
+
+        assert_err!(result, Protocol::UNDECLARED_STATE_ERR);
+    }
+
+    #[test]
+    fn all_states_declared_and_used_undeclared_in_state_owned_err() {
+        const INPUT: &str = "
+let name = Undeclared;
+let states = [Init, Next];
+Init[Undeclared]?Activate -> Next;
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .all_states_declared_and_used();
+
+        assert_err!(result, Protocol::UNDECLARED_STATE_ERR);
+    }
+
+    #[test]
+    fn all_states_declared_and_used_unused_err() {
+        const INPUT: &str = "
+let name = Unused;
+let states = [Init, Extra];
+Init?Activate -> End;
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .all_states_declared_and_used();
+
+        assert_err!(result, Protocol::UNUSED_STATE_ERR);
+    }
+
+    #[test]
+    fn match_receivers_and_senders_ok() {
+        let result = parse_str::<Protocol>(MIN_PROTOCOL)
+            .unwrap()
+            .match_receivers_and_senders();
+
+        assert_ok!(result);
+    }
+
+    #[test]
+    fn match_receivers_and_senders_send_activate_ok() {
+        const INPUT: &str = "
+let name = Unbalanced;
+let states = [First, Second];
+First?Activate -> First, >Second!Activate;
+Second?Activate -> Second;
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .match_receivers_and_senders();
+
+        assert_ok!(result);
+    }
+
+    #[test]
+    fn match_receivers_and_senders_unmatched_sender_err() {
+        const INPUT: &str = "
+let name = Unbalanced;
+let states = [Init, Other];
+Init?Activate -> Init, >Other!Activate;
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .match_receivers_and_senders();
+
+        assert_err!(result, Protocol::UNMATCHED_SENDER_ERR);
+    }
+
+    #[test]
+    fn match_receivers_and_senders_unmatched_receiver_err() {
+        const INPUT: &str = "
+let name = Unbalanced;
+let states = [Init];
+Init?NotExists -> Init;
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .match_receivers_and_senders();
+
+        assert_err!(result, Protocol::UNMATCHED_RECEIVER_ERR);
+    }
+
+    #[test]
+    fn no_undeliverable_msgs_ok() {
+        let result = parse_str::<Protocol>(MIN_PROTOCOL)
+            .unwrap()
+            .no_undeliverable_msgs();
+
+        assert_ok!(result);
+    }
+
+    #[test]
+    fn no_undeliverable_msgs_reply_ok() {
+        const INPUT: &str = "
+let name = Undeliverable;
+let states = [Listening, Client];
+Listening?Msg -> Listening, >Client!Msg::Reply;
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .no_undeliverable_msgs();
+
+        assert_ok!(result);
+    }
+
+    #[test]
+    fn no_undeliverable_msgs_service_ok() {
+        const INPUT: &str = "
+let name = Undeliverable;
+let states = [Client, Server];
+Client -> Client, >service(Server)!Msg;
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .no_undeliverable_msgs();
+
+        assert_ok!(result);
+    }
+
+    #[test]
+    fn no_undeliverable_msgs_owned_ok() {
+        const INPUT: &str = "
+let name = Undeliverable;
+let states = [FileClient, FileHandle];
+FileClient[FileHandle] -> FileClient, >FileHandle!FileOp;
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .no_undeliverable_msgs();
+
+        assert_ok!(result);
+    }
+
+    #[test]
+    fn no_undeliverable_msgs_err() {
+        const INPUT: &str = "
+let name = Undeliverable;
+let states = [Client, Server];
+Client -> Client, >Server!Msg;
+";
+
+        let result = parse_str::<Protocol>(INPUT)
+            .unwrap()
+            .no_undeliverable_msgs();
+
+        assert_err!(result, Protocol::UNDELIVERABLE_ERR);
+    }
+
+    #[test]
+    fn valid_replies_ok() {
+        const INPUT: &str = "
+let name = ValidReplies;
+let states = [Client, Server];
+Server?Msg -> Server, >Client!Msg::Reply;
+";
+
+        let result = parse_str::<Protocol>(INPUT).unwrap().valid_replies();
+
+        assert_ok!(result);
+    }
+
+    #[test]
+    fn valid_replies_invalid_reply_err() {
+        const INPUT: &str = "
+let name = ValidReplies;
+let states = [Client, Server];
+Client -> Client, >Server!Msg::Reply;
+";
+
+        let result = parse_str::<Protocol>(INPUT).unwrap().valid_replies();
+
+        assert_err!(result, Protocol::INVALID_REPLY_ERR);
+    }
+
+    #[test]
+    fn valid_replies_multiple_replies_err() {
+        const INPUT: &str = "
+let name = ValidReplies;
+let states = [Client, OtherClient, Server];
+Server?Msg -> Server, >Client!Msg::Reply, OtherClient!Msg::Reply;
+";
+
+        let result = parse_str::<Protocol>(INPUT).unwrap().valid_replies();
+
+        assert_err!(result, Protocol::MULTIPLE_REPLIES_ERR);
     }
 }

+ 5 - 4
crates/btrun/src/fs_proto.rs

@@ -17,7 +17,7 @@ protocol! {
     let states = [
         ServerInit, Listening,
         Client,
-        FileInit, Open, Opened,
+        FileInit, Opened,
         FileHandle,
     ];
     ServerInit?Activate -> Listening;
@@ -27,15 +27,16 @@ protocol! {
     Client?Query::Reply -> Client;
 
     Client -> Client, >service(Listening)!Open;
-    Listening?Open -> Listening, Opened, >Client!Open::Reply[Opened];
+    Listening?Open -> Listening, FileInit, >Client!Open::Reply[Opened], FileInit!Open;
     Client?Open::Reply[Opened] -> Client, FileHandle[Opened];
 
     FileInit?Activate -> FileInit;
     FileInit?Open -> Opened;
+    //PoopedPants?Notification -> Changing;
 
     FileHandle[Opened] -> FileHandle[Opened], >Opened!FileOp;
-    Opened?FileOp -> Opened, >Client!FileOp::Reply;
-    FileHandle?FileOp::Reply -> FileClient;
+    Opened?FileOp -> Opened, >FileHandle!FileOp::Reply;
+    FileHandle?FileOp::Reply -> FileHandle;
 
     FileHandle[Opened] -> End, >Opened!Close;
     Opened?Close -> End;

+ 2 - 2
crates/btrun/src/lib.rs

@@ -816,7 +816,7 @@ mod tests {
             ClientInit, SentPing,
             ServerInit, Listening,
         ];
-        ClientInit?Activate -> SentPing, >Listening!Ping;
+        ClientInit?Activate -> SentPing, >service(Listening)!Ping;
         ServerInit?Activate -> Listening;
         Listening?Ping -> End, >SentPing!Ping::Reply;
         SentPing?Ping::Reply -> End;
@@ -1031,7 +1031,7 @@ mod tests {
             Choosing,
         ];
         AgencyInit?Activate -> Listening;
-        Choosing -> Choosing, >Listening!Query, Listening!Accept, Listening!Reject;
+        Choosing -> Choosing, >service(Listening)!Query, service(Listening)!Accept, service(Listening)!Reject;
         Listening?Query -> Listening, >Choosing!Query::Reply;
         Choosing?Query::Reply -> Choosing;
         Listening?Accept -> End, >Choosing!Accept::Reply;