Explorar o código

* Added new validation method to btproto.
* Modified the syntax types to better keep track of the
Span occupied by the text they represent.

Matthew Carr hai 1 ano
pai
achega
b77fcd72e1

+ 11 - 10
crates/btproto/src/generation.rs

@@ -24,7 +24,7 @@ impl Protocol {
                 msgs.insert(msg);
             }
         }
-        let variants = msgs.iter().map(|msg| msg.variant());
+        let variants = msgs.iter().map(|msg| msg.ident());
         let msg_types = msgs.iter().map(|msg| msg.type_tokens());
         let enum_name = format_ident!("{}Msgs", self.name_def.name);
         quote! {
@@ -44,9 +44,10 @@ impl Protocol {
         }
         let mut tokens = TokenStream::new();
         for (trait_ident, transitions) in traits {
+            let transition_tokens = transitions.iter().map(|x| x.generate_tokens());
             quote! {
                 pub trait #trait_ident {
-                    #( #transitions )*
+                    #( #transition_tokens )*
                 }
             }
             .to_tokens(&mut tokens);
@@ -66,7 +67,7 @@ impl Message {
             let msg_type = &self.msg_type;
             quote! { #msg_type }
         };
-        if self.is_reply {
+        if self.is_reply() {
             quote! {
                 <#msg_type as ::btrun::CallMsg>::Reply
             }
@@ -78,15 +79,16 @@ impl Message {
     }
 }
 
-impl ToTokens for Transition {
-    fn to_tokens(&self, tokens: &mut TokenStream) {
+impl Transition {
+    /// Generates the tokens for the code which implements this transition.
+    fn generate_tokens(&self) -> TokenStream {
         let (msg_arg, method_ident) = if let Some(msg) = &self.in_msg {
             let msg_type = if msg.msg_type == Activate::ident() {
                 quote! { ::btrun::Activate }
             } else {
                 msg.msg_type.to_token_stream()
             };
-            let method_ident = format_ident!("handle_{}", msg.variant().pascal_to_snake());
+            let method_ident = format_ident!("handle_{}", msg.ident().pascal_to_snake());
             let msg_arg = quote! { , msg: #msg_type };
             (msg_arg, method_ident)
         } else {
@@ -96,7 +98,7 @@ impl ToTokens for Transition {
                 .as_ref()
                 .iter()
                 .fold(Option::<String>::None, |accum, curr| {
-                    let msg_name = curr.msg.variant().pascal_to_snake();
+                    let msg_name = curr.msg.ident().pascal_to_snake();
                     if let Some(mut accum) = accum {
                         accum.push('_');
                         accum.push_str(&msg_name);
@@ -131,12 +133,11 @@ impl ToTokens for Transition {
         let output_decls = output_pairs.iter().map(|(decl, _)| decl);
         let output_types = output_pairs.iter().map(|(_, output_type)| output_type);
         let future_name = format_ident!("{}Fut", method_type_prefix);
-        let transition = quote! {
+        quote! {
             #( #output_decls )*
             type #future_name: ::std::future::Future<Output = Result<( #( #output_types ),* )>>;
             fn #method_ident(self #msg_arg) -> Self::#future_name;
-        };
-        tokens.extend(transition);
+        }
     }
 }
 

+ 2 - 2
crates/btproto/src/lib.rs

@@ -24,11 +24,11 @@ macro_rules! unwrap_or_compile_err {
 /// recognized by the `llgen` tool:
 ///
 /// ```ebnf
-/// protocol : name_def states_def transition* ;
+/// protocol : name_def states_def ( transition ';' )* ;
 /// name_def : "let" "name" '=' Ident ';' ;
 /// states_def : "let" "states" '=' ident_array ';' ;
 /// ident_array : '[' Ident ( ',' Ident )* ','? ']' ;
-/// transition : state ( '?' message )?  "->" states_list ( '>' dest_list )? ';' ;
+/// transition : state ( '?' message )?  "->" states_list ( '>' dest_list )? ;
 /// state : Ident ident_array? ;
 /// states_list : state ( ',' state )* ','? ;
 /// dest_list : dest ( ',' dest )* ;

+ 254 - 44
crates/btproto/src/parsing.rs

@@ -1,7 +1,11 @@
 //! Types for parsing the protocol grammar.
 
+use proc_macro2::Span;
 use quote::format_ident;
-use syn::{bracketed, parenthesized, parse::Parse, punctuated::Punctuated, token, Ident, Token};
+use syn::{
+    bracketed, parenthesized, parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Ident,
+    Token,
+};
 
 /// This type represents the top-level production for the protocol grammar.
 #[cfg_attr(test, derive(Debug, PartialEq))]
@@ -12,7 +16,7 @@ pub(crate) struct Protocol {
 }
 
 impl Parse for Protocol {
-    /// protocol : name_def states_def transition* ;
+    /// protocol : name_def states_def ( transition ';' )* ;
     fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
         Ok(Protocol {
             name_def: input.parse()?,
@@ -22,9 +26,22 @@ impl Parse for Protocol {
     }
 }
 
+impl GetSpan for Protocol {
+    fn span(&self) -> Span {
+        self.name_def
+            .span()
+            .left_join(&self.states_def)
+            .left_join(punctuated_span(&self.transitions))
+    }
+}
+
 #[cfg_attr(test, derive(Debug, PartialEq))]
 pub(crate) struct NameDef {
+    let_token: Token![let],
+    name_ident: Ident,
+    eq_token: Token![=],
     pub(crate) name: Ident,
+    semi_token: Token![;],
 }
 
 impl NameDef {
@@ -35,14 +52,24 @@ impl NameDef {
 impl Parse for NameDef {
     /// name_def : "let" "name" '=' Ident ';' ;
     fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
-        input.parse::<Token![let]>()?;
-        if Ident::parse(input)? != Self::NAME_IDENT {
-            return Err(input.error(Self::NAME_IDENT_ERR));
-        }
-        input.parse::<Token![=]>()?;
-        let name = Ident::parse(input)?;
-        input.parse::<Token![;]>()?;
-        Ok(NameDef { name })
+        Ok(NameDef {
+            let_token: input.parse()?,
+            name_ident: check_ident(input.parse()?, Self::NAME_IDENT, Self::NAME_IDENT_ERR)?,
+            eq_token: input.parse()?,
+            name: input.parse()?,
+            semi_token: input.parse()?,
+        })
+    }
+}
+
+impl GetSpan for NameDef {
+    fn span(&self) -> Span {
+        self.let_token
+            .span()
+            .left_join(self.name_ident.span())
+            .left_join(self.eq_token.span())
+            .left_join(self.name.span())
+            .left_join(self.semi_token.span())
     }
 }
 
@@ -72,6 +99,12 @@ impl Parse for StatesDef {
     }
 }
 
+impl GetSpan for StatesDef {
+    fn span(&self) -> Span {
+        self.states.span()
+    }
+}
+
 #[cfg_attr(test, derive(Debug))]
 #[derive(Hash, PartialEq, Eq)]
 pub(crate) struct IdentArray(Punctuated<Ident, Token![,]>);
@@ -84,6 +117,12 @@ impl IdentArray {
     }
 }
 
+impl GetSpan for IdentArray {
+    fn span(&self) -> Span {
+        self.0.span()
+    }
+}
+
 impl Parse for IdentArray {
     /// ident_array : '[' Ident ( ',' Ident )* ','? ']' ;
     fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
@@ -107,12 +146,14 @@ impl AsRef<Punctuated<Ident, Token![,]>> for IdentArray {
 pub(crate) struct Transition {
     pub(crate) in_state: State,
     pub(crate) in_msg: Option<Message>,
+    arrow: Token![->],
     pub(crate) out_states: StatesList,
+    redirect: Option<Token![>]>,
     pub(crate) out_msgs: DestList,
 }
 
 impl Parse for Transition {
-    /// transition : state ( '?' message )?  "->" states_list ( '>' dest_list )? ';' ;
+    /// transition : state ( '?' message )?  "->" states_list ( '>' dest_list )? ;
     fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
         let in_state = State::parse(input)?;
         let in_msg = if input.parse::<Token![?]>().is_ok() {
@@ -120,23 +161,36 @@ impl Parse for Transition {
         } else {
             None
         };
-        input.parse::<Token![->]>()?;
+        let arrow = input.parse::<Token![->]>()?;
         let out_states = StatesList::parse(input)?;
-        let out_msgs = if input.parse::<Token![>]>().is_ok() {
-            DestList::parse(input)?
+        let (redirect, out_msgs) = if let Ok(redirect) = input.parse::<Token![>]>() {
+            (Some(redirect), DestList::parse(input)?)
         } else {
-            DestList::empty()
+            (None, DestList::empty())
         };
-        // Note that we must not eat the semicolon because the Punctuated parser expects it.
         Ok(Self {
             in_state,
             in_msg,
+            arrow,
             out_states,
+            redirect,
             out_msgs,
         })
     }
 }
 
+impl GetSpan for Transition {
+    fn span(&self) -> Span {
+        self.in_state
+            .span()
+            .left_join(self.in_msg.as_ref())
+            .left_join(self.arrow.span())
+            .left_join(&self.out_states)
+            .left_join(self.redirect.as_ref().map(|x| x.span()))
+            .left_join(&self.out_msgs)
+    }
+}
+
 #[cfg_attr(test, derive(Debug))]
 #[derive(Hash, PartialEq, Eq)]
 pub(crate) struct State {
@@ -160,6 +214,12 @@ impl Parse for State {
     }
 }
 
+impl GetSpan for State {
+    fn span(&self) -> Span {
+        self.state_trait.span().left_join(&self.owned_states)
+    }
+}
+
 #[cfg_attr(test, derive(Debug, PartialEq))]
 pub(crate) struct StatesList(Punctuated<State, Token![,]>);
 
@@ -184,6 +244,12 @@ impl Parse for StatesList {
     }
 }
 
+impl GetSpan for StatesList {
+    fn span(&self) -> Span {
+        punctuated_span(&self.0)
+    }
+}
+
 impl AsRef<Punctuated<State, Token![,]>> for StatesList {
     fn as_ref(&self) -> &Punctuated<State, Token![,]> {
         &self.0
@@ -213,6 +279,12 @@ impl Parse for DestList {
     }
 }
 
+impl GetSpan for DestList {
+    fn span(&self) -> Span {
+        punctuated_span(&self.0)
+    }
+}
+
 impl AsRef<Punctuated<Dest, Token![,]>> for DestList {
     fn as_ref(&self) -> &Punctuated<Dest, Token![,]> {
         &self.0
@@ -235,6 +307,12 @@ impl Parse for Dest {
     }
 }
 
+impl GetSpan for Dest {
+    fn span(&self) -> Span {
+        self.state.span().left_join(&self.msg)
+    }
+}
+
 #[cfg_attr(test, derive(Debug, PartialEq))]
 pub(crate) enum DestinationState {
     Service(State),
@@ -274,42 +352,53 @@ impl Parse for DestinationState {
     }
 }
 
+impl GetSpan for DestinationState {
+    fn span(&self) -> Span {
+        let state = match self {
+            Self::Service(state) => state,
+            Self::Individual(state) => state,
+        };
+        state.span()
+    }
+}
+
 #[cfg_attr(test, derive(Debug))]
 #[derive(Hash, PartialEq, Eq)]
 pub(crate) struct Message {
     pub(crate) msg_type: Ident,
-    pub(crate) is_reply: bool,
+    reply_part: Option<MessageReplyPart>,
     pub(crate) owned_states: IdentArray,
-    variant: Option<Ident>,
+    ident: Option<Ident>,
 }
 
 impl Message {
-    const REPLY_ERR: &str = "expected 'Reply'";
-
-    /// Returns the name of the message enum variant to enclose this message type.
-    pub(crate) fn variant(&self) -> &Ident {
-        if let Some(variant) = &self.variant {
+    /// Returns the identifier to use when naming types and variants after this message.
+    pub(crate) fn ident(&self) -> &Ident {
+        if let Some(variant) = &self.ident {
             variant
         } else {
             &self.msg_type
         }
     }
+
+    /// Returns true if and only if this message is a reply.
+    pub(crate) fn is_reply(&self) -> bool {
+        self.reply_part.is_some()
+    }
 }
 
 impl Parse for Message {
     /// message : Ident ( "::" "Reply" )? ident_array? ;
     fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
         let msg_type = Ident::parse(input)?;
-        let is_reply = input.peek(Token![::]);
-        let variant = if is_reply {
-            input.parse::<Token![::]>()?;
-            let reply = Ident::parse(input)?;
-            if reply != "Reply" {
-                return Err(syn::Error::new(reply.span(), Self::REPLY_ERR));
-            }
-            Some(format_ident!("{msg_type}Reply"))
+        let (reply_part, ident) = if input.peek(Token![::]) {
+            let reply_part = input.parse()?;
+            (
+                Some(reply_part),
+                Some(format_ident!("{msg_type}{}", MessageReplyPart::REPLY_IDENT)),
+            )
         } else {
-            None
+            (None, None)
         };
         let owned_states = if input.peek(token::Bracket) {
             IdentArray::parse(input)?
@@ -318,13 +407,115 @@ impl Parse for Message {
         };
         Ok(Self {
             msg_type,
-            is_reply,
+            reply_part,
             owned_states,
-            variant,
+            ident,
+        })
+    }
+}
+
+impl GetSpan for Message {
+    fn span(&self) -> Span {
+        let mut span = self.msg_type.span();
+        if let Some(reply_part) = &self.reply_part {
+            span = span.left_join(reply_part.span());
+        }
+        span.left_join(self.owned_states.span())
+    }
+}
+
+#[cfg_attr(test, derive(Debug))]
+#[derive(Hash, PartialEq, Eq)]
+struct MessageReplyPart {
+    colons: Token![::],
+    reply: Ident,
+}
+
+impl MessageReplyPart {
+    const REPLY_ERR: &str = "expected 'Reply'";
+    const REPLY_IDENT: &str = "Reply";
+}
+
+impl Parse for MessageReplyPart {
+    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+        Ok(Self {
+            colons: input.parse()?,
+            reply: check_ident(input.parse()?, Self::REPLY_IDENT, Self::REPLY_ERR)?,
         })
     }
 }
 
+impl GetSpan for MessageReplyPart {
+    fn span(&self) -> Span {
+        self.colons.span().left_join(self.reply.span())
+    }
+}
+
+/// Verifies that an ident has the expected string contents and returns it if it does. An error is
+/// returned containing the given error message if it does not.
+fn check_ident(ident: Ident, expected: &str, err_msg: &str) -> syn::Result<Ident> {
+    if ident == expected {
+        Ok(ident)
+    } else {
+        Err(syn::Error::new(ident.span(), err_msg))
+    }
+}
+
+/// Trait for types which represent a region of source code. This is similar to the [Spanned] trait
+/// in the [syn] crate. A new trait was needed because [Spanned] is sealed.
+pub(crate) trait GetSpan {
+    /// Returns the [Span] covering the source code represented by this syntax value.
+    fn span(&self) -> Span;
+}
+
+impl GetSpan for Span {
+    fn span(&self) -> Span {
+        *self
+    }
+}
+
+impl<'a, T: GetSpan> GetSpan for &'a T {
+    fn span(&self) -> Span {
+        (*self).span()
+    }
+}
+
+/// Returns the span of a punctuated sequence of values which implement [GetSpan].
+fn punctuated_span<T: GetSpan, P>(arg: &Punctuated<T, P>) -> Span {
+    let mut iter = arg.into_iter();
+    let mut span = if let Some(state) = iter.next() {
+        state.span()
+    } else {
+        return Span::call_site();
+    };
+    for state in iter {
+        span = span.left_join(state.span());
+    }
+    span
+}
+
+trait LeftJoin<Rhs> {
+    /// Attempts to join two [GetSpan] values, but if the result of the join is `None`, then just
+    /// the left span is returned.
+    fn left_join(&self, other: Rhs) -> Span;
+}
+
+impl<R: GetSpan> LeftJoin<R> for Span {
+    fn left_join(&self, other: R) -> Span {
+        self.join(other.span()).unwrap_or(*self)
+    }
+}
+
+impl<R: GetSpan> LeftJoin<Option<R>> for Span {
+    fn left_join(&self, other: Option<R>) -> Span {
+        if let Some(other) = other {
+            self.left_join(other)
+        } else {
+            *self
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -394,7 +585,13 @@ let states = [{}];
 
     impl NameDef {
         fn new(name: &str) -> Self {
-            Self { name: ident(name) }
+            Self {
+                let_token: Token![let](Span::call_site()),
+                name_ident: ident("name"),
+                eq_token: Token![=](Span::call_site()),
+                name: ident(name),
+                semi_token: Token![;](Span::call_site()),
+            }
         }
     }
 
@@ -411,7 +608,7 @@ let states = [{}];
 
     #[test]
     fn name_def_wrong_ident_err() {
-        let result = parse_str::<NameDef>("let nam = Samson;");
+        let result = parse_str::<NameDef>("let nmae = Shodan;");
 
         assert!(result.is_err());
         let err_str = result.err().unwrap().to_string();
@@ -508,11 +705,19 @@ let states = [{}];
             out_states: impl Iterator<Item = State>,
             out_msgs: impl Iterator<Item = Dest>,
         ) -> Self {
+            let out_msgs = DestList(out_msgs.collect());
+            let redirect = if out_msgs.as_ref().is_empty() {
+                None
+            } else {
+                Some(Token![>](Span::call_site()))
+            };
             Self {
                 in_state,
                 in_msg,
+                arrow: Token![->](Span::call_site()),
                 out_states: StatesList(out_states.collect()),
-                out_msgs: DestList(out_msgs.collect()),
+                redirect,
+                out_msgs,
             }
         }
     }
@@ -769,16 +974,21 @@ let states = [{}];
             is_reply: bool,
             owned_states: impl Iterator<Item = &'static str>,
         ) -> Self {
-            let variant = if is_reply {
-                Some(format_ident!("{}Reply", msg_type))
+            let (reply_part, ident_field) = if is_reply {
+                let reply_part = MessageReplyPart {
+                    colons: Token![::](Span::call_site()),
+                    reply: ident(MessageReplyPart::REPLY_IDENT),
+                };
+                let variant = format_ident!("{}{}", msg_type, MessageReplyPart::REPLY_IDENT);
+                (Some(reply_part), Some(variant))
             } else {
-                None
+                (None, None)
             };
             Self {
                 msg_type: ident(msg_type),
-                is_reply,
+                reply_part,
                 owned_states: IdentArray::new(owned_states),
-                variant,
+                ident: ident_field,
             }
         }
     }
@@ -796,7 +1006,7 @@ let states = [{}];
         );
 
         assert_eq!(actual.msg_type, EXPECTED_MSG_TYPE);
-        assert_eq!(actual.is_reply, EXPECTED_IS_REPLY);
+        assert_eq!(actual.is_reply(), EXPECTED_IS_REPLY);
         assert_eq!(actual.owned_states.0.len(), EXPECTED_OWNED_STATES.len());
         assert_eq!(actual.owned_states.0[0], EXPECTED_OWNED_STATES[0]);
         assert_eq!(actual.owned_states.0[1], EXPECTED_OWNED_STATES[1]);
@@ -833,7 +1043,7 @@ let states = [{}];
 
         assert!(result.is_err());
         let err_str = result.err().unwrap().to_string();
-        assert_eq!(Message::REPLY_ERR, err_str);
+        assert_eq!(MessageReplyPart::REPLY_ERR, err_str);
     }
 
     #[test]

+ 55 - 3
crates/btproto/src/validation.rs

@@ -6,7 +6,7 @@ use btrun::{Activate, End};
 
 use crate::{
     error::MaybeErr,
-    parsing::{DestinationState, Message, State},
+    parsing::{DestinationState, GetSpan, Message, State},
     Protocol,
 };
 
@@ -16,6 +16,7 @@ impl Protocol {
             .combine(self.match_receivers_and_senders())
             .combine(self.no_undeliverable_msgs())
             .combine(self.valid_replies())
+            .combine(self.msg_sent_or_received())
             .into()
     }
 
@@ -108,7 +109,7 @@ impl Protocol {
         for transition in self.transitions.iter() {
             let mut allowed_states: Option<HashSet<&Ident>> = None;
             for dest in transition.out_msgs.as_ref().iter() {
-                if dest.msg.is_reply {
+                if dest.msg.is_reply() {
                     continue;
                 }
                 match &dest.state {
@@ -153,7 +154,7 @@ impl Protocol {
                 .as_ref()
                 .iter()
                 .map(|dest| &dest.msg)
-                .filter(|msg| msg.is_reply)
+                .filter(|msg| msg.is_reply())
                 .collect();
             if replies.is_empty() {
                 continue;
@@ -181,6 +182,21 @@ impl Protocol {
         }
         err
     }
+
+    const NO_MSG_SENT_OR_RECEIVED_ERR: &str = "A transition must send or receive a message.";
+
+    /// Verifies that either a message is received, or sent by a transition. The rational behind
+    /// this is that if no message is sent or received, then the state transition is unobserved by
+    /// other actors, and so should not be in a protocol.
+    fn msg_sent_or_received(&self) -> MaybeErr {
+        self.transitions
+            .iter()
+            .filter(|transition| {
+                transition.in_msg.is_none() && transition.out_msgs.as_ref().is_empty()
+            })
+            .map(|transition| syn::Error::new(transition.span(), Self::NO_MSG_SENT_OR_RECEIVED_ERR))
+            .collect()
+    }
 }
 
 #[cfg(test)]
@@ -456,4 +472,40 @@ Server?Msg -> Server, >Client!Msg::Reply, OtherClient!Msg::Reply;
 
         assert_err!(result, Protocol::MULTIPLE_REPLIES_ERR);
     }
+
+    #[test]
+    fn msg_sent_or_received_msg_received_ok() {
+        const INPUT: &str = "
+let name = Test;
+let states = [Init];
+Init?Activate -> End;
+";
+        let result = parse_str::<Protocol>(INPUT).unwrap().msg_sent_or_received();
+
+        assert_ok!(result);
+    }
+
+    #[test]
+    fn msg_sent_or_received_msg_sent_ok() {
+        const INPUT: &str = "
+let name = Test;
+let states = [First, Second];
+First -> First, >Second!Msg;
+";
+        let result = parse_str::<Protocol>(INPUT).unwrap().msg_sent_or_received();
+
+        assert_ok!(result);
+    }
+
+    #[test]
+    fn msg_sent_or_received_neither_err() {
+        const INPUT: &str = "
+let name = Test;
+let states = [First];
+First -> First;
+";
+        let result = parse_str::<Protocol>(INPUT).unwrap().msg_sent_or_received();
+
+        assert_err!(result, Protocol::NO_MSG_SENT_OR_RECEIVED_ERR);
+    }
 }