ソースを参照

Refactored btproto to make it easier to evolve.

Matthew Carr 1 年間 前
コミット
43a8fed38d

+ 163 - 0
crates/btproto/src/case_convert.rs

@@ -0,0 +1,163 @@
+//! Code for converting between different casing disciplines.
+
+use proc_macro2::Ident;
+
+pub(crate) trait CaseConvert {
+    /// Converts a name in snake_case to PascalCase.
+    fn snake_to_pascal(&self) -> String;
+    /// Converts a name in PascalCase to snake_case.
+    fn pascal_to_snake(&self) -> String;
+}
+
+impl CaseConvert for String {
+    fn snake_to_pascal(&self) -> String {
+        let mut pascal = String::with_capacity(self.len());
+        let mut prev_underscore = true;
+        for c in self.chars() {
+            if '_' == c {
+                prev_underscore = true;
+            } else {
+                if prev_underscore {
+                    pascal.extend(c.to_uppercase());
+                } else {
+                    pascal.push(c);
+                }
+                prev_underscore = false;
+            }
+        }
+        pascal
+    }
+
+    fn pascal_to_snake(&self) -> String {
+        let mut snake = String::with_capacity(self.len());
+        let mut prev_lower = false;
+        for c in self.chars() {
+            if c.is_uppercase() {
+                if prev_lower {
+                    snake.push('_');
+                }
+                snake.extend(c.to_lowercase());
+                prev_lower = false;
+            } else {
+                prev_lower = true;
+                snake.push(c);
+            }
+        }
+        snake
+    }
+}
+
+impl CaseConvert for Ident {
+    fn snake_to_pascal(&self) -> String {
+        self.to_string().snake_to_pascal()
+    }
+
+    fn pascal_to_snake(&self) -> String {
+        self.to_string().pascal_to_snake()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn string_snake_to_pascal_multiple_segments() {
+        const EXPECTED: &str = "FirstSecondThird";
+        let input = String::from("first_second_third");
+
+        let actual = input.snake_to_pascal();
+
+        assert_eq!(EXPECTED, actual);
+    }
+
+    #[test]
+    fn string_snake_to_pascal_single_segment() {
+        const EXPECTED: &str = "First";
+        let input = String::from("first");
+
+        let actual = input.snake_to_pascal();
+
+        assert_eq!(EXPECTED, actual);
+    }
+
+    #[test]
+    fn string_snake_to_pascal_empty_string() {
+        const EXPECTED: &str = "";
+        let input = String::from(EXPECTED);
+
+        let actual = input.snake_to_pascal();
+
+        assert_eq!(EXPECTED, actual);
+    }
+
+    #[test]
+    fn string_snake_to_pascal_leading_underscore() {
+        const EXPECTED: &str = "First";
+        let input = String::from("_first");
+
+        let actual = input.snake_to_pascal();
+
+        assert_eq!(EXPECTED, actual);
+    }
+
+    #[test]
+    fn string_snake_to_pascal_leading_underscores() {
+        const EXPECTED: &str = "First";
+        let input = String::from("__first");
+
+        let actual = input.snake_to_pascal();
+
+        assert_eq!(EXPECTED, actual);
+    }
+
+    #[test]
+    fn string_snake_to_pascal_multiple_underscores() {
+        const EXPECTED: &str = "FirstSecondThird";
+        let input = String::from("first__second___third");
+
+        let actual = input.snake_to_pascal();
+
+        assert_eq!(EXPECTED, actual);
+    }
+
+    #[test]
+    fn string_pascal_to_snake_multiple_segments() {
+        const EXPECTED: &str = "first_second_third";
+        let input = String::from("FirstSecondThird");
+
+        let actual = input.pascal_to_snake();
+
+        assert_eq!(EXPECTED, actual);
+    }
+
+    #[test]
+    fn string_pascal_to_snake_single_segment() {
+        let input = String::from("First");
+        const EXPECTED: &str = "first";
+
+        let actual = input.pascal_to_snake();
+
+        assert_eq!(EXPECTED, actual);
+    }
+
+    #[test]
+    fn string_pascal_to_snake_empty_string() {
+        const EXPECTED: &str = "";
+        let input = String::from(EXPECTED);
+
+        let actual = input.pascal_to_snake();
+
+        assert_eq!(EXPECTED, actual);
+    }
+
+    #[test]
+    fn string_pascal_to_snake_consecutive_uppercase() {
+        const EXPECTED: &str = "kernel_mc";
+        let input = String::from("KernelMC");
+
+        let actual = input.pascal_to_snake();
+
+        assert_eq!(EXPECTED, actual);
+    }
+}

+ 39 - 0
crates/btproto/src/error.rs

@@ -50,3 +50,42 @@ fn fold_errs(accum: Option<syn::Error>, curr: syn::Error) -> Option<syn::Error>
         Some(curr)
     }
 }
+
+#[cfg(test)]
+/// Panics if the given value does not transform into `Ok`.
+pub(crate) fn assert_ok<T, E: Into<syn::Result<T>>>(maybe_err: E) {
+    let result: syn::Result<T> = maybe_err.into();
+    assert!(result.is_ok(), "{}", result.err().unwrap());
+}
+
+#[cfg(test)]
+/// Panics if the given value does not transform into `Err` with an error value containing the
+/// given message.
+pub(crate) fn assert_err<T, E: Into<syn::Result<T>>>(maybe_err: E, expected_msg: &str) {
+    let result: syn::Result<T> = maybe_err.into();
+    assert!(result.is_err());
+    assert_eq!(expected_msg, result.err().unwrap().to_string());
+}
+
+/// User-visible compile time error messages.
+pub(crate) mod msgs {
+    /// Indicates that a duplicate transition has been defined.
+    ///
+    /// This means it is either handling the same message type in the same state sending the
+    /// same message type in the same state as another transition.
+    pub(crate) const DUPLICATE_TRANSITION: &str = "Duplicate transition.";
+    pub(crate) const NO_MSG_SENT_OR_RECEIVED_ERR: &str =
+        "A transition must send or receive a message.";
+    pub(crate) const UNDECLARED_STATE_ERR: &str = "State was not declared.";
+    pub(crate) const UNUSED_STATE_ERR: &str = "State was declared but never used.";
+    pub(crate) const UNMATCHED_SENDER_ERR: &str = "No receiver found for message type.";
+    pub(crate) const UNMATCHED_RECEIVER_ERR: &str = "No sender found for message type.";
+    pub(crate) const UNDELIVERABLE_ERR: &str =
+        "Receiver must either be a service, an owned state, or an out state, or the message must be a reply.";
+    pub(crate) const INVALID_REPLY_ERR: &str =
+        "Replies can only be used in transitions which handle messages.";
+    pub(crate) const MULTIPLE_REPLIES_ERR: &str =
+        "Only a single reply can be sent in response to any message.";
+    pub(crate) const CLIENT_RECEIVED_NON_REPLY_ERR: &str =
+        "A client actor cannot receive a message which is not a reply.";
+}

+ 21 - 260
crates/btproto/src/generation.rs

@@ -1,32 +1,20 @@
-use std::collections::{HashMap, HashSet};
-
-use crate::parsing::{Message, Transition};
-use btrun::End;
-
-use super::Protocol;
-use proc_macro2::{Ident, TokenStream};
+use proc_macro2::TokenStream;
 use quote::{format_ident, quote, ToTokens};
 
