use std::{ collections::{HashMap, HashSet, LinkedList}, hash::Hash, rc::Rc, }; use btrun::model::End; use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; use crate::{ case_convert::CaseConvert, error, parsing::{ActorDef, Dest, GetSpan, Message, Protocol, State, Transition}, parsing::{DestinationState, MessageReplyPart}, }; pub(crate) struct ProtocolModel { def: Protocol, msg_lookup: MsgLookup, actor_lookup: ActorLookup, use_statements: TokenStream, actors: HashMap, ActorModel>, msg_enum_name: Ident, msg_enum_kinds_name: Ident, end_ident: Ident, end_struct: TokenStream, actor_id_param: Ident, runtime_param: Ident, msg_ident: Ident, reply_ident: Ident, from_ident: Ident, actor_name_ident: Ident, init_state_var: Ident, init_state_type_param: Ident, state_type_param: Ident, new_state_type_param: Ident, new_state_method: Ident, state_var: Ident, } impl ProtocolModel { pub(crate) fn new(def: Protocol) -> syn::Result { let get_transitions = || def.transitions.iter().map(|x| x.as_ref()); let actor_lookup = ActorLookup::new(def.actor_defs.iter().map(|x| x.as_ref()), get_transitions())?; let mut is_client = HashMap::, bool>::new(); // First mark all actors as not clients. for actor_def in def.actor_defs.iter() { is_client.insert(actor_def.actor.clone(), false); } // For every actor which is not spawned by another actor, mark it as a client if its // initial state receives no messages, then mark all of the actors it spawns as clients. for actor_def in def.actor_defs.iter() { if !actor_lookup.parents(actor_def.actor.as_ref()).is_empty() { continue; } let init_state = actor_def.states.as_ref().first().unwrap(); let init_state_receives_no_msgs = def .transitions .iter() .filter(|transition| { transition.in_state.state_trait.as_ref() == init_state.as_ref() }) .all(|transition| transition.in_msg().is_none()); if init_state_receives_no_msgs { mark_all_progeny(&actor_lookup, &mut is_client, actor_def.actor.as_ref()); } } fn mark_all_progeny( actor_lookup: &ActorLookup, is_client: &mut HashMap, bool>, actor_name: &Ident, ) { *is_client.get_mut(actor_name).unwrap() = true; let children = actor_lookup.children(actor_name); for child in children { *is_client.get_mut(child).unwrap() = true; mark_all_progeny(actor_lookup, is_client, child.as_ref()); } } let msg_lookup = MsgLookup::new(get_transitions()); let mut actors = HashMap::new(); for actor_def in def.actor_defs.iter() { let actor_name = &actor_def.actor; let actor_states = actor_lookup.actor_states(actor_name.as_ref()); let transitions_by_state = actor_states.iter().map(|state_name| { let transitions = def .transitions .iter() .filter(|transition| { state_name.as_ref() == transition.in_state.state_trait.as_ref() }) .cloned(); (state_name.clone(), transitions) }); let is_client = *is_client.get(&actor_def.actor).unwrap(); let is_service = actor_lookup.service_providers.contains(&actor_def.actor); let kind = ActorKind::new(is_client, is_service).ok_or_else(|| { syn::Error::new(actor_def.actor.span(), error::msgs::CLIENT_USED_IN_SERVICE) })?; let actor = ActorModel::new(actor_def.clone(), &msg_lookup, kind, transitions_by_state)?; actors.insert(actor_name.clone(), actor); } let end_ident = format_ident!("{}", End::ident()); Ok(Self { msg_enum_name: Self::make_msg_enum_name(&def), msg_enum_kinds_name: Self::make_msg_enum_kinds_name(&def), use_statements: Self::make_use_statements(), def, msg_lookup, actor_lookup, actors, end_struct: quote! { ::btrun::model::#end_ident }, end_ident, actor_id_param: format_ident!("actor_id"), runtime_param: format_ident!("runtime"), msg_ident: format_ident!("msg"), reply_ident: format_ident!("reply"), from_ident: format_ident!("from"), actor_name_ident: format_ident!("actor_name"), init_state_var: format_ident!("init"), init_state_type_param: format_ident!("Init"), state_type_param: format_ident!("State"), new_state_type_param: format_ident!("NewState"), new_state_method: format_ident!("new_state"), state_var: format_ident!("state"), }) } fn make_msg_enum_name(def: &Protocol) -> Ident { format_ident!("{}Msgs", def.name_def.name) } fn make_msg_enum_kinds_name(def: &Protocol) -> Ident { format_ident!("{}MsgKinds", def.name_def.name) } fn make_use_statements() -> TokenStream { quote! { use ::btlib::bterr; use ::btrun::{ log, Mailbox, model::{ Envelope, ControlMsg, Named, TransResult, ActorError, ActorErrorPayload, TransKind, Mutex, } }; use ::std::sync::Arc; } } pub(crate) fn def(&self) -> &Protocol { &self.def } pub(crate) fn msg_lookup(&self) -> &MsgLookup { &self.msg_lookup } pub(crate) fn actor_lookup(&self) -> &ActorLookup { &self.actor_lookup } pub(crate) fn actors(&self) -> &HashMap, ActorModel> { &self.actors } pub(crate) fn actors_iter(&self) -> impl Iterator { self.actors.values() } pub(crate) fn states_iter(&self) -> impl Iterator { self.actors_iter().flat_map(|actor| actor.states().values()) } #[cfg(test)] pub(crate) fn methods_iter(&self) -> impl Iterator { self.states_iter() .flat_map(|state| state.methods().values()) } #[cfg(test)] pub(crate) fn outputs_iter(&self) -> impl Iterator { self.methods_iter() .flat_map(|method| method.output_values().iter()) } /// Returns the tokens for the use statements which bring the types used inside spawn functions /// into scope. pub(crate) fn use_statements(&self) -> &TokenStream { &self.use_statements } /// The name of the message enum used by this protocol. pub(crate) fn msg_enum_ident(&self) -> &Ident { &self.msg_enum_name } /// The name of the message kinds enum used by this protocol. pub(crate) fn msg_enum_kinds_ident(&self) -> &Ident { &self.msg_enum_kinds_name } /// The [Ident] containing [End::ident]. pub(crate) fn end_ident(&self) -> &Ident { &self.end_ident } /// Returns a token stream which references the [End] struct using a fully qualified path. #[allow(dead_code)] pub(crate) fn end_struct(&self) -> &TokenStream { &self.end_struct } /// The name of the [btrun::ActorId] parameter in an actor closure. pub(crate) fn actor_id_param(&self) -> &Ident { &self.actor_id_param } /// The name of the `&'static `[Runtime] parameter in an actor closure. pub(crate) fn runtime_param(&self) -> &Ident { &self.runtime_param } /// The name of the variable used to hold the `from` field of a [btrun::Envelope]. #[allow(clippy::wrong_self_convention)] pub(crate) fn from_ident(&self) -> &Ident { &self.from_ident } /// The name of the variable used to hold the `reply` field from a [btrun::Envelope]. pub(crate) fn reply_ident(&self) -> &Ident { &self.reply_ident } /// The name of the variable used to hold the `msg` field from a [btrun::Envelope]. pub(crate) fn msg_ident(&self) -> &Ident { &self.msg_ident } /// The name of the local variable in an actor closure used to hold the actor's name. pub(crate) fn actor_name_ident(&self) -> &Ident { &self.actor_name_ident } /// The identifier for the variable holding the initial state of an actor. pub(crate) fn init_state_var(&self) -> &Ident { &self.init_state_var } /// The identifier for the `Init` type parameter used in client enums, handles, and spawn /// functions. pub(crate) fn init_type_param(&self) -> &Ident { &self.init_state_type_param } /// The type parameter for client handles which indicates the handle's current state. pub(crate) fn state_type_param(&self) -> &Ident { &self.state_type_param } /// The type parameter for client handles which indicates the new state the handle is /// transitioning to. pub(crate) fn new_state_type_param(&self) -> &Ident { &self.new_state_type_param } /// The name of the method used to transition client handles to a new type state. pub(crate) fn new_state_method(&self) -> &Ident { &self.new_state_method } /// The identifier of variable used to hold the current state of an actor, and the field of /// a client handle which holds the current state. pub(crate) fn state_var(&self) -> &Ident { &self.state_var } /// Returns the name of the field used to hold the shared state in a client handle. pub(crate) fn state_field(&self) -> &Ident { &self.state_var } /// Returns an iterator over the [Ident]s for each of the states in [actor]. pub(crate) fn state_idents<'a>( &'a self, actor: &'a ActorModel, ) -> impl Iterator { actor .states() .values() .map(|state| state.name()) .chain(std::iter::once(&self.end_ident)) } fn get_actor<'a>(&'a self, actor_name: &Ident) -> &'a ActorModel { self.actors .get(actor_name) .unwrap_or_else(|| panic!("Invalid actor name: '{actor_name}'")) } pub(crate) fn get_state<'a>(&'a self, state_name: &Ident) -> &'a StateModel { let actor_name = self.actor_lookup().actor_with_state(state_name); let actor = self.get_actor(actor_name); actor .states() .get(state_name) .unwrap_or_else(|| panic!("Actor {actor_name} doesn't contain state {state_name}.")) } /// Returns a struct with the type params and their associated constraints for the given actor's /// spawn function. pub(crate) fn type_param_info_for<'a>(&'a self, actor_name: &Ident) -> TypeParamInfo<'a> { let mut type_params = Vec::<&Ident>::new(); let mut constraints = Vec::::new(); // We do a breadth-first traversal over the associated types referenced by this actor, // starting with it's initial state. let mut visited = HashSet::<&Ident>::new(); let mut queue = LinkedList::<&Ident>::new(); queue.push_front(self.get_actor(actor_name).init_state().name()); while !queue.is_empty() { let state_name = queue.pop_back().unwrap(); visited.insert(state_name); let state = self.get_state(state_name); let mut any = false; let eqs = state .out_states_and_assoc_types() .map(|(out_state, assoc_type)| { any = true; if !visited.contains(out_state) { queue.push_front(out_state); } let type_param = self.get_state(out_state).type_param(); quote! { #assoc_type = #type_param } }); let constraint_list = quote! { #( #eqs ),* }; let constraint = if any { quote! { < #constraint_list > } } else { constraint_list }; let type_param = state.type_param(); type_params.push(type_param); let state_name = state.name(); let constraint = quote! { #type_param: 'static + #state_name #constraint }; constraints.push(constraint); } TypeParamInfo { constraints, type_params, } } /// Returns an iterator of the names of the states which the given state transitions to. pub(crate) fn next_states<'a>( &'a self, state: &'a StateModel, ) -> impl Iterator { state .methods() .values() // The first output of a method is the state this state transitions to. .flat_map(|method| method.output_values().first().zip(Some(method.name()))) .filter(|(output, method_name)| { let state_trait = output.kind().state_trait() .unwrap_or_else(|| panic!("The first output of method {method_name} in state {} was not the correct kind.", state.name())); state_trait.as_ref() != End::ident() }) .map(|(output, _)| output) } } pub(crate) struct TypeParamInfo<'a> { pub(crate) type_params: Vec<&'a Ident>, pub(crate) constraints: Vec, } #[derive(Clone, Debug)] /// A categorization of different actor types based on their messaging behavior. pub(crate) enum ActorKind { /// A client is an actor which is not spawned by another actor (i.e. has no parents) and /// whose initial state has no transition which receives a message, or which is spawned by a /// client. Client, /// Any actor with a state that appears wrapped in `service()` is in this category. Service, /// Any actor not in either of the other two categories falls into this one. Worker, } impl ActorKind { fn new(is_client: bool, is_service: bool) -> Option { match (is_client, is_service) { (true, false) => Some(Self::Client), (false, true) => Some(Self::Service), (false, false) => Some(Self::Worker), (true, true) => None, } } pub(crate) fn is_client(&self) -> bool { matches!(self, Self::Client) } } impl Copy for ActorKind {} pub(crate) struct ActorModel { #[allow(dead_code)] def: Rc, kind: ActorKind, state_enum_ident: Ident, states: HashMap, StateModel>, spawn_function_ident: Option, handle_struct_ident: Option, } impl ActorModel { fn new( def: Rc, messages: &MsgLookup, kind: ActorKind, state_iter: S, ) -> syn::Result where S: IntoIterator, T)>, T: IntoIterator>, { let transitions: HashMap<_, Vec<_>> = state_iter .into_iter() .map(|(name, transitions)| (name, transitions.into_iter().collect())) .collect(); let mut states = HashMap::new(); let is_client = matches!(kind, ActorKind::Client); for (name, transitions) in transitions.into_iter() { let state = StateModel::new(name.clone(), messages, transitions, is_client)?; if let Some(prev) = states.insert(name, state) { panic!( "States are not being grouped by actor correctly. Duplicate state name: '{}'", prev.name ); } } let actor_name = def.actor.as_ref(); Ok(Self { state_enum_ident: Self::make_state_enum_ident(actor_name), spawn_function_ident: Self::make_spawn_function_ident(kind, actor_name), handle_struct_ident: Self::make_handle_struct_ident(kind, actor_name), def, kind, states, }) } fn make_state_enum_ident(actor_name: &Ident) -> Ident { format_ident!("{}State", actor_name.to_string().snake_to_pascal()) } fn make_spawn_function_ident(kind: ActorKind, actor_name: &Ident) -> Option { // Services are registered, not spawned, so they have no spawn function. if let ActorKind::Service = kind { None } else { Some(format_ident!("spawn_{actor_name}")) } } fn make_handle_struct_ident(kind: ActorKind, actor_name: &Ident) -> Option { if let ActorKind::Client = kind { Some(format_ident!("{}Handle", actor_name.snake_to_pascal())) } else { None } } pub(crate) fn def(&self) -> &ActorDef { &self.def } pub(crate) fn name(&self) -> &Ident { &self.def.actor } pub(crate) fn kind(&self) -> ActorKind { self.kind } pub(crate) fn states(&self) -> &HashMap, StateModel> { &self.states } pub(crate) fn init_state(&self) -> &StateModel { // It's a syntax error to have an IdentArray with no states in it, so this unwrap // shouldn't panic. let init = self.def.states.as_ref().first().unwrap(); self.states.get(init).unwrap() } pub(crate) fn state_enum_ident(&self) -> &Ident { &self.state_enum_ident } pub(crate) fn spawn_function_ident(&self) -> Option<&Ident> { self.spawn_function_ident.as_ref() } pub(crate) fn handle_struct_ident(&self) -> Option<&Ident> { self.handle_struct_ident.as_ref() } } pub(crate) struct StateModel { name: Rc, methods: HashMap, MethodModel>, type_param: Ident, } impl StateModel { fn new( name: Rc, messages: &MsgLookup, transitions: T, part_of_client: bool, ) -> syn::Result where T: IntoIterator>, { let mut methods = HashMap::new(); for transition in transitions.into_iter() { let transition_span = transition.span(); let method = MethodModel::new(transition, messages, part_of_client)?; if methods.insert(method.name.clone(), method).is_some() { return Err(syn::Error::new( transition_span, error::msgs::DUPLICATE_TRANSITION, )); } } Ok(Self { type_param: Self::make_generic_type_param(name.as_ref()), name, methods, }) } fn make_generic_type_param(name: &Ident) -> Ident { format_ident!("T{name}") } pub(crate) fn name(&self) -> &Ident { self.name.as_ref() } pub(crate) fn methods(&self) -> &HashMap, MethodModel> { &self.methods } pub(crate) fn type_param(&self) -> &Ident { &self.type_param } fn out_states_and_assoc_types(&self) -> impl Iterator { self.methods().values().flat_map(|method| { method.output_values().iter().flat_map(|output| { output .kind() .state_trait() .map(|ptr| ptr.as_ref()) .zip(output.assoc_type()) }) }) } } impl GetSpan for StateModel { fn span(&self) -> Span { self.name.span() } } #[cfg_attr(test, derive(Debug))] pub(crate) struct MethodModel { def: Rc, name: Rc, handle_name: Option, inputs: Vec, outputs: Vec, future: Ident, } impl MethodModel { fn new(def: Rc, messages: &MsgLookup, part_of_client: bool) -> syn::Result { let (name, client_handle_name) = Self::new_name(def.as_ref())?; let name = Rc::new(name); let type_prefix = name.snake_to_pascal(); Ok(Self { name, handle_name: client_handle_name, inputs: Self::make_inputs(def.as_ref(), &type_prefix, messages, part_of_client), outputs: Self::make_outputs(def.as_ref(), &type_prefix, messages, part_of_client), future: format_ident!("{type_prefix}Fut"), def, }) } fn new_name(def: &Transition) -> syn::Result<(Ident, Option)> { let pair = if let Some(msg) = def.in_msg() { let name = format_ident!("handle_{}", msg.variant().pascal_to_snake()); (name, None) } else { let mut dests = def.out_msgs.as_ref().iter(); let mut msg_names = String::new(); if let Some(dest) = dests.next() { msg_names.push_str(dest.msg.variant().pascal_to_snake().as_str()); } else { return Err(syn::Error::new( def.span(), error::msgs::NO_MSG_SENT_OR_RECEIVED, )); } for dest in dests { msg_names.push('_'); msg_names.push_str(dest.msg.variant().pascal_to_snake().as_str()); } let name = format_ident!("on_send_{msg_names}"); let handle_name = format_ident!("send_{msg_names}"); (name, Some(handle_name)) }; Ok(pair) } fn make_inputs( def: &Transition, type_prefix: &str, messages: &MsgLookup, part_of_client: bool, ) -> Vec { let mut inputs = Vec::new(); if let Some(in_msg) = def.in_msg() { let kind = ValueKind::Msg { def: in_msg.clone(), }; inputs.push(ValueModel::new(kind, type_prefix)); } if part_of_client { for dest in def.out_msgs.as_ref().iter() { let kind = ValueKind::new_dest(dest.clone(), messages, part_of_client); inputs.push(ValueModel::new(kind, type_prefix)); } } inputs } fn make_outputs( def: &Transition, type_prefix: &str, messages: &MsgLookup, part_of_client: bool, ) -> Vec { let mut outputs = Vec::new(); for state in def.out_states.as_ref().iter() { let kind = ValueKind::State { def: state.clone() }; outputs.push(ValueModel::new(kind, type_prefix)); } if part_of_client { if def.in_msg().is_none() { outputs.push(ValueModel::new(ValueKind::Return, type_prefix)); } } else { for dest in def.out_msgs.as_ref().iter() { let kind = ValueKind::new_dest(dest.clone(), messages, part_of_client); outputs.push(ValueModel::new(kind, type_prefix)); } } outputs } /// Returns the input associated with the message this method is handling, or [None] if this /// method is not handling a message. pub(crate) fn msg_received_input(&self) -> Option<&Message> { self.def.in_msg().map(|rc| rc.as_ref()) } pub(crate) fn def(&self) -> &Transition { self.def.as_ref() } pub(crate) fn name(&self) -> &Rc { &self.name } pub(crate) fn inputs(&self) -> &Vec { &self.inputs } pub(crate) fn output_values(&self) -> &Vec { &self.outputs } pub(crate) fn future(&self) -> &Ident { &self.future } /// The name of this method in a client handle, if this method is part of a client actor. pub(crate) fn handle_name(&self) -> Option<&Ident> { self.handle_name.as_ref() } /// Returns the output for the state this method transitions into. pub(crate) fn next_state(&self) -> &ValueModel { self.output_values().get(0).unwrap() } /// Returns an iterator over the output variables returned by this method. pub(crate) fn output_vars(&self) -> impl Iterator { self.output_values() .iter() .flat_map(|output| output.output_var_name()) } } impl GetSpan for MethodModel { fn span(&self) -> Span { self.def.span() } } #[cfg_attr(test, derive(Debug))] pub(crate) struct ValueModel { kind: ValueKind, type_name: Option, assoc_type: Option, decl: Option, var_name: Ident, } impl ValueModel { fn new(kind: ValueKind, type_prefix: &str) -> Self { let (decl, type_name, assoc_type) = match &kind { ValueKind::Msg { def, .. } => { let decl = None; let msg_type = def.msg_type.as_ref(); let type_name = Some(quote! { #msg_type }); let assoc_type = None; (decl, type_name, assoc_type) } ValueKind::State { def, .. } => { let state_trait = def.state_trait.as_ref(); if state_trait == End::ident() { let decl = None; let end_ident = format_ident!("{}", End::ident()); let type_name = Some(quote! { ::btrun::model::#end_ident }); let assoc_type = None; (decl, type_name, assoc_type) } else { let assoc_type_ident = format_ident!("{type_prefix}{}", state_trait); let decl = Some(quote! { type #assoc_type_ident: #state_trait; }); let type_name = Some(quote! { Self::#assoc_type_ident }); let assoc_type = Some(assoc_type_ident); (decl, type_name, assoc_type) } } ValueKind::Dest { msg_type, reply_type, part_of_client, .. } => { let decl = None; let type_name = if *part_of_client { if reply_type.is_some() { Some(quote! { <#msg_type as ::btrun::model::CallMsg>::Reply }) } else { None } } else { Some(quote! { #msg_type }) }; let assoc_type = None; (decl, type_name, assoc_type) } ValueKind::Return { .. } => { let assoc_type_ident = format_ident!("{type_prefix}Return"); let decl = Some(quote! { type #assoc_type_ident; }); let type_name = Some(quote! { Self::#assoc_type_ident }); let assoc_type = Some(assoc_type_ident); (decl, type_name, assoc_type) } }; Self { var_name: kind.var_name(), type_name, decl, kind, assoc_type, } } fn span(&self) -> Span { self.kind.span() } /// The code for specifying the type of this input. pub(crate) fn type_name(&self) -> Option<&TokenStream> { self.type_name.as_ref() } pub(crate) fn decl(&self) -> Option<&TokenStream> { self.decl.as_ref() } pub(crate) fn kind(&self) -> &ValueKind { &self.kind } /// Returns the associated type for this output. If this output not a state, or is the end /// state, then `None` is returned. pub(crate) fn assoc_type(&self) -> Option<&Ident> { self.assoc_type.as_ref() } pub(crate) fn var_name(&self) -> &Ident { &self.var_name } /// If this output actually appears in the tuple returned by its method, then its variable name /// is returned. Otherwise, `None` is returned. /// /// An output of a transition is not actually be an output of the corresponding trait method in /// the case of a method in a client actor. pub(crate) fn output_var_name(&self) -> Option<&Ident> { if self.type_name.is_some() { Some(&self.var_name) } else { None } } /// Returns the token for this input when it appears in a client handle method. pub(crate) fn as_handle_param(&self) -> TokenStream { let name = &self.var_name; if let ValueKind::Dest { msg_type, .. } = &self.kind { quote_spanned! {self.span()=> mut #name: #msg_type } } else { quote_spanned! {self.span() => } } } #[allow(dead_code)] pub(crate) fn as_method_call(&self) -> TokenStream { if let ValueKind::Msg { .. } | ValueKind::Dest { .. } = &self.kind { let var_name = &self.var_name; quote_spanned! {self.span()=> #var_name } } else { quote_spanned! {self.span()=> } } } pub(crate) fn in_method_decl(&self) -> TokenStream { let var_name = &self.var_name; match &self.kind { ValueKind::Msg { def, .. } => { let msg_type = def.msg_type.as_ref(); quote! { #var_name: #msg_type } } ValueKind::State { .. } => quote_spanned! {self.span()=> }, // Dest values only ever occur in the inputs of clients. In client handles, only call // replies are passed in. ValueKind::Dest { reply_type, .. } => { if let Some(reply_type) = reply_type { quote! { #var_name: #reply_type } } else { quote! {} } } ValueKind::Return => quote! {}, } } } #[cfg_attr(test, derive(Debug))] pub(crate) enum ValueKind { /// Represents a value which is passing in a message, and so always occurs in an input position. Msg { def: Rc }, /// Represents a value which is passing out a state, and so always occurs in an output position. State { def: Rc }, /// Represents a value which is sending a message to another actor, and so is in the input /// position for client actors, but in the output position for server actors. Dest { def: Rc, msg_type: Rc, reply_type: Option>, part_of_client: bool, }, /// Represents the return value of a client handle. Return, } impl ValueKind { fn var_name(&self) -> Ident { let ident = match self { Self::Msg { def, .. } => def.msg_type.as_ref(), Self::State { def, .. } => def.state_trait.as_ref(), Self::Dest { def, .. } => def.state.state_ref().state_trait.as_ref(), Self::Return { .. } => return format_ident!("return_var"), }; format_ident!("{}_var", ident.pascal_to_snake()) } fn new_dest(def: Rc, msg_lookup: &MsgLookup, part_of_client: bool) -> Self { let msg_info = msg_lookup.lookup(&def.msg); let msg_type = msg_info.msg_type().clone(); let reply_type = msg_info.reply().map(|reply| reply.msg_type().clone()); Self::Dest { def, msg_type, reply_type, part_of_client, } } fn span(&self) -> Span { match self { Self::Msg { def, .. } => def.span(), Self::State { def, .. } => def.span(), Self::Dest { def, .. } => def.span(), Self::Return { .. } => Span::call_site(), } } pub(crate) fn state_trait(&self) -> Option<&Rc> { if let Self::State { def, .. } = self { Some(&def.state_trait) } else { None } } } /// A type used to query information about actors, states, and their relationships. pub(crate) struct ActorLookup { /// A map from an actor name to the set of states which are part of that actor. actor_states: HashMap, HashSet>>, #[allow(dead_code)] /// A map from a state name to the actor name which that state is a part of. actors_by_state: HashMap, Rc>, /// A map from an actor name to the set of actor names which spawn it. parents: HashMap, HashSet>>, /// A map from an actor name to the set of actors names which it spawns. children: HashMap, HashSet>>, /// A map from the initial state of an actor to the actor. actors_by_init_state: HashMap, Rc>, /// The set of actors which are service providers in this protocol. service_providers: HashSet>, } impl ActorLookup { fn new<'a, A, T>(actor_defs: A, transitions: T) -> syn::Result where A: IntoIterator, T: IntoIterator, { // First we gather all the information we can by iterating over the actor definitions. let mut actor_states = HashMap::new(); let mut actors_by_state = HashMap::new(); let mut parents = HashMap::new(); let mut children = HashMap::new(); let mut actors_by_init_state = HashMap::new(); for actor_def in actor_defs { let mut states = HashSet::new(); let actor_name = &actor_def.actor; let mut first = true; for state in actor_def.states.as_ref().iter() { if first { actors_by_init_state.insert(state.clone(), actor_name.clone()); first = false; } states.insert(state.clone()); actors_by_state.insert(state.clone(), actor_def.actor.clone()); } actor_states.insert(actor_name.clone(), states); parents.insert(actor_name.clone(), HashSet::new()); children.insert(actor_name.clone(), HashSet::new()); } // Then, we gather information by iterating over the transitions. let mut transitions_to = HashMap::new(); let mut service_providers = HashSet::new(); for transition in transitions { let in_state = &transition.in_state.state_trait; let parent = actors_by_state .get(in_state) .ok_or_else(|| syn::Error::new(in_state.span(), error::msgs::UNDECLARED_STATE))?; for (index, out_state) in transition.out_states.as_ref().iter().enumerate() { let out_state = &out_state.state_trait; transitions_to .entry(in_state.clone()) .or_insert_with(HashSet::new) .insert(out_state.clone()); // The first output state is skipped because the current actor is transitioning to // it, its not creating a new actor. if 0 == index { continue; } let child = actors_by_state.get(out_state).ok_or_else(|| { syn::Error::new(out_state.span(), error::msgs::UNDECLARED_STATE) })?; parents .entry(child.clone()) .or_insert_with(HashSet::new) .insert(parent.clone()); children .entry(parent.clone()) .or_insert_with(HashSet::new) .insert(child.clone()); } for dest in transition.out_msgs.as_ref().iter() { if let DestinationState::Service(service) = &dest.state { let dest_state = &service.state_trait; let actor_name = actors_by_state.get(dest_state).ok_or_else(|| { syn::Error::new(dest_state.span(), error::msgs::UNDECLARED_STATE) })?; service_providers.insert(actor_name.clone()); } } } Ok(Self { actor_states, actors_by_state, parents, children, actors_by_init_state, service_providers, }) } const UNKNOWN_ACTOR_ERR: &str = "Unknown actor. This indicates there is a bug in the btproto crate."; /// Returns the set of states associated with the given actor. /// /// This method **panics** if you call it with a non-existent actor name. pub(crate) fn actor_states(&self, actor_name: &Ident) -> &HashSet> { self.actor_states .get(actor_name) .unwrap_or_else(|| panic!("actor_states: {}", Self::UNKNOWN_ACTOR_ERR)) } /// Returns the name of the actor containing the given state. /// /// **Panics**, if called with an unknown state name. pub(crate) fn actor_with_state(&self, state_name: &Ident) -> &Rc { self.actors_by_state .get(state_name) .unwrap_or_else(|| panic!("can't find state {state_name}")) } pub(crate) fn parents(&self, actor_name: &Ident) -> &HashSet> { self.parents .get(actor_name) .unwrap_or_else(|| panic!("parents: {}", Self::UNKNOWN_ACTOR_ERR)) } pub(crate) fn children(&self, actor_name: &Ident) -> &HashSet> { self.children .get(actor_name) .unwrap_or_else(|| panic!("children: {}", Self::UNKNOWN_ACTOR_ERR)) } /// Returns the name of the actor for which the given state name is the initial state. If the /// given name is not the initial state of any actor, [None] is returned. pub(crate) fn actor_with_init_state(&self, init_state: &Ident) -> Option<&Rc> { self.actors_by_init_state.get(init_state) } } pub(crate) struct MsgLookup { messages: HashMap, MsgInfo>, } impl MsgLookup { fn new<'a>(transitions: impl IntoIterator) -> Self { let mut messages = HashMap::new(); for transition in transitions.into_iter() { let in_state = &transition.in_state.state_trait; if let Some(in_msg) = transition.in_msg() { let msg_name = &in_msg.msg_type; let msg_info = messages .entry(msg_name.clone()) .or_insert_with(|| MsgInfo::empty(in_msg.clone())); msg_info.record_receiver(in_msg, in_state.clone()); } for dest in transition.out_msgs.as_ref().iter() { let msg = &dest.msg; let msg_info = messages .entry(msg.msg_type.clone()) .or_insert_with(|| MsgInfo::empty(msg.clone())); msg_info.record_sender(&dest.msg, in_state.clone()); } } Self { messages } } pub(crate) fn lookup(&self, msg: &Message) -> &MsgInfo { self.messages .get(msg.msg_type.as_ref()) // Since a message is either sent or received, and we've added all such messages in // the new method, this unwrap should not panic. .unwrap_or_else(|| { panic!("Failed to find message info. There is a bug in MessageLookup::new.") }) .info_for(msg) } pub(crate) fn msg_iter(&self) -> impl Iterator { self.messages .values() .flat_map(|msg_info| [Some(msg_info), msg_info.reply()]) .flatten() } } impl AsRef, MsgInfo>> for MsgLookup { fn as_ref(&self) -> &HashMap, MsgInfo> { &self.messages } } #[cfg_attr(test, derive(Debug))] pub(crate) struct MsgInfo { def: Rc, msg_name: Rc, msg_type: Rc, is_reply: bool, senders: HashSet>, receivers: HashSet>, reply: Option>, } impl MsgInfo { fn empty(def: Rc) -> Self { let msg_name = def.msg_type.as_ref(); Self { msg_name: def.msg_type.clone(), msg_type: Rc::new(quote! { #msg_name }), is_reply: false, senders: HashSet::new(), receivers: HashSet::new(), reply: None, def, } } fn info_for(&self, msg: &Message) -> &Self { if msg.is_reply() { // If this message is a reply, then we should have seen it in MsgLookup::new and // initialized its reply pointer. So this unwrap shouldn't panic. self.reply.as_ref().unwrap_or_else(|| { panic!( "A reply message was not properly recorded. There is a bug in MessageLookup::new." ) }) } else { self } } fn info_for_mut(&mut self, msg: &Rc) -> &mut Self { if msg.is_reply() { self.reply.get_or_insert_with(|| { let mut reply = MsgInfo::empty(msg.clone()); reply.is_reply = true; reply.msg_name = Rc::new(format_ident!( "{}{}", msg.msg_type.as_ref(), MessageReplyPart::REPLY_IDENT )); let parent = self.msg_name.as_ref(); reply.msg_type = Rc::new(quote! { <#parent as ::btrun::model::CallMsg>::Reply }); Box::new(reply) }) } else { self } } fn record_receiver(&mut self, msg: &Rc, receiver: Rc) { let target = self.info_for_mut(msg); target.receivers.insert(receiver); } fn record_sender(&mut self, msg: &Rc, sender: Rc) { let target = self.info_for_mut(msg); target.senders.insert(sender); } pub(crate) fn def(&self) -> &Rc { &self.def } /// The unique name of this message. If it is a reply, it will end in /// `MessageReplyPart::REPLY_IDENT`. pub(crate) fn msg_name(&self) -> &Rc { &self.msg_name } /// The type of this message. If this message is not a reply, this is just `msg_name`. If it is /// a reply, it is `<#msg_name as ::btrun::model::CallMsg>::Reply`. pub(crate) fn msg_type(&self) -> &Rc { &self.msg_type } pub(crate) fn reply(&self) -> Option<&MsgInfo> { self.reply.as_ref().map(|ptr| ptr.as_ref()) } pub(crate) fn is_reply(&self) -> bool { self.is_reply } /// Returns true iff this message is a call. pub(crate) fn is_call(&self) -> bool { self.reply.is_some() } } impl PartialEq for MsgInfo { fn eq(&self, other: &Self) -> bool { self.msg_name.as_ref() == other.msg_name.as_ref() } } impl Eq for MsgInfo {} impl Hash for MsgInfo { fn hash(&self, state: &mut H) { self.msg_name.hash(state) } } #[cfg(test)] mod tests { use crate::{ error::{assert_err, assert_ok}, parsing::{DestinationState, NameDef}, }; use super::*; #[test] fn protocol_model_new_minimal_ok() { let input = Protocol::minimal(); let result = ProtocolModel::new(input); assert_ok(result); } #[test] fn protocol_model_new_dup_recv_transition_err() { let input = Protocol::new( NameDef::new("Undeclared"), [ActorDef::new("actor", ["Init", "Next"])], [ Transition::new( State::new("Init", []), Some(Message::new("Query", false, [])), [State::new("Next", [])], [], ), Transition::new( State::new("Init", []), Some(Message::new("Query", false, [])), [State::new("Init", [])], [], ), ], ); let result = ProtocolModel::new(input); assert_err(result, error::msgs::DUPLICATE_TRANSITION); } #[test] fn protocol_model_new_dup_send_transition_err() { let input = Protocol::new( NameDef::new("Undeclared"), [ ActorDef::new("server", ["Init", "Next"]), ActorDef::new("client", ["Client"]), ], [ Transition::new( State::new("Client", []), None, [State::new("Client", [])], [Dest::new( DestinationState::Individual(State::new("Init", [])), Message::new("Query", false, []), )], ), Transition::new( State::new("Client", []), None, [State::new("Client", [])], [Dest::new( DestinationState::Individual(State::new("Next", [])), Message::new("Query", false, []), )], ), ], ); let result = ProtocolModel::new(input); assert_err(result, error::msgs::DUPLICATE_TRANSITION); } #[test] fn msg_sent_or_received_msg_received_ok() { let input = Protocol::new( NameDef::new("Test"), [ActorDef::new("actor", ["Init"])], [Transition::new( State::new("Init", []), Some(Message::new("Activate", false, [])), [State::new("End", [])], [], )], ); let result = ProtocolModel::new(input); assert_ok(result); } #[test] fn msg_sent_or_received_msg_sent_ok() { let input = Protocol::new( NameDef::new("Test"), [ActorDef::new("actor", ["First", "Second"])], [Transition::new( State::new("First", []), None, [State::new("First", [])], [Dest::new( DestinationState::Individual(State::new("Second", [])), Message::new("Msg", false, []), )], )], ); let result = ProtocolModel::new(input); assert_ok(result); } #[test] fn msg_sent_or_received_neither_err() { let input = Protocol::new( NameDef::new("Test"), [ActorDef::new("actor", ["First"])], [Transition::new( State::new("First", []), None, [State::new("First", [])], [], )], ); let result = ProtocolModel::new(input); assert_err(result, error::msgs::NO_MSG_SENT_OR_RECEIVED); } #[test] fn reply_is_marked_in_output() { const MSG: &str = "Ping"; let msg_ident = format_ident!("{MSG}"); let expected = quote! { < #msg_ident as :: btrun :: model :: CallMsg > :: Reply }; let input = Protocol::new( NameDef::new("ReplyTest"), [ ActorDef::new("server", ["Listening"]), ActorDef::new("client", ["Client", "Waiting"]), ], [ Transition::new( State::new("Client", []), None, [State::new("Waiting", [])], [Dest::new( DestinationState::Service(State::new("Listening", [])), Message::new(MSG, false, []), )], ), Transition::new( State::new("Listening", []), Some(Message::new(MSG, false, [])), [State::new("Listening", [])], [Dest::new( DestinationState::Individual(State::new("Client", [])), Message::new(MSG, true, []), )], ), ], ); let actual = ProtocolModel::new(input).unwrap(); let mut msg_types = actual.outputs_iter().flat_map(|output| { if let ValueKind::Dest { msg_type, .. } = &output.kind { Some(msg_type) } else { None } }); let msg_type = msg_types.next().unwrap(); assert!(msg_types.next().is_none()); assert_eq!(expected.to_string(), msg_type.to_string()); } fn simple_client_server_proto() -> Protocol { Protocol::new( NameDef::new("IsClientTest"), [ ActorDef::new("server", ["Server"]), ActorDef::new("client", ["Client"]), ], [ Transition::new( State::new("Client", []), None, [State::new("End", [])], [Dest::new( DestinationState::Service(State::new("Server", [])), Message::new("Msg", false, []), )], ), Transition::new( State::new("Server", []), Some(Message::new("Msg", false, [])), [State::new("End", [])], [], ), ], ) } #[test] fn is_client_false_for_server() { let input = simple_client_server_proto(); let actual = ProtocolModel::new(input).unwrap(); let server = actual.actors.get(&format_ident!("server")).unwrap(); assert!(!server.kind().is_client()); } #[test] fn is_client_true_for_client() { let input = simple_client_server_proto(); let actual = ProtocolModel::new(input).unwrap(); let client = actual.actors.get(&format_ident!("client")).unwrap(); assert!(client.kind().is_client()); } #[test] fn is_client_false_for_worker() { let input = Protocol::new( NameDef::new("IsClientTest"), [ ActorDef::new("server", ["Listening"]), ActorDef::new("worker", ["Working"]), ActorDef::new("client", ["Unregistered", "Registered"]), ], [ Transition::new( State::new("Unregistered", []), None, [State::new("Registered", [])], [Dest::new( DestinationState::Service(State::new("Listening", [])), Message::new("Register", false, ["Registered"]), )], ), Transition::new( State::new("Listening", []), Some(Message::new("Register", false, ["Registered"])), [ State::new("Listening", []), State::new("Working", ["Registered"]), ], [], ), Transition::new( State::new("Working", ["Registered"]), None, [State::new("End", [])], [Dest::new( DestinationState::Individual(State::new("Registered", [])), Message::new("Completed", false, []), )], ), Transition::new( State::new("Registered", []), Some(Message::new("Completed", false, [])), [State::new("End", [])], [], ), ], ); let actual = ProtocolModel::new(input).unwrap(); let worker = actual.actors.get(&format_ident!("worker")).unwrap(); assert!(!worker.kind().is_client()); } }