-impl ToTokens for Protocol {
+use crate::model::{MethodModel, ProtocolModel};
+
+impl ToTokens for ProtocolModel {
     fn to_tokens(&self, tokens: &mut TokenStream) {
         tokens.extend(self.generate_message_enum());
         tokens.extend(self.generate_state_traits());
     }
 }
 
-impl Protocol {
+impl ProtocolModel {
     fn generate_message_enum(&self) -> TokenStream {
-        let mut msgs: HashSet<&Message> = HashSet::new();
-        for transition in self.transitions.iter() {
-            // We only need to insert received messages because every sent message has a
-            // corresponding receiver thanks to the validator.
-            if let Some(msg) = transition.in_msg() {
-                msgs.insert(msg);
-            }
-        }
-        let variants = msgs.iter().map(|msg| msg.ident());
-        let msg_types = msgs.iter().map(|msg| msg.type_tokens());
-        let enum_name = format_ident!("{}Msgs", self.name_def.name);
+        let variants = self.msg_lookup().msg_iter().map(|msg| msg.msg_name());
+        let msg_types = self.msg_lookup().msg_iter().map(|msg| msg.msg_type());
+        let enum_name = format_ident!("{}Msgs", self.def().name_def.name);
         quote! {
             pub enum #enum_name {
                 #( #variants(#msg_types) ),*
@@ -35,19 +23,15 @@ impl Protocol {
     }
 
     fn generate_state_traits(&self) -> TokenStream {
-        let mut traits: HashMap<&Ident, Vec<&Transition>> = HashMap::new();
-        for transition in self.transitions.iter() {
-            let vec = traits
-                .entry(&transition.in_state.state_trait)
-                .or_insert_with(Vec::new);
-            vec.push(transition);
-        }
+        let traits = self
+            .states_iter()
+            .map(|state| (state.name(), state.methods().values()));
         let mut tokens = TokenStream::new();
-        for (trait_ident, transitions) in traits {
-            let transition_tokens = transitions.iter().map(|x| x.generate_tokens());
+        for (trait_ident, methods) in traits {
+            let method_tokens = methods.map(|x| x.generate_tokens());
             quote! {
                 pub trait #trait_ident {
-                    #( #transition_tokens )*
+                    #( #method_tokens )*
                 }
             }
             .to_tokens(&mut tokens);
@@ -56,241 +40,18 @@ impl Protocol {
     }
 }
 
-impl Message {
-    /// Returns the tokens which represent the type of this message.
-    fn type_tokens(&self) -> TokenStream {
-        // We generate a fully-qualified path to the Activate message so that it doesn't need to be
-        // imported for every protocol definition.
-        let msg_type = {
-            let msg_type = &self.msg_type;
-            quote! { #msg_type }
-        };
-        if self.is_reply() {
-            quote! {
-                <#msg_type as ::btrun::CallMsg>::Reply
-            }
-        } else {
-            quote! {
-                #msg_type
-            }
-        }
-    }
-}
-
-impl Transition {
+impl MethodModel {
     /// Generates the tokens for the code which implements this transition.
     fn generate_tokens(&self) -> TokenStream {
-        let (msg_arg, method_ident) = if let Some(msg) = self.in_msg() {
-            let msg_type = msg.msg_type.to_token_stream();
-            let method_ident = format_ident!("handle_{}", msg.ident().pascal_to_snake());
-            let msg_arg = quote! { , msg: #msg_type };
-            (msg_arg, method_ident)
-        } else {
-            let msg_arg = quote! {};
-            let msg_names = self
-                .out_msgs
-                .as_ref()
-                .iter()
-                .fold(Option::<String>::None, |accum, curr| {
-                    let msg_name = curr.msg.ident().pascal_to_snake();
-                    if let Some(mut accum) = accum {
-                        accum.push('_');
-                        accum.push_str(&msg_name);
-                        Some(accum)
-                    } else {
-                        Some(msg_name)
-                    }
-                })
-                // Since no message is being handled, the validator ensures that at least one
-                // message is being sent. Hence this unwrap will not panic.
-                .unwrap();
-            let method_ident = format_ident!("send_{}", msg_names);
-            (msg_arg, method_ident)
-        };
-        let method_type_prefix = method_ident.snake_to_pascal();
-        let output_pairs: Vec<_> = self
-            .out_states
-            .as_ref()
-            .iter()
-            .map(|state| {
-                let state_trait = &state.state_trait;
-                if state_trait == End::ident() {
-                    (quote! {}, quote! { ::btrun::End })
-                } else {
-                    let assoc_type = format_ident!("{}{}", method_type_prefix, state_trait);
-                    let output_decl = quote! { type #assoc_type: #state_trait; };
-                    let output_type = quote! { Self::#assoc_type };
-                    (output_decl, output_type)
-                }
-            })
-            .collect();
-        let output_decls = output_pairs.iter().map(|(decl, _)| decl);
-        let output_types = output_pairs.iter().map(|(_, output_type)| output_type);
-        let future_name = format_ident!("{}Fut", method_type_prefix);
+        let method_ident = self.name().as_ref();
+        let msg_args = self.inputs().iter();
+        let output_decls = self.outputs().iter().flat_map(|output| output.decl());
+        let output_types = self.outputs().iter().flat_map(|output| output.type_name());
+        let future_name = self.future();
         quote! {
             #( #output_decls )*
             type #future_name: ::std::future::Future<Output = Result<( #( #output_types ),* )>>;
-            fn #method_ident(self #msg_arg) -> Self::#future_name;
-        }
-    }
-}
-
-trait CaseConvert {
-    /// Converts a name in snake_case to PascalCase.
-    fn snake_to_pascal(&self) -> String;
-    /// Converts a name in PascalCase to snake_case.
-    fn pascal_to_snake(&self) -> String;
-}
-
-impl CaseConvert for String {
-    fn snake_to_pascal(&self) -> String {
-        let mut pascal = String::with_capacity(self.len());
-        let mut prev_underscore = true;
-        for c in self.chars() {
-            if '_' == c {
-                prev_underscore = true;
-            } else {
-                if prev_underscore {
-                    pascal.extend(c.to_uppercase());
-                } else {
-                    pascal.push(c);
-                }
-                prev_underscore = false;
-            }
-        }
-        pascal
-    }
-
-    fn pascal_to_snake(&self) -> String {
-        let mut snake = String::with_capacity(self.len());
-        let mut prev_lower = false;
-        for c in self.chars() {
-            if c.is_uppercase() {
-                if prev_lower {
-                    snake.push('_');
-                }
-                snake.extend(c.to_lowercase());
-                prev_lower = false;
-            } else {
-                prev_lower = true;
-                snake.push(c);
-            }
+            fn #method_ident(self #( , #msg_args )*) -> Self::#future_name;
         }
-        snake
-    }
-}
-
-impl CaseConvert for Ident {
-    fn snake_to_pascal(&self) -> String {
-        self.to_string().snake_to_pascal()
-    }
-
-    fn pascal_to_snake(&self) -> String {
-        self.to_string().pascal_to_snake()
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-
-    #[test]
-    fn string_snake_to_pascal_multiple_segments() {
-        const EXPECTED: &str = "FirstSecondThird";
-        let input = String::from("first_second_third");
-
-        let actual = input.snake_to_pascal();
-
-        assert_eq!(EXPECTED, actual);
-    }
-
-    #[test]
-    fn string_snake_to_pascal_single_segment() {
-        const EXPECTED: &str = "First";
-        let input = String::from("first");
-
-        let actual = input.snake_to_pascal();
-
-        assert_eq!(EXPECTED, actual);
-    }
-
-    #[test]
-    fn string_snake_to_pascal_empty_string() {
-        const EXPECTED: &str = "";
-        let input = String::from(EXPECTED);
-
-        let actual = input.snake_to_pascal();
-
-        assert_eq!(EXPECTED, actual);
-    }
-
-    #[test]
-    fn string_snake_to_pascal_leading_underscore() {
-        const EXPECTED: &str = "First";
-        let input = String::from("_first");
-
-        let actual = input.snake_to_pascal();
-
-        assert_eq!(EXPECTED, actual);
-    }
-
-    #[test]
-    fn string_snake_to_pascal_leading_underscores() {
-        const EXPECTED: &str = "First";
-        let input = String::from("__first");
-
-        let actual = input.snake_to_pascal();
-
-        assert_eq!(EXPECTED, actual);
-    }
-
-    #[test]
-    fn string_snake_to_pascal_multiple_underscores() {
-        const EXPECTED: &str = "FirstSecondThird";
-        let input = String::from("first__second___third");
-
-        let actual = input.snake_to_pascal();
-
-        assert_eq!(EXPECTED, actual);
-    }
-
-    #[test]
-    fn string_pascal_to_snake_multiple_segments() {
-        const EXPECTED: &str = "first_second_third";
-        let input = String::from("FirstSecondThird");
-
-        let actual = input.pascal_to_snake();
-
-        assert_eq!(EXPECTED, actual);
-    }
-
-    #[test]
-    fn string_pascal_to_snake_single_segment() {
-        let input = String::from("First");
-        const EXPECTED: &str = "first";
-
-        let actual = input.pascal_to_snake();
-
-        assert_eq!(EXPECTED, actual);
-    }
-
-    #[test]
-    fn string_pascal_to_snake_empty_string() {
-        const EXPECTED: &str = "";
-        let input = String::from(EXPECTED);
-
-        let actual = input.pascal_to_snake();
-
-        assert_eq!(EXPECTED, actual);
-    }
-
-    #[test]
-    fn string_pascal_to_snake_consecutive_uppercase() {
-        const EXPECTED: &str = "kernel_mc";
-        let input = String::from("KernelMC");
-
-        let actual = input.pascal_to_snake();
-
-        assert_eq!(EXPECTED, actual);
     }
 }

+ 8 - 5
crates/btproto/src/lib.rs

@@ -3,13 +3,15 @@ use proc_macro::TokenStream;
 use quote::ToTokens;
 use syn::parse_macro_input;
 
-mod parsing;
-use parsing::Protocol;
-
+mod case_convert;
 mod error;
 mod generation;
+mod model;
+mod parsing;
 mod validation;
 
+use crate::{model::ProtocolModel, parsing::Protocol};
+
 macro_rules! unwrap_or_compile_err {
     ($result:expr) => {
         match $result {
@@ -54,6 +56,7 @@ macro_rules! unwrap_or_compile_err {
 #[proc_macro]
 pub fn protocol(input: TokenStream) -> TokenStream {
     let input = parse_macro_input!(input as Protocol);
-    unwrap_or_compile_err!(input.validate());
-    input.to_token_stream().into()
+    let model = unwrap_or_compile_err!(ProtocolModel::new(input));
+    unwrap_or_compile_err!(model.validate());
+    model.to_token_stream().into()
 }

+ 761 - 0
crates/btproto/src/model.rs

@@ -0,0 +1,761 @@
+use std::{
+    collections::{HashMap, HashSet},
+    hash::Hash,
+    rc::Rc,
+};
+
+use btrun::End;
+use proc_macro2::{Ident, Span, TokenStream};
+use quote::{format_ident, quote, ToTokens};
+
+use crate::{
+    case_convert::CaseConvert,
+    error,
+    parsing::MessageReplyPart,
+    parsing::{ActorDef, Dest, GetSpan, Message, Protocol, State, Transition},
+};
+
+pub(crate) struct ProtocolModel {
+    def: Protocol,
+    msg_lookup: MsgLookup,
+    actors: HashMap<Rc<Ident>, ActorModel>,
+}
+
+impl ProtocolModel {
+    pub(crate) fn new(def: Protocol) -> syn::Result<Self> {
+        let actor_lookup = ActorLookup::new(def.actor_defs.iter().map(|x| x.as_ref()));
+        let msg_lookup = MsgLookup::new(def.transitions.iter().map(|x| x.as_ref()));
+        let mut actors = HashMap::new();
+        for actor_def in def.actor_defs.iter() {
+            let actor_name = &actor_def.actor;
+            let actor_states = actor_lookup.actor_states(actor_name.as_ref());
+            let transitions_by_state = actor_states.iter().map(|state_name| {
+                let transitions = def
+                    .transitions
+                    .iter()
+                    .filter(|transition| {
+                        state_name.as_ref() == transition.in_state.state_trait.as_ref()
+                    })
+                    .cloned();
+                (state_name.clone(), transitions)
+            });
+            let actor = ActorModel::new(actor_def.clone(), &msg_lookup, transitions_by_state)?;
+            actors.insert(actor_name.clone(), actor);
+        }
+        Ok(Self {
+            def,
+            msg_lookup,
+            actors,
+        })
+    }
+
+    pub(crate) fn def(&self) -> &Protocol {
+        &self.def
+    }
+
+    pub(crate) fn msg_lookup(&self) -> &MsgLookup {
+        &self.msg_lookup
+    }
+
+    pub(crate) fn actors(&self) -> &HashMap<Rc<Ident>, ActorModel> {
+        &self.actors
+    }
+
+    pub(crate) fn actors_iter(&self) -> impl Iterator<Item = &ActorModel> {
+        self.actors.values()
+    }
+
+    pub(crate) fn states_iter(&self) -> impl Iterator<Item = &StateModel> {
+        self.actors_iter().flat_map(|actor| actor.states().values())
+    }
+
+    #[cfg(test)]
+    pub(crate) fn methods_iter(&self) -> impl Iterator<Item = &MethodModel> {
+        self.states_iter()
+            .flat_map(|state| state.methods().values())
+    }
+
+    #[cfg(test)]
+    pub(crate) fn outputs_iter(&self) -> impl Iterator<Item = &OutputModel> {
+        self.methods_iter()
+            .flat_map(|method| method.outputs().iter())
+    }
+}
+
+pub(crate) struct ActorModel {
+    #[allow(dead_code)]
+    def: Rc<ActorDef>,
+    is_client: bool,
+    states: HashMap<Rc<Ident>, StateModel>,
+}
+
+impl ActorModel {
+    fn new<S, T>(def: Rc<ActorDef>, messages: &MsgLookup, state_iter: S) -> syn::Result<Self>
+    where
+        S: IntoIterator<Item = (Rc<Ident>, T)>,
+        T: IntoIterator<Item = Rc<Transition>>,
+    {
+        let transitions: HashMap<_, Vec<_>> = state_iter
+            .into_iter()
+            .map(|(name, transitions)| (name, transitions.into_iter().collect()))
+            .collect();
+        let is_client = transitions
+            .values()
+            .flatten()
+            .any(|transition| transition.is_client());
+        let mut states = HashMap::new();
+        for (name, transitions) in transitions.into_iter() {
+            let state = StateModel::new(name.clone(), messages, transitions, is_client)?;
+            if let Some(prev) = states.insert(name, state) {
+                panic!(
+                    "States are not being grouped by actor correctly. Duplicate state name: '{}'",
+                    prev.name
+                );
+            }
+        }
+        Ok(Self {
+            def,
+            is_client,
+            states,
+        })
+    }
+
+    pub(crate) fn is_client(&self) -> bool {
+        self.is_client
+    }
+
+    pub(crate) fn states(&self) -> &HashMap<Rc<Ident>, StateModel> {
+        &self.states
+    }
+}
+
+pub(crate) struct StateModel {
+    name: Rc<Ident>,
+    methods: HashMap<Rc<Ident>, MethodModel>,
+}
+
+impl StateModel {
+    fn new<T>(
+        name: Rc<Ident>,
+        messages: &MsgLookup,
+        transitions: T,
+        part_of_client: bool,
+    ) -> syn::Result<Self>
+    where
+        T: IntoIterator<Item = Rc<Transition>>,
+    {
+        let mut methods = HashMap::new();
+        for transition in transitions.into_iter() {
+            let transition_span = transition.span();
+            let method = MethodModel::new(transition, messages, part_of_client)?;
+            if methods.insert(method.name.clone(), method).is_some() {
+                return Err(syn::Error::new(
+                    transition_span,
+                    error::msgs::DUPLICATE_TRANSITION,
+                ));
+            }
+        }
+        Ok(Self { name, methods })
+    }
+
+    pub(crate) fn name(&self) -> &Ident {
+        self.name.as_ref()
+    }
+
+    pub(crate) fn methods(&self) -> &HashMap<Rc<Ident>, MethodModel> {
+        &self.methods
+    }
+}
+
+impl GetSpan for StateModel {
+    fn span(&self) -> Span {
+        self.name.span()
+    }
+}
+
+#[cfg_attr(test, derive(Debug))]
+pub(crate) struct MethodModel {
+    def: Rc<Transition>,
+    name: Rc<Ident>,
+    inputs: Vec<InputModel>,
+    outputs: Vec<OutputModel>,
+    future: Ident,
+}
+
+impl MethodModel {
+    fn new(def: Rc<Transition>, messages: &MsgLookup, part_of_client: bool) -> syn::Result<Self> {
+        let name = Rc::new(Self::new_name(def.as_ref())?);
+        let type_prefix = name.snake_to_pascal();
+        Ok(Self {
+            name,
+            inputs: Self::new_inputs(def.as_ref(), messages, part_of_client),
+            outputs: Self::new_outputs(def.as_ref(), &type_prefix, messages, part_of_client),
+            future: format_ident!("{type_prefix}Fut"),
+            def,
+        })
+    }
+
+    fn new_name(def: &Transition) -> syn::Result<Ident> {
+        let name = if let Some(msg) = def.in_msg() {
+            format_ident!("handle_{}", msg.variant().pascal_to_snake())
+        } else {
+            let mut dests = def.out_msgs.as_ref().iter();
+            let mut msg_names = String::new();
+            if let Some(dest) = dests.next() {
+                msg_names.push_str(dest.msg.variant().pascal_to_snake().as_str());
+            } else {
+                return Err(syn::Error::new(
+                    def.span(),
+                    error::msgs::NO_MSG_SENT_OR_RECEIVED_ERR,
+                ));
+            }
+            for dest in dests {
+                msg_names.push('_');
+                msg_names.push_str(dest.msg.variant().pascal_to_snake().as_str());
+            }
+            format_ident!("send_{msg_names}")
+        };
+        Ok(name)
+    }
+
+    fn new_inputs(def: &Transition, messages: &MsgLookup, part_of_client: bool) -> Vec<InputModel> {
+        let mut inputs = Vec::new();
+        if let Some(in_msg) = def.in_msg() {
+            let msg_info = messages.lookup(in_msg);
+            inputs.push(InputModel::new(
+                msg_info.msg_name().clone(),
+                msg_info.msg_type.clone(),
+            ))
+        }
+        if part_of_client {
+            for out_msg in def.out_msgs.as_ref().iter() {
+                let msg_info = messages.lookup(&out_msg.msg);
+                inputs.push(InputModel::new(
+                    msg_info.msg_name().clone(),
+                    msg_info.msg_type.clone(),
+                ))
+            }
+        }
+        inputs
+    }
+
+    fn new_outputs(
+        def: &Transition,
+        type_prefix: &str,
+        messages: &MsgLookup,
+        part_of_client: bool,
+    ) -> Vec<OutputModel> {
+        let mut outputs = Vec::new();
+        for state in def.out_states.as_ref().iter() {
+            outputs.push(OutputModel::new(
+                OutputKind::State { def: state.clone() },
+                type_prefix,
+            ));
+        }
+        for dest in def.out_msgs.as_ref().iter() {
+            let msg_info = messages.lookup(&dest.msg);
+            outputs.push(OutputModel::new(
+                OutputKind::Msg {
+                    def: dest.clone(),
+                    msg_type: msg_info.msg_type.clone(),
+                    is_call: msg_info.is_call(),
+                    part_of_client,
+                },
+                type_prefix,
+            ))
+        }
+        outputs
+    }
+
+    pub(crate) fn def(&self) -> &Transition {
+        self.def.as_ref()
+    }
+
+    pub(crate) fn name(&self) -> &Rc<Ident> {
+        &self.name
+    }
+
+    pub(crate) fn inputs(&self) -> &Vec<InputModel> {
+        &self.inputs
+    }
+
+    pub(crate) fn outputs(&self) -> &Vec<OutputModel> {
+        &self.outputs
+    }
+
+    pub(crate) fn future(&self) -> &Ident {
+        &self.future
+    }
+}
+
+impl GetSpan for MethodModel {
+    fn span(&self) -> Span {
+        self.def.span()
+    }
+}
+
+#[cfg_attr(test, derive(Debug))]
+pub(crate) struct InputModel {
+    name: Ident,
+    arg_type: Rc<TokenStream>,
+}
+
+impl InputModel {
+    fn new(type_name: Rc<Ident>, arg_type: Rc<TokenStream>) -> Self {
+        let name = format_ident!("{}_arg", type_name.to_string().pascal_to_snake());
+        Self { name, arg_type }
+    }
+}
+
+impl ToTokens for InputModel {
+    fn to_tokens(&self, tokens: &mut TokenStream) {
+        let name = &self.name;
+        let arg_type = self.arg_type.as_ref();
+        tokens.extend(quote! { #name : #arg_type })
+    }
+}
+
+#[cfg_attr(test, derive(Debug))]
+pub(crate) struct OutputModel {
+    type_name: Option<TokenStream>,
+    decl: Option<TokenStream>,
+    #[allow(dead_code)]
+    kind: OutputKind,
+}
+
+impl OutputModel {
+    fn new(kind: OutputKind, type_prefix: &str) -> Self {
+        let (decl, type_name) = match &kind {
+            OutputKind::State { def, .. } => {
+                let state_trait = def.state_trait.as_ref();
+                if state_trait == End::ident() {
+                    let end_ident = format_ident!("{}", End::ident());
+                    (None, Some(quote! { ::btrun::#end_ident }))
+                } else {
+                    let type_name = format_ident!("{type_prefix}{}", state_trait);
+                    (
+                        Some(quote! { type  #type_name: #state_trait; }),
+                        Some(quote! { Self::#type_name }),
+                    )
+                }
+            }
+            OutputKind::Msg {
+                msg_type,
+                part_of_client,
+                is_call,
+                ..
+            } => {
+                let type_name = if *part_of_client {
+                    if *is_call {
+                        Some(quote! {
+                            <#msg_type as ::btrun::CallMsg>::Reply
+                        })
+                    } else {
+                        None
+                    }
+                } else {
+                    Some(quote! { #msg_type })
+                };
+                (None, type_name)
+            }
+        };
+        Self {
+            type_name,
+            decl,
+            kind,
+        }
+    }
+
+    pub(crate) fn type_name(&self) -> Option<&TokenStream> {
+        self.type_name.as_ref()
+    }
+
+    pub(crate) fn decl(&self) -> Option<&TokenStream> {
+        self.decl.as_ref()
+    }
+}
+
+#[cfg_attr(test, derive(Debug))]
+pub(crate) enum OutputKind {
+    State {
+        def: Rc<State>,
+    },
+    Msg {
+        #[allow(dead_code)]
+        def: Rc<Dest>,
+        msg_type: Rc<TokenStream>,
+        is_call: bool,
+        part_of_client: bool,
+    },
+}
+
+pub(crate) struct ActorLookup {
+    actor_states: HashMap<Rc<Ident>, HashSet<Rc<Ident>>>,
+}
+
+impl ActorLookup {
+    fn new<'a>(actor_defs: impl IntoIterator<Item = &'a ActorDef>) -> Self {
+        let mut actor_states = HashMap::new();
+        for actor_def in actor_defs.into_iter() {
+            let mut states = HashSet::new();
+            for state in actor_def.states.as_ref().iter() {
+                states.insert(state.clone());
+            }
+            actor_states.insert(actor_def.actor.clone(), states);
+        }
+        Self { actor_states }
+    }
+
+    /// Returns the set of states associated with the given actor.
+    ///
+    /// This method **panics** if you call it with a non-existent actor name.
+    pub(crate) fn actor_states(&self, actor_name: &Ident) -> &HashSet<Rc<Ident>> {
+        self.actor_states.get(actor_name).unwrap_or_else(|| {
+            panic!("Unknown actor. This indicates there is a bug in the btproto crate.")
+        })
+    }
+}
+
+pub(crate) struct MsgLookup {
+    messages: HashMap<Rc<Ident>, MsgInfo>,
+}
+
+impl MsgLookup {
+    fn new<'a>(transitions: impl IntoIterator<Item = &'a Transition>) -> Self {
+        let mut messages = HashMap::new();
+        for transition in transitions.into_iter() {
+            let in_state = &transition.in_state.state_trait;
+            if let Some(in_msg) = transition.in_msg() {
+                let msg_name = &in_msg.msg_type;
+                let msg_info = messages
+                    .entry(msg_name.clone())
+                    .or_insert_with(|| MsgInfo::empty(msg_name.clone()));
+                msg_info.record_receiver(in_msg, in_state.clone());
+            }
+            for out_msg in transition.out_msgs.as_ref().iter() {
+                let msg_name = &out_msg.msg.msg_type;
+                let msg_info = messages
+                    .entry(msg_name.clone())
+                    .or_insert_with(|| MsgInfo::empty(msg_name.clone()));
+                msg_info.record_sender(&out_msg.msg, in_state.clone());
+            }
+        }
+        Self { messages }
+    }
+
+    pub(crate) fn lookup(&self, msg: &Message) -> &MsgInfo {
+        self.messages
+            .get(msg.msg_type.as_ref())
+            // Since a message is either sent or received, and we've added all such messages in
+            // the new method, this unwrap should not panic.
+            .unwrap_or_else(|| {
+                panic!("Failed to find message info. There is a bug in MessageLookup::new.")
+            })
+            .info_for(msg)
+    }
+
+    pub(crate) fn msg_iter(&self) -> impl Iterator<Item = &MsgInfo> {
+        self.messages
+            .values()
+            .flat_map(|msg_info| [Some(msg_info), msg_info.reply()])
+            .flatten()
+    }
+}
+
+impl AsRef<HashMap<Rc<Ident>, MsgInfo>> for MsgLookup {
+    fn as_ref(&self) -> &HashMap<Rc<Ident>, MsgInfo> {
+        &self.messages
+    }
+}
+
+pub(crate) struct MsgInfo {
+    /// The unique name of this message. If it is a reply, it will end in
+    /// `MessageReplyPart::REPLY_IDENT`.
+    msg_name: Rc<Ident>,
+    /// The type of this message. If this message is not a reply, this is just `msg_name`. If it is
+    /// a reply, it is `<#msg_name as ::btrun::CallMsg>::Reply`.
+    msg_type: Rc<TokenStream>,
+    is_reply: bool,
+    senders: HashSet<Rc<Ident>>,
+    receivers: HashSet<Rc<Ident>>,
+    reply: Option<Box<MsgInfo>>,
+}
+
+impl MsgInfo {
+    fn empty(msg_name: Rc<Ident>) -> Self {
+        Self {
+            msg_name: msg_name.clone(),
+            msg_type: Rc::new(quote! { #msg_name }),
+            is_reply: false,
+            senders: HashSet::new(),
+            receivers: HashSet::new(),
+            reply: None,
+        }
+    }
+
+    fn is_call(&self) -> bool {
+        self.reply.is_some()
+    }
+
+    fn info_for(&self, msg: &Message) -> &Self {
+        if msg.is_reply() {
+            // If this message is a reply, then we should have seen it in MsgLookup::new and
+            // initialized its reply pointer. So this unwrap shouldn't panic.
+            self.reply.as_ref().unwrap_or_else(|| {
+                panic!(
+                "A reply message was not properly recorded. There is a bug in MessageLookup::new."
+            )
+            })
+        } else {
+            self
+        }
+    }
+
+    fn info_for_mut(&mut self, msg: &Message) -> &mut Self {
+        if msg.is_reply() {
+            self.reply.get_or_insert_with(|| {
+                let mut reply = MsgInfo::empty(msg.msg_type.clone());
+                reply.is_reply = true;
+                reply.msg_name = Rc::new(format_ident!(
+                    "{}{}",
+                    msg.msg_type.as_ref(),
+                    MessageReplyPart::REPLY_IDENT
+                ));
+                let parent = self.msg_name.as_ref();
+                reply.msg_type = Rc::new(quote! { <#parent as ::btrun::CallMsg>::Reply });
+                Box::new(reply)
+            })
+        } else {
+            self
+        }
+    }
+
+    fn record_receiver(&mut self, msg: &Message, receiver: Rc<Ident>) {
+        let target = self.info_for_mut(msg);
+        target.receivers.insert(receiver);
+    }
+
+    fn record_sender(&mut self, msg: &Message, sender: Rc<Ident>) {
+        let target = self.info_for_mut(msg);
+        target.senders.insert(sender);
+    }
+
+    pub(crate) fn msg_name(&self) -> &Rc<Ident> {
+        &self.msg_name
+    }
+
+    pub(crate) fn msg_type(&self) -> &Rc<TokenStream> {
+        &self.msg_type
+    }
+
+    pub(crate) fn reply(&self) -> Option<&MsgInfo> {
+        self.reply.as_ref().map(|ptr| ptr.as_ref())
+    }
+}
+
+impl PartialEq for MsgInfo {
+    fn eq(&self, other: &Self) -> bool {
+        self.msg_name.as_ref() == other.msg_name.as_ref()
+    }
+}
+
+impl Eq for MsgInfo {}
+
+impl Hash for MsgInfo {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        self.msg_name.hash(state)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::{
+        error::{assert_err, assert_ok},
+        parsing::{DestinationState, NameDef},
+    };
+
+    use super::*;
+
+    #[test]
+    fn protocol_model_new_minimal_ok() {
+        let input = Protocol::minimal();
+
+        let result = ProtocolModel::new(input);
+
+        assert_ok(result);
+    }
+
+    #[test]
+    fn protocol_model_new_dup_recv_transition_err() {
+        let input = Protocol::new(
+            NameDef::new("Undeclared"),
+            [ActorDef::new("actor", ["Init", "Next"])],
+            [
+                Transition::new(
+                    State::new("Init", []),
+                    Some(Message::new("Query", false, [])),
+                    [State::new("Next", [])],
+                    [],
+                ),
+                Transition::new(
+                    State::new("Init", []),
+                    Some(Message::new("Query", false, [])),
+                    [State::new("Init", [])],
+                    [],
+                ),
+            ],
+        );
+
+        let result = ProtocolModel::new(input);
+
+        assert_err(result, error::msgs::DUPLICATE_TRANSITION);
+    }
+
+    #[test]
+    fn protocol_model_new_dup_send_transition_err() {
+        let input = Protocol::new(
+            NameDef::new("Undeclared"),
+            [
+                ActorDef::new("server", ["Init", "Next"]),
+                ActorDef::new("client", ["Client"]),
+            ],
+            [
+                Transition::new(
+                    State::new("Client", []),
+                    None,
+                    [State::new("Client", [])],
+                    [Dest::new(
+                        DestinationState::Individual(State::new("Init", [])),
+                        Message::new("Query", false, []),
+                    )],
+                ),
+                Transition::new(
+                    State::new("Client", []),
+                    None,
+                    [State::new("Client", [])],
+                    [Dest::new(
+                        DestinationState::Individual(State::new("Next", [])),
+                        Message::new("Query", false, []),
+                    )],
+                ),
+            ],
+        );
+
+        let result = ProtocolModel::new(input);
+
+        assert_err(result, error::msgs::DUPLICATE_TRANSITION);
+    }
+
+    #[test]
+    fn msg_sent_or_received_msg_received_ok() {
+        let input = Protocol::new(
+            NameDef::new("Test"),
+            [ActorDef::new("actor", ["Init"])],
+            [Transition::new(
+                State::new("Init", []),
+                Some(Message::new("Activate", false, [])),
+                [State::new("End", [])],
+                [],
+            )],
+        );
+
+        let result = ProtocolModel::new(input);
+
+        assert_ok(result);
+    }
+
+    #[test]
+    fn msg_sent_or_received_msg_sent_ok() {
+        let input = Protocol::new(
+            NameDef::new("Test"),
+            [ActorDef::new("actor", ["First", "Second"])],
+            [Transition::new(
+                State::new("First", []),
+                None,
+                [State::new("First", [])],
+                [Dest::new(
+                    DestinationState::Individual(State::new("Second", [])),
+                    Message::new("Msg", false, []),
+                )],
+            )],
+        );
+
+        let result = ProtocolModel::new(input);
+
+        assert_ok(result);
+    }
+
+    #[test]
+    fn msg_sent_or_received_neither_err() {
+        let input = Protocol::new(
+            NameDef::new("Test"),
+            [ActorDef::new("actor", ["First"])],
+            [Transition::new(
+                State::new("First", []),
+                None,
+                [State::new("First", [])],
+                [],
+            )],
+        );
+
+        let result = ProtocolModel::new(input);
+
+        assert_err(result, error::msgs::NO_MSG_SENT_OR_RECEIVED_ERR);
+    }
+
+    #[test]
+    fn reply_is_marked_in_output() {
+        const MSG: &str = "Ping";
+        let input = Protocol::new(
+            NameDef::new("ReplyTest"),
+            [
+                ActorDef::new("server", ["Listening"]),
+                ActorDef::new("client", ["Client", "Waiting"]),
+            ],
+            [
+                Transition::new(
+                    State::new("Client", []),
+                    None,
+                    [State::new("Waiting", [])],
+                    [Dest::new(
+                        DestinationState::Service(State::new("Listening", [])),
+                        Message::new(MSG, false, []),
+                    )],
+                ),
+                Transition::new(
+                    State::new("Listening", []),
+                    Some(Message::new(MSG, false, [])),
+                    [State::new("Listening", [])],
+                    [Dest::new(
+                        DestinationState::Individual(State::new("Client", [])),
+                        Message::new(MSG, true, []),
+                    )],
+                ),
+                Transition::new(
+                    State::new("Waiting", []),
+                    Some(Message::new(MSG, true, [])),
+                    [State::new(End::ident(), [])],
+                    [],
+                ),
+            ],
+        );
+
+        let actual = ProtocolModel::new(input).unwrap();
+
+        let outputs: Vec<_> = actual
+            .outputs_iter()
+            .map(|output| {
+                if let OutputKind::Msg { is_call, .. } = output.kind {
+                    Some(is_call)
+                } else {
+                    None
+                }
+            })
+            .filter(|x| x.is_some())
+            .map(|x| x.unwrap())
+            .collect();
+        assert_eq!(2, outputs.len());
+        assert_eq!(1, outputs.iter().filter(|is_reply| **is_reply).count());
+        assert_eq!(1, outputs.iter().filter(|is_reply| !*is_reply).count());
+    }
+}

+ 92 - 38
crates/btproto/src/parsing.rs

@@ -2,6 +2,7 @@
 
 use proc_macro2::Span;
 use quote::format_ident;
+use std::{ops::Deref, rc::Rc};
 use syn::{
     bracketed, parenthesized,
     parse::{Parse, ParseStream},
@@ -18,8 +19,8 @@ pub(crate) struct Protocol {
     name_def_semi: Token![;],
     #[allow(dead_code)]
     version_def: Option<(VersionDef, Token![;])>,
-    pub(crate) actor_defs: Punctuated<ActorDef, Token![;]>,
-    pub(crate) transitions: Punctuated<Transition, Token![;]>,
+    pub(crate) actor_defs: Punctuated<Rc<ActorDef>, Token![;]>,
+    pub(crate) transitions: Punctuated<Rc<Transition>, Token![;]>,
 }
 
 #[cfg(test)]
@@ -29,9 +30,11 @@ impl Protocol {
         actor_def: impl IntoIterator<Item = ActorDef>,
         transitions: impl IntoIterator<Item = Transition>,
     ) -> Self {
-        let mut actor_def: Punctuated<ActorDef, Token![;]> = actor_def.into_iter().collect();
+        let mut actor_def: Punctuated<Rc<ActorDef>, Token![;]> =
+            actor_def.into_iter().map(Rc::new).collect();
         actor_def.push_punct(Token![;](Span::call_site()));
-        let mut transitions: Punctuated<Transition, Token![;]> = transitions.into_iter().collect();
+        let mut transitions: Punctuated<Rc<Transition>, Token![;]> =
+            transitions.into_iter().map(Rc::new).collect();
         transitions.push_punct(Token![;](Span::call_site()));
         Self {
             name_def,
@@ -52,6 +55,35 @@ impl Protocol {
         protocol.version_def = Some((version_def, Token![;](Span::call_site())));
         protocol
     }
+
+    /// Creates minimal [Protocol] value.
+    pub(crate) fn minimal() -> Protocol {
+        const STATE_NAME: &str = "Init";
+        Protocol::new(
+            NameDef::new("Test"),
+            [
+                ActorDef::new("server", [STATE_NAME]),
+                ActorDef::new("client", ["Client"]),
+            ],
+            [
+                Transition::new(
+                    State::new("Client", []),
+                    None,
+                    [State::new("End", [])],
+                    [Dest::new(
+                        DestinationState::Service(State::new(STATE_NAME, [])),
+                        Message::new("Msg", false, []),
+                    )],
+                ),
+                Transition::new(
+                    State::new(STATE_NAME, []),
+                    Some(Message::new("Msg", false, [])),
+                    [State::new("End", [])],
+                    [],
+                ),
+            ],
+        )
+    }
 }
 
 impl Parse for Protocol {
@@ -69,7 +101,8 @@ impl Parse for Protocol {
             None
         };
         let actor_defs = Punctuated::parse_list(input, |input| !input.peek(Token![let]))?;
-        let transitions = input.parse_terminated(Transition::parse, Token![;])?;
+        let transitions =
+            input.parse_terminated(|input| Ok(Rc::new(Transition::parse(input)?)), Token![;])?;
         Ok(Protocol {
             name_def,
             name_def_semi,
@@ -182,7 +215,7 @@ impl GetSpan for VersionDef {
 #[cfg_attr(test, derive(Debug, PartialEq))]
 pub(crate) struct ActorDef {
     let_token: Token![let],
-    pub(crate) actor: Ident,
+    pub(crate) actor: Rc<Ident>,
     eq_token: Token![=],
     pub(crate) states: IdentArray,
 }
@@ -192,7 +225,7 @@ impl ActorDef {
     pub(crate) fn new(actor: &str, state_names: impl IntoIterator<Item = &'static str>) -> Self {
         Self {
             let_token: Token![let](Span::call_site()),
-            actor: new_ident(actor),
+            actor: new_ident(actor).into(),
             eq_token: Token![=](Span::call_site()),
             states: IdentArray::new(state_names),
         }
@@ -204,7 +237,7 @@ impl Parse for ActorDef {
     fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
         Ok(Self {
             let_token: input.parse()?,
-            actor: input.parse()?,
+            actor: Rc::new(input.parse()?),
             eq_token: input.parse()?,
             states: input.parse()?,
         })
@@ -225,7 +258,7 @@ impl GetSpan for ActorDef {
 #[derive(Hash, PartialEq, Eq)]
 pub(crate) struct IdentArray {
     bracket: Bracket,
-    idents: Punctuated<Ident, Token![,]>,
+    idents: Punctuated<Rc<Ident>, Token![,]>,
 }
 
 impl IdentArray {
@@ -244,7 +277,11 @@ impl IdentArray {
     pub(crate) fn new(state_names: impl IntoIterator<Item = &'static str>) -> Self {
         Self {
             bracket: Bracket::default(),
-            idents: state_names.into_iter().map(new_ident).collect(),
+            idents: state_names
+                .into_iter()
+                .map(new_ident)
+                .map(Rc::new)
+                .collect(),
         }
     }
 }
@@ -264,7 +301,8 @@ impl Parse for IdentArray {
     fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
         let content;
         let bracket = bracketed!(content in input);
-        let idents = content.parse_terminated(Ident::parse, Token![,])?;
+        let idents =
+            content.parse_terminated(|input| Ok(Rc::new(Ident::parse(input)?)), Token![,])?;
         if idents.is_empty() {
             return Err(syn::Error::new(bracket.span.open(), Self::EMPTY_ERR));
         }
@@ -272,8 +310,8 @@ impl Parse for IdentArray {
     }
 }
 
-impl AsRef<Punctuated<Ident, Token![,]>> for IdentArray {
-    fn as_ref(&self) -> &Punctuated<Ident, Token![,]> {
+impl AsRef<Punctuated<Rc<Ident>, Token![,]>> for IdentArray {
+    fn as_ref(&self) -> &Punctuated<Rc<Ident>, Token![,]> {
         &self.idents
     }
 }
@@ -292,6 +330,10 @@ impl Transition {
     pub(crate) fn in_msg(&self) -> Option<&Message> {
         self.in_msg.as_ref().map(|(_, msg)| msg)
     }
+
+    pub(crate) fn is_client(&self) -> bool {
+        self.in_msg.is_none()
+    }
 }
 
 #[cfg(test)]
@@ -302,7 +344,7 @@ impl Transition {
         out_states: impl IntoIterator<Item = State>,
         out_msgs: impl IntoIterator<Item = Dest>,
     ) -> Self {
-        let out_msgs = DestList(out_msgs.into_iter().collect());
+        let out_msgs = DestList(out_msgs.into_iter().map(Rc::new).collect());
         let redirect = if out_msgs.as_ref().is_empty() {
             None
         } else {
@@ -312,7 +354,7 @@ impl Transition {
             in_state,
             in_msg: in_msg.map(|msg| (Token![?](Span::call_site()), msg)),
             arrow: Token![->](Span::call_site()),
-            out_states: StatesList(out_states.into_iter().collect()),
+            out_states: StatesList(out_states.into_iter().map(Rc::new).collect()),
             redirect,
             out_msgs,
         }
@@ -361,7 +403,7 @@ impl GetSpan for Transition {
 #[cfg_attr(test, derive(Debug))]
 #[derive(Hash, PartialEq, Eq)]
 pub(crate) struct State {
-    pub(crate) state_trait: Ident,
+    pub(crate) state_trait: Rc<Ident>,
     pub(crate) owned_states: IdentArray,
 }
 
@@ -372,7 +414,7 @@ impl State {
         owned_states: impl IntoIterator<Item = &'static str>,
     ) -> Self {
         Self {
-            state_trait: new_ident(state_trait),
+            state_trait: new_ident(state_trait).into(),
             owned_states: IdentArray::new(owned_states),
         }
     }
@@ -385,7 +427,7 @@ impl State {
 impl Parse for State {
     /// state : Ident ident_array? ;
     fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
-        let state_trait = Ident::parse(input)?;
+        let state_trait = Ident::parse(input)?.into();
         let owned_states = if input.peek(Bracket) {
             IdentArray::parse(input)?
         } else {
@@ -405,7 +447,7 @@ impl GetSpan for State {
 }
 
 #[cfg_attr(test, derive(Debug, PartialEq))]
-pub(crate) struct StatesList(Punctuated<State, Token![,]>);
+pub(crate) struct StatesList(Punctuated<Rc<State>, Token![,]>);
 
 impl StatesList {
     const EMPTY_ERR: &str = "at lest one out state is required";
@@ -430,14 +472,14 @@ impl GetSpan for StatesList {
     }
 }
 
-impl AsRef<Punctuated<State, Token![,]>> for StatesList {
-    fn as_ref(&self) -> &Punctuated<State, Token![,]> {
+impl AsRef<Punctuated<Rc<State>, Token![,]>> for StatesList {
+    fn as_ref(&self) -> &Punctuated<Rc<State>, Token![,]> {
         &self.0
     }
 }
 
 #[cfg_attr(test, derive(Debug, PartialEq))]
-pub(crate) struct DestList(Punctuated<Dest, Token![,]>);
+pub(crate) struct DestList(Punctuated<Rc<Dest>, Token![,]>);
 
 impl DestList {
     const TRAILING_COMMA_ERR: &str = "No trailing comma is allowed in a destination list.";
@@ -467,8 +509,8 @@ impl GetSpan for DestList {
     }
 }
 
-impl AsRef<Punctuated<Dest, Token![,]>> for DestList {
-    fn as_ref(&self) -> &Punctuated<Dest, Token![,]> {
+impl AsRef<Punctuated<Rc<Dest>, Token![,]>> for DestList {
+    fn as_ref(&self) -> &Punctuated<Rc<Dest>, Token![,]> {
         &self.0
     }
 }
@@ -541,7 +583,7 @@ impl Parse for DestinationState {
                 return Err(syn::Error::new(extra_dest.span(), Self::MULTI_STATE_ERR));
             }
             Ok(DestinationState::Service(State {
-                state_trait: dest_state,
+                state_trait: dest_state.into(),
                 owned_states: IdentArray::empty(),
             }))
         } else {
@@ -563,7 +605,7 @@ impl GetSpan for DestinationState {
 #[cfg_attr(test, derive(Debug))]
 #[derive(Hash, PartialEq, Eq)]
 pub(crate) struct Message {
-    pub(crate) msg_type: Ident,
+    pub(crate) msg_type: Rc<Ident>,
     reply_part: Option<MessageReplyPart>,
     pub(crate) owned_states: IdentArray,
     ident: Option<Ident>,
@@ -571,7 +613,7 @@ pub(crate) struct Message {
 
 impl Message {
     /// Returns the identifier to use when naming types and variants after this message.
-    pub(crate) fn ident(&self) -> &Ident {
+    pub(crate) fn variant(&self) -> &Ident {
         if let Some(ident) = &self.ident {
             ident
         } else {
@@ -603,7 +645,7 @@ impl Message {
             (None, None)
         };
         Self {
-            msg_type: new_ident(msg_type),
+            msg_type: Rc::new(new_ident(msg_type)),
             reply_part,
             owned_states: IdentArray::new(owned_states),
             ident: ident_field,
@@ -614,7 +656,7 @@ impl Message {
 impl Parse for Message {
     /// message : Ident ( "::" "Reply" )? ident_array? ;
     fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
-        let msg_type = Ident::parse(input)?;
+        let msg_type = Rc::new(Ident::parse(input)?);
         let (reply_part, ident) = if input.peek(Token![::]) {
             let reply_part = input.parse()?;
             (
@@ -649,14 +691,14 @@ impl GetSpan for Message {
 
 #[cfg_attr(test, derive(Debug))]
 #[derive(Hash, PartialEq, Eq)]
-struct MessageReplyPart {
+pub(crate) struct MessageReplyPart {
     colons: Token![::],
     reply: Ident,
 }
 
 impl MessageReplyPart {
     const REPLY_ERR: &str = "expected 'Reply'";
-    const REPLY_IDENT: &str = "Reply";
+    pub(crate) const REPLY_IDENT: &str = "Reply";
 }
 
 impl Parse for MessageReplyPart {
@@ -703,6 +745,12 @@ impl<'a, T: GetSpan> GetSpan for &'a T {
     }
 }
 
+impl<T: GetSpan> GetSpan for Rc<T> {
+    fn span(&self) -> Span {
+        self.deref().span()
+    }
+}
+
 /// Returns the span of a punctuated sequence of values which implement [GetSpan].
 fn punctuated_span<T: GetSpan, P>(arg: &Punctuated<T, P>) -> Span {
     let mut iter = arg.into_iter();
@@ -752,7 +800,7 @@ trait ParsePunctuatedList: Sized {
     ) -> syn::Result<Self>;
 }
 
-impl<T: Parse, U: Parse> ParsePunctuatedList for Punctuated<T, U> {
+impl<T: Parse, U: Parse> ParsePunctuatedList for Punctuated<Rc<T>, U> {
     fn parse_list(
         input: ParseStream,
         should_break: impl Fn(ParseStream) -> bool,
@@ -762,7 +810,7 @@ impl<T: Parse, U: Parse> ParsePunctuatedList for Punctuated<T, U> {
             if should_break(input) {
                 break;
             }
-            output.push_value(input.parse()?);
+            output.push_value(Rc::new(input.parse()?));
             if let Ok(punct) = input.parse() {
                 output.push_punct(punct);
             }
@@ -927,8 +975,8 @@ Init?Activate -> End;"
         let actual = IdentArray::new(EXPECTED);
 
         assert_eq!(EXPECTED.len(), actual.idents.len());
-        assert_eq!(actual.idents[0], EXPECTED[0]);
-        assert_eq!(actual.idents[1], EXPECTED[1]);
+        assert_eq!(actual.idents[0].as_ref(), EXPECTED[0]);
+        assert_eq!(actual.idents[1].as_ref(), EXPECTED[1]);
     }
 
     #[test]
@@ -1206,14 +1254,20 @@ Init?Activate -> End;"
 
         let actual = Message::new(EXPECTED_MSG_TYPE, EXPECTED_IS_REPLY, EXPECTED_OWNED_STATES);
 
-        assert_eq!(actual.msg_type, EXPECTED_MSG_TYPE);
+        assert_eq!(actual.msg_type.as_ref(), EXPECTED_MSG_TYPE);
         assert_eq!(actual.is_reply(), EXPECTED_IS_REPLY);
         assert_eq!(
             actual.owned_states.idents.len(),
             EXPECTED_OWNED_STATES.len()
         );
-        assert_eq!(actual.owned_states.idents[0], EXPECTED_OWNED_STATES[0]);
-        assert_eq!(actual.owned_states.idents[1], EXPECTED_OWNED_STATES[1]);
+        assert_eq!(
+            actual.owned_states.idents[0].as_ref(),
+            EXPECTED_OWNED_STATES[0]
+        );
+        assert_eq!(
+            actual.owned_states.idents[1].as_ref(),
+            EXPECTED_OWNED_STATES[1]
+        );
     }
 
     #[test]

+ 160 - 299
crates/btproto/src/validation.rs

@@ -1,97 +1,65 @@
-use std::collections::{HashMap, HashSet};
+use std::collections::HashSet;
 
 use proc_macro2::{Ident, Span};
 
 use btrun::End;
 
 use crate::{
-    error::MaybeErr,
+    error::{self, MaybeErr},
+    model::ProtocolModel,
     parsing::{DestinationState, GetSpan, Message, State},
-    Protocol,
 };
 
-impl Protocol {
+impl ProtocolModel {
+    #[allow(dead_code)]
     pub(crate) fn validate(&self) -> syn::Result<()> {
-        let validator = ProtocolValidator::new(self);
-        validator
-            .all_states_declared_and_used()
-            .combine(validator.receivers_and_senders_matched())
-            .combine(validator.no_undeliverable_msgs())
-            .combine(validator.replies_expected())
-            .combine(validator.msg_sent_or_received())
-            .combine(validator.clients_only_receive_replies())
+        self.all_states_declared_and_used()
+            .combine(self.receivers_and_senders_matched())
+            .combine(self.no_undeliverable_msgs())
+            .combine(self.replies_expected())
+            .combine(self.clients_only_receive_replies())
             .into()
     }
-}
-
-struct ProtocolValidator<'a> {
-    protocol: &'a Protocol,
-    actors_by_states: HashMap<&'a Ident, &'a Ident>,
-    client_actors: HashSet<&'a Ident>,
-}
-
-impl<'a> ProtocolValidator<'a> {
-    fn new(protocol: &'a Protocol) -> Self {
-        let mut actors_by_states: HashMap<&Ident, &Ident> = HashMap::new();
-        for actor_def in protocol.actor_defs.iter() {
-            for state in actor_def.states.as_ref().iter() {
-                actors_by_states.insert(state, &actor_def.actor);
-            }
-        }
-        let client_actors: HashSet<&Ident> = protocol
-            .transitions
-            .iter()
-            .filter(|transition| transition.in_msg().is_none())
-            .map(|transition| actors_by_states.get(&transition.in_state.state_trait))
-            .filter(|option| option.is_some())
-            .map(|option| *option.unwrap())
-            .collect();
-        Self {
-            protocol,
-            actors_by_states,
-            client_actors,
-        }
-    }
-
-    /// Returns the [Ident] for the actor that the given state is a part of.
-    fn actor(&self, state: &Ident) -> Option<&Ident> {
-        self.actors_by_states.get(state).copied()
-    }
 
-    fn is_client(&self, actor: &Ident) -> bool {
-        self.client_actors.contains(actor)
-    }
-
-    fn is_client_state(&self, state: &State) -> bool {
-        self.actor(&state.state_trait)
-            .map(|actor| self.is_client(actor))
-            .unwrap_or(false)
-    }
-
-    const UNDECLARED_STATE_ERR: &str = "State was not declared.";
-    const UNUSED_STATE_ERR: &str = "State was declared but never used.";
-
-    /// Verifies that every state which is used has been declared, except for the End state.
+    /// Verifies that every state which is declared is actually used.
     fn all_states_declared_and_used(&self) -> MaybeErr {
         let end = Ident::new(End::ident(), Span::call_site());
         let mut declared: HashSet<&Ident> = HashSet::new();
         declared.insert(&end);
-        for actor_def in self.protocol.actor_defs.iter() {
+        for actor_def in self.def().actor_defs.iter() {
             for state in actor_def.states.as_ref().iter() {
                 declared.insert(state);
             }
         }
         let mut used: HashSet<&Ident> = HashSet::with_capacity(declared.len());
-        for transition in self.protocol.transitions.iter() {
+        for transition in self.def().transitions.iter() {
             let in_state = &transition.in_state;
             used.insert(&in_state.state_trait);
-            used.extend(in_state.owned_states.as_ref().iter());
+            used.extend(
+                in_state
+                    .owned_states
+                    .as_ref()
+                    .iter()
+                    .map(|ident| ident.as_ref()),
+            );
             if let Some(in_msg) = transition.in_msg() {
-                used.extend(in_msg.owned_states.as_ref().iter());
+                used.extend(
+                    in_msg
+                        .owned_states
+                        .as_ref()
+                        .iter()
+                        .map(|ident| ident.as_ref()),
+                );
             }
             for out_states in transition.out_states.as_ref().iter() {
                 used.insert(&out_states.state_trait);
-                used.extend(out_states.owned_states.as_ref().iter());
+                used.extend(
+                    out_states
+                        .owned_states
+                        .as_ref()
+                        .iter()
+                        .map(|ident| ident.as_ref()),
+                );
             }
             // We don't have to check the states referred to in out_msgs because the
             // receivers_and_senders_matched method ensures that each of these exists in a receiver
@@ -99,25 +67,23 @@ impl<'a> ProtocolValidator<'a> {
         }
         let undeclared: MaybeErr = used
             .difference(&declared)
-            .map(|ident| syn::Error::new(ident.span(), Self::UNDECLARED_STATE_ERR))
+            .map(|ident| syn::Error::new(ident.span(), error::msgs::UNDECLARED_STATE_ERR))
             .collect();
         let unused: MaybeErr = declared
             .difference(&used)
             .filter(|ident| **ident != End::ident())
-            .map(|ident| syn::Error::new(ident.span(), Self::UNUSED_STATE_ERR))
+            .map(|ident| syn::Error::new(ident.span(), error::msgs::UNUSED_STATE_ERR))
             .collect();
         undeclared.combine(unused)
     }
 
-    const UNMATCHED_SENDER_ERR: &str = "No receiver found for message type.";
-    const UNMATCHED_RECEIVER_ERR: &str = "No sender found for message type.";
-
     /// Ensures that the recipient state for every sent message has a receiving transition
-    /// defined, and every receiver has a sender.
+    /// defined, and every receiver has a sender. Note that each message isn't required to have a
+    /// unique sender or a unique receiver, just that at least one of each much be defined.
     fn receivers_and_senders_matched(&self) -> MaybeErr {
         let mut senders: HashSet<(&State, &Message)> = HashSet::new();
         let mut receivers: HashSet<(&State, &Message)> = HashSet::new();
-        for transition in self.protocol.transitions.iter() {
+        for transition in self.def().transitions.iter() {
             if let Some(msg) = transition.in_msg() {
                 receivers.insert((&transition.in_state, msg));
             }
@@ -131,24 +97,23 @@ impl<'a> ProtocolValidator<'a> {
         }
         let extra_senders: MaybeErr = senders
             .difference(&receivers)
-            .map(|pair| syn::Error::new(pair.1.msg_type.span(), Self::UNMATCHED_SENDER_ERR))
+            .map(|pair| syn::Error::new(pair.1.msg_type.span(), error::msgs::UNMATCHED_SENDER_ERR))
             .collect();
         let extra_receivers: MaybeErr = receivers
             .difference(&senders)
-            .map(|pair| syn::Error::new(pair.1.msg_type.span(), Self::UNMATCHED_RECEIVER_ERR))
+            .map(|pair| {
+                syn::Error::new(pair.1.msg_type.span(), error::msgs::UNMATCHED_RECEIVER_ERR)
+            })
             .collect();
         extra_senders.combine(extra_receivers)
     }
 
-    const UNDELIVERABLE_ERR: &str =
-        "Receiver must either be a service, an owned state, or an out state, or the message must be a reply.";
-
     /// Checks that messages are only sent to destinations which are either services, states
     /// which are owned by the sender, listed in the output states, or that the message is a
     /// reply.
     fn no_undeliverable_msgs(&self) -> MaybeErr {
         let mut err = MaybeErr::none();
-        for transition in self.protocol.transitions.iter() {
+        for transition in self.def().transitions.iter() {
             let mut allowed_states: Option<HashSet<&Ident>> = None;
             for dest in transition.out_msgs.as_ref().iter() {
                 if dest.msg.is_reply() {
@@ -162,15 +127,22 @@ impl<'a> ProtocolValidator<'a> {
                                 .out_states
                                 .as_ref()
                                 .iter()
-                                .map(|state| &state.state_trait)
-                                .chain(transition.in_state.owned_states.as_ref().iter())
+                                .map(|state| state.state_trait.as_ref())
+                                .chain(
+                                    transition
+                                        .in_state
+                                        .owned_states
+                                        .as_ref()
+                                        .iter()
+                                        .map(|ident| ident.as_ref()),
+                                )
                                 .collect()
                         });
-                        if !allowed.contains(&dest_state.state_trait) {
+                        if !allowed.contains(dest_state.state_trait.as_ref()) {
                             err = err.combine(
                                 syn::Error::new(
                                     dest_state.state_trait.span(),
-                                    Self::UNDELIVERABLE_ERR,
+                                    error::msgs::UNDELIVERABLE_ERR,
                                 )
                                 .into(),
                             );
@@ -182,15 +154,10 @@ impl<'a> ProtocolValidator<'a> {
         err
     }
 
-    const INVALID_REPLY_ERR: &str =
-        "Replies can only be used in transitions which handle messages.";
-    const MULTIPLE_REPLIES_ERR: &str =
-        "Only a single reply can be sent in response to any message.";
-
     /// Verifies that exactly one reply is sent in response to a previously sent message.
     fn replies_expected(&self) -> MaybeErr {
         let mut err = MaybeErr::none();
-        for transition in self.protocol.transitions.iter() {
+        for transition in self.def().transitions.iter() {
             let replies: Vec<_> = transition
                 .out_msgs
                 .as_ref()
@@ -206,7 +173,10 @@ impl<'a> ProtocolValidator<'a> {
                     replies
                         .iter()
                         .map(|reply| {
-                            syn::Error::new(reply.msg_type.span(), Self::MULTIPLE_REPLIES_ERR)
+                            syn::Error::new(
+                                reply.msg_type.span(),
+                                error::msgs::MULTIPLE_REPLIES_ERR,
+                            )
                         })
                         .collect(),
                 );
@@ -216,7 +186,7 @@ impl<'a> ProtocolValidator<'a> {
                     replies
                         .iter()
                         .map(|reply| {
-                            syn::Error::new(reply.msg_type.span(), Self::INVALID_REPLY_ERR)
+                            syn::Error::new(reply.msg_type.span(), error::msgs::INVALID_REPLY_ERR)
                         })
                         .collect(),
                 );
@@ -225,41 +195,30 @@ impl<'a> ProtocolValidator<'a> {
         err
     }
 
-    const NO_MSG_SENT_OR_RECEIVED_ERR: &str = "A transition must send or receive a message.";
-
-    /// Verifies that either a message is received, or sent by a transition. The rational behind
-    /// this is that if no message is sent or received, then the state transition is unobserved by
-    /// other actors, and so should not be in a protocol.
-    fn msg_sent_or_received(&self) -> MaybeErr {
-        self.protocol
-            .transitions
-            .iter()
-            .filter(|transition| {
-                transition.in_msg().is_none() && transition.out_msgs.as_ref().is_empty()
-            })
-            .map(|transition| syn::Error::new(transition.span(), Self::NO_MSG_SENT_OR_RECEIVED_ERR))
-            .collect()
-    }
-
-    const CLIENT_RECEIVED_NON_REPLY_ERR: &str =
-        "A client actor cannot receive a message which is not a reply.";
-
     /// A client is any actor with a state that sends at least one message when not handling an
-    /// incoming message. Such actors are not allowed to receive any message which are not replies.
+    /// incoming message. Such actors are not allowed to receive any messages which are not replies.
     fn clients_only_receive_replies(&self) -> MaybeErr {
-        self.protocol
-            .transitions
-            .iter()
-            .filter(|transition| {
-                if let Some(msg) = transition.in_msg() {
-                    if !msg.is_reply() {
-                        return self.is_client_state(&transition.in_state);
-                    }
+        self.actors()
+            .values()
+            .filter(|actor| actor.is_client())
+            .flat_map(|actor| {
+                actor
+                    .states()
+                    .values()
+                    .flat_map(|state| state.methods().values())
+            })
+            .filter(|method| {
+                if let Some(msg) = method.def().in_msg() {
+                    !msg.is_reply()
+                } else {
+                    false
                 }
-                false
             })
             .map(|transition| {
-                syn::Error::new(transition.span(), Self::CLIENT_RECEIVED_NON_REPLY_ERR)
+                syn::Error::new(
+                    transition.span(),
+                    error::msgs::CLIENT_RECEIVED_NON_REPLY_ERR,
+                )
             })
             .collect()
     }
@@ -268,64 +227,24 @@ impl<'a> ProtocolValidator<'a> {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::parsing::{ActorDef, Dest, NameDef, Transition};
-
-    macro_rules! assert_ok {
-        ($maybe_err:expr) => {
-            let result: syn::Result<()> = $maybe_err.into();
-            assert!(result.is_ok(), "{}", result.err().unwrap());
-        };
-    }
-
-    macro_rules! assert_err {
-        ($maybe_err:expr, $expected_msg:expr) => {
-            let result: syn::Result<()> = $maybe_err.into();
-            assert!(result.is_err());
-            assert_eq!($expected_msg, result.err().unwrap().to_string());
-        };
-    }
-
-    /// Creates minimal [Protocol] value.
-    fn min_protocol() -> Protocol {
-        const STATE_NAME: &str = "Init";
-        Protocol::new(
-            NameDef::new("Test"),
-            [
-                ActorDef::new("server", [STATE_NAME]),
-                ActorDef::new("client", ["Client"]),
-            ],
-            [
-                Transition::new(
-                    State::new("Client", []),
-                    None,
-                    [State::new("End", [])],
-                    [Dest::new(
-                        DestinationState::Service(State::new(STATE_NAME, [])),
-                        Message::new("Msg", false, []),
-                    )],
-                ),
-                Transition::new(
-                    State::new(STATE_NAME, []),
-                    Some(Message::new("Msg", false, [])),
-                    [State::new("End", [])],
-                    [],
-                ),
-            ],
-        )
-    }
+    use crate::{
+        error::{assert_err, assert_ok},
+        parsing::{ActorDef, Dest, NameDef, Protocol, Transition},
+    };
 
     #[test]
     fn all_states_declared_and_used_ok() {
-        let protocol = min_protocol();
-        let result = ProtocolValidator::new(&protocol).all_states_declared_and_used();
+        let input = ProtocolModel::new(Protocol::minimal()).unwrap();
 
-        assert_ok!(result);
+        let result = input.all_states_declared_and_used();
+
+        assert_ok(result);
     }
 
     #[test]
     fn all_states_declared_and_used_end_not_used_ok() {
         const STATE_NAME: &str = "Init";
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("Test"),
             [ActorDef::new("actor", [STATE_NAME])],
             [Transition::new(
@@ -334,17 +253,17 @@ mod tests {
                 [State::new(STATE_NAME, [])],
                 [],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.all_states_declared_and_used();
 
-        assert_ok!(result);
+        assert_ok(result);
     }
 
     #[test]
     fn all_states_declared_and_used_undeclared_err() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("Undeclared"),
             [ActorDef::new("actor", ["Init"])],
             [Transition::new(
@@ -353,17 +272,17 @@ mod tests {
                 [State::new("Next", [])],
                 [],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.all_states_declared_and_used();
 
-        assert_err!(result, ProtocolValidator::UNDECLARED_STATE_ERR);
+        assert_err(result, error::msgs::UNDECLARED_STATE_ERR);
     }
 
     #[test]
     fn all_states_declared_and_used_undeclared_out_state_owned_err() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("Undeclared"),
             [ActorDef::new("actor", ["Init", "Next"])],
             [Transition::new(
@@ -372,17 +291,17 @@ mod tests {
                 [State::new("Init", []), State::new("Next", ["Undeclared"])],
                 [],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.all_states_declared_and_used();
 
-        assert_err!(result, ProtocolValidator::UNDECLARED_STATE_ERR);
+        assert_err(result, error::msgs::UNDECLARED_STATE_ERR);
     }
 
     #[test]
     fn all_states_declared_and_used_undeclared_in_state_owned_err() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("Undeclared"),
             [ActorDef::new("actor", ["Init", "Next"])],
             [Transition::new(
@@ -391,17 +310,17 @@ mod tests {
                 [State::new("Next", [])],
                 [],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.all_states_declared_and_used();
 
-        assert_err!(result, ProtocolValidator::UNDECLARED_STATE_ERR);
+        assert_err(result, error::msgs::UNDECLARED_STATE_ERR);
     }
 
     #[test]
     fn all_states_declared_and_used_unused_err() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("Unused"),
             [ActorDef::new("actor", ["Init", "Extra"])],
             [Transition::new(
@@ -410,25 +329,26 @@ mod tests {
                 [State::new("End", [])],
                 [],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.all_states_declared_and_used();
 
-        assert_err!(result, ProtocolValidator::UNUSED_STATE_ERR);
+        assert_err(result, error::msgs::UNUSED_STATE_ERR);
     }
 
     #[test]
     fn receivers_and_senders_matched_ok() {
-        let protocol = min_protocol();
-        let result = ProtocolValidator::new(&protocol).receivers_and_senders_matched();
+        let input = ProtocolModel::new(Protocol::minimal()).unwrap();
+
+        let result = input.receivers_and_senders_matched();
 
-        assert_ok!(result);
+        assert_ok(result);
     }
 
     #[test]
     fn receivers_and_senders_matched_unmatched_sender_err() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("Unbalanced"),
             [ActorDef::new("actor", ["Init"])],
             [Transition::new(
@@ -440,17 +360,17 @@ mod tests {
                     Message::new("Msg", false, []),
                 )],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.receivers_and_senders_matched();
 
-        assert_err!(result, ProtocolValidator::UNMATCHED_SENDER_ERR);
+        assert_err(result, error::msgs::UNMATCHED_SENDER_ERR);
     }
 
     #[test]
     fn receivers_and_senders_matched_unmatched_receiver_err() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("Unbalanced"),
             [ActorDef::new("actor", ["Init"])],
             [Transition::new(
@@ -459,25 +379,26 @@ mod tests {
                 [State::new("Init", [])],
                 [],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.receivers_and_senders_matched();
 
-        assert_err!(result, ProtocolValidator::UNMATCHED_RECEIVER_ERR);
+        assert_err(result, error::msgs::UNMATCHED_RECEIVER_ERR);
     }
 
     #[test]
     fn no_undeliverable_msgs_ok() {
-        let protocol = min_protocol();
-        let result = ProtocolValidator::new(&protocol).no_undeliverable_msgs();
+        let input = ProtocolModel::new(Protocol::minimal()).unwrap();
+
+        let result = input.no_undeliverable_msgs();
 
-        assert_ok!(result);
+        assert_ok(result);
     }
 
     #[test]
     fn no_undeliverable_msgs_reply_ok() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("Undeliverable"),
             [ActorDef::new("actor", ["Listening", "Client"])],
             [Transition::new(
@@ -489,17 +410,17 @@ mod tests {
                     Message::new("Msg", true, []),
                 )],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.no_undeliverable_msgs();
 
-        assert_ok!(result);
+        assert_ok(result);
     }
 
     #[test]
     fn no_undeliverable_msgs_service_ok() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("Undeliverable"),
             [ActorDef::new("actor", ["Client", "Server"])],
             [Transition::new(
@@ -511,17 +432,17 @@ mod tests {
                     Message::new("Msg", false, []),
                 )],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.no_undeliverable_msgs();
 
-        assert_ok!(result);
+        assert_ok(result);
     }
 
     #[test]
     fn no_undeliverable_msgs_owned_ok() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("Undeliverable"),
             [ActorDef::new("actor", ["FileClient", "FileHandle"])],
             [Transition::new(
@@ -533,17 +454,17 @@ mod tests {
                     Message::new("FileOp", false, []),
                 )],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.no_undeliverable_msgs();
 
-        assert_ok!(result);
+        assert_ok(result);
     }
 
     #[test]
     fn no_undeliverable_msgs_err() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("Undeliverable"),
             [ActorDef::new("actor", ["Client", "Server"])],
             [Transition::new(
@@ -555,17 +476,17 @@ mod tests {
                     Message::new("Msg", false, []),
                 )],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.no_undeliverable_msgs();
 
-        assert_err!(result, ProtocolValidator::UNDELIVERABLE_ERR);
+        assert_err(result, error::msgs::UNDELIVERABLE_ERR);
     }
 
     #[test]
     fn replies_expected_ok() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("ValidReplies"),
             [ActorDef::new("actor", ["Client", "Server"])],
             [Transition::new(
@@ -577,17 +498,17 @@ mod tests {
                     Message::new("Msg", true, []),
                 )],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.replies_expected();
 
-        assert_ok!(result);
+        assert_ok(result);
     }
 
     #[test]
     fn replies_expected_invalid_reply_err() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("ValidReplies"),
             [ActorDef::new("actor", ["Client", "Server"])],
             [Transition::new(
@@ -599,17 +520,17 @@ mod tests {
                     Message::new("Msg", true, []),
                 )],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.replies_expected();
 
-        assert_err!(result, ProtocolValidator::INVALID_REPLY_ERR);
+        assert_err(result, error::msgs::INVALID_REPLY_ERR);
     }
 
     #[test]
     fn replies_expected_multiple_replies_err() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("ValidReplies"),
             [ActorDef::new("actor", ["Client", "OtherClient", "Server"])],
             [Transition::new(
@@ -627,77 +548,17 @@ mod tests {
                     ),
                 ],
             )],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.replies_expected();
 
-        assert_err!(result, ProtocolValidator::MULTIPLE_REPLIES_ERR);
-    }
-
-    #[test]
-    fn msg_sent_or_received_msg_received_ok() {
-        let protocol = Protocol::new(
-            NameDef::new("Test"),
-            [ActorDef::new("actor", ["Init"])],
-            [Transition::new(
-                State::new("Init", []),
-                Some(Message::new("Activate", false, [])),
-                [State::new("End", [])],
-                [],
-            )],
-        );
-        let input = ProtocolValidator::new(&protocol);
-
-        let result = input.msg_sent_or_received();
-
-        assert_ok!(result);
-    }
-
-    #[test]
-    fn msg_sent_or_received_msg_sent_ok() {
-        let protocol = Protocol::new(
-            NameDef::new("Test"),
-            [ActorDef::new("actor", ["First", "Second"])],
-            [Transition::new(
-                State::new("First", []),
-                None,
-                [State::new("First", [])],
-                [Dest::new(
-                    DestinationState::Individual(State::new("Second", [])),
-                    Message::new("Msg", false, []),
-                )],
-            )],
-        );
-        let input = ProtocolValidator::new(&protocol);
-
-        let result = input.msg_sent_or_received();
-
-        assert_ok!(result);
-    }
-
-    #[test]
-    fn msg_sent_or_received_neither_err() {
-        let protocol = Protocol::new(
-            NameDef::new("Test"),
-            [ActorDef::new("actor", ["First"])],
-            [Transition::new(
-                State::new("First", []),
-                None,
-                [State::new("First", [])],
-                [],
-            )],
-        );
-        let input = ProtocolValidator::new(&protocol);
-
-        let result = input.msg_sent_or_received();
-
-        assert_err!(result, ProtocolValidator::NO_MSG_SENT_OR_RECEIVED_ERR);
+        assert_err(result, error::msgs::MULTIPLE_REPLIES_ERR);
     }
 
     #[test]
     fn clients_only_receive_replies_ok() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("ClientReplies"),
             [
                 ActorDef::new("client", ["Client", "Waiting"]),
@@ -735,17 +596,17 @@ mod tests {
                     [],
                 ),
             ],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.clients_only_receive_replies();
 
-        assert_ok!(result);
+        assert_ok(result);
     }
 
     #[test]
     fn clients_only_receive_replies_err() {
-        let protocol = Protocol::new(
+        let input = ProtocolModel::new(Protocol::new(
             NameDef::new("ClientReplies"),
             [
                 ActorDef::new("client", ["Client", "Waiting"]),
@@ -789,11 +650,11 @@ mod tests {
                     [],
                 ),
             ],
-        );
-        let input = ProtocolValidator::new(&protocol);
+        ))
+        .unwrap();
 
         let result = input.clients_only_receive_replies();
 
-        assert_err!(result, ProtocolValidator::CLIENT_RECEIVED_NON_REPLY_ERR);
+        assert_err(result, error::msgs::CLIENT_RECEIVED_NON_REPLY_ERR);
     }
 }

+ 16 - 6
crates/btproto/tests/protocol_tests.rs

@@ -45,6 +45,15 @@ fn minimal_syntax() {
             ready(Ok(End))
         }
     }
+
+    struct ClientState;
+
+    impl Client for ClientState {
+        type SendMsgFut = Ready<Result<End>>;
+        fn send_msg(self, _msg: Msg) -> Self::SendMsgFut {
+            ready(Ok(End))
+        }
+    }
 }
 
 #[test]
@@ -69,9 +78,9 @@ fn reply() {
 
     impl Listening for ListeningState {
         type HandlePingListening = Self;
-        type HandlePingFut = Ready<Result<Self>>;
+        type HandlePingFut = Ready<Result<(Self, <Ping as CallMsg>::Reply)>>;
         fn handle_ping(self, _msg: Ping) -> Self::HandlePingFut {
-            ready(Ok(self))
+            ready(Ok((self, ())))
         }
     }
 
@@ -79,17 +88,18 @@ fn reply() {
 
     impl Client for ClientState {
         type SendPingWaiting = WaitingState;
-        type SendPingFut = Ready<Result<WaitingState>>;
-        fn send_ping(self) -> Self::SendPingFut {
-            ready(Ok(WaitingState))
+        type SendPingFut = Ready<Result<(WaitingState, <Ping as CallMsg>::Reply)>>;
+        fn send_ping(self, _ping: Ping) -> Self::SendPingFut {
+            ready(Ok((WaitingState, ())))
         }
     }
 
     struct WaitingState;
 
+    // TODO: This state should not be generated, as it is never observed.
     impl Waiting for WaitingState {
         type HandlePingReplyFut = Ready<Result<End>>;
-        fn handle_ping_reply(self, _msg: Ping) -> Self::HandlePingReplyFut {
+        fn handle_ping_reply(self, _msg: ()) -> Self::HandlePingReplyFut {
             ready(Ok(End))
         }
     }