use std::{ collections::{HashMap, HashSet}, hash::Hash, rc::Rc, }; use btrun::End; use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; use crate::{ case_convert::CaseConvert, error, parsing::MessageReplyPart, parsing::{ActorDef, Dest, GetSpan, Message, Protocol, State, Transition}, }; pub(crate) struct ProtocolModel { def: Protocol, msg_lookup: MsgLookup, actors: HashMap, ActorModel>, } impl ProtocolModel { pub(crate) fn new(def: Protocol) -> syn::Result { let actor_lookup = ActorLookup::new(def.actor_defs.iter().map(|x| x.as_ref())); let msg_lookup = MsgLookup::new(def.transitions.iter().map(|x| x.as_ref())); let mut actors = HashMap::new(); for actor_def in def.actor_defs.iter() { let actor_name = &actor_def.actor; let actor_states = actor_lookup.actor_states(actor_name.as_ref()); let transitions_by_state = actor_states.iter().map(|state_name| { let transitions = def .transitions .iter() .filter(|transition| { state_name.as_ref() == transition.in_state.state_trait.as_ref() }) .cloned(); (state_name.clone(), transitions) }); let actor = ActorModel::new(actor_def.clone(), &msg_lookup, transitions_by_state)?; actors.insert(actor_name.clone(), actor); } Ok(Self { def, msg_lookup, actors, }) } pub(crate) fn def(&self) -> &Protocol { &self.def } pub(crate) fn msg_lookup(&self) -> &MsgLookup { &self.msg_lookup } pub(crate) fn actors_iter(&self) -> impl Iterator { self.actors.values() } pub(crate) fn states_iter(&self) -> impl Iterator { self.actors_iter().flat_map(|actor| actor.states().values()) } #[cfg(test)] pub(crate) fn methods_iter(&self) -> impl Iterator { self.states_iter() .flat_map(|state| state.methods().values()) } #[cfg(test)] pub(crate) fn outputs_iter(&self) -> impl Iterator { self.methods_iter() .flat_map(|method| method.outputs().iter()) } } pub(crate) struct ActorModel { #[allow(dead_code)] def: Rc, is_client: bool, states: HashMap, StateModel>, } impl ActorModel { fn new(def: Rc, messages: &MsgLookup, state_iter: S) -> syn::Result where S: IntoIterator, T)>, T: IntoIterator>, { let transitions: HashMap<_, Vec<_>> = state_iter .into_iter() .map(|(name, transitions)| (name, transitions.into_iter().collect())) .collect(); let is_client = transitions .values() .flatten() .any(|transition| transition.is_client()); let mut states = HashMap::new(); for (name, transitions) in transitions.into_iter() { let state = StateModel::new(name.clone(), messages, transitions, is_client)?; if let Some(prev) = states.insert(name, state) { panic!( "States are not being grouped by actor correctly. Duplicate state name: '{}'", prev.name ); } } Ok(Self { def, is_client, states, }) } pub(crate) fn is_client(&self) -> bool { self.is_client } pub(crate) fn states(&self) -> &HashMap, StateModel> { &self.states } } pub(crate) struct StateModel { name: Rc, methods: HashMap, MethodModel>, } impl StateModel { fn new( name: Rc, messages: &MsgLookup, transitions: T, part_of_client: bool, ) -> syn::Result where T: IntoIterator>, { let mut methods = HashMap::new(); for transition in transitions.into_iter() { let transition_span = transition.span(); let method = MethodModel::new(transition, messages, part_of_client)?; if methods.insert(method.name.clone(), method).is_some() { return Err(syn::Error::new( transition_span, error::msgs::DUPLICATE_TRANSITION, )); } } Ok(Self { name, methods }) } pub(crate) fn name(&self) -> &Ident { self.name.as_ref() } pub(crate) fn methods(&self) -> &HashMap, MethodModel> { &self.methods } } impl GetSpan for StateModel { fn span(&self) -> Span { self.name.span() } } #[cfg_attr(test, derive(Debug))] pub(crate) struct MethodModel { def: Rc, name: Rc, inputs: Vec, outputs: Vec, future: Ident, } impl MethodModel { fn new(def: Rc, messages: &MsgLookup, part_of_client: bool) -> syn::Result { let name = Rc::new(Self::new_name(def.as_ref())?); let type_prefix = name.snake_to_pascal(); Ok(Self { name, inputs: Self::new_inputs(def.as_ref(), messages, part_of_client), outputs: Self::new_outputs(def.as_ref(), &type_prefix, messages, part_of_client), future: format_ident!("{type_prefix}Fut"), def, }) } fn new_name(def: &Transition) -> syn::Result { let name = if let Some(msg) = def.in_msg() { format_ident!("handle_{}", msg.variant().pascal_to_snake()) } else { let mut dests = def.out_msgs.as_ref().iter(); let mut msg_names = String::new(); if let Some(dest) = dests.next() { msg_names.push_str(dest.msg.variant().pascal_to_snake().as_str()); } else { return Err(syn::Error::new( def.span(), error::msgs::NO_MSG_SENT_OR_RECEIVED, )); } for dest in dests { msg_names.push('_'); msg_names.push_str(dest.msg.variant().pascal_to_snake().as_str()); } format_ident!("send_{msg_names}") }; Ok(name) } fn new_inputs(def: &Transition, messages: &MsgLookup, part_of_client: bool) -> Vec { let mut inputs = Vec::new(); if let Some(in_msg) = def.in_msg() { let msg_info = messages.lookup(in_msg); inputs.push(InputModel::new( msg_info.msg_name().clone(), msg_info.msg_type.clone(), )) } if part_of_client { for out_msg in def.out_msgs.as_ref().iter() { let msg_info = messages.lookup(&out_msg.msg); inputs.push(InputModel::new( msg_info.msg_name().clone(), msg_info.msg_type.clone(), )) } } inputs } fn new_outputs( def: &Transition, type_prefix: &str, messages: &MsgLookup, part_of_client: bool, ) -> Vec { let mut outputs = Vec::new(); for state in def.out_states.as_ref().iter() { outputs.push(OutputModel::new( OutputKind::State { def: state.clone() }, type_prefix, )); } for dest in def.out_msgs.as_ref().iter() { let msg_info = messages.lookup(&dest.msg); outputs.push(OutputModel::new( OutputKind::Msg { def: dest.clone(), msg_type: msg_info.msg_type.clone(), is_call: msg_info.is_call(), part_of_client, }, type_prefix, )) } outputs } pub(crate) fn def(&self) -> &Transition { self.def.as_ref() } pub(crate) fn name(&self) -> &Rc { &self.name } pub(crate) fn inputs(&self) -> &Vec { &self.inputs } pub(crate) fn outputs(&self) -> &Vec { &self.outputs } pub(crate) fn future(&self) -> &Ident { &self.future } } impl GetSpan for MethodModel { fn span(&self) -> Span { self.def.span() } } #[cfg_attr(test, derive(Debug))] pub(crate) struct InputModel { name: Ident, arg_type: Rc, } impl InputModel { fn new(type_name: Rc, arg_type: Rc) -> Self { let name = format_ident!("{}_arg", type_name.to_string().pascal_to_snake()); Self { name, arg_type } } } impl ToTokens for InputModel { fn to_tokens(&self, tokens: &mut TokenStream) { let name = &self.name; let arg_type = self.arg_type.as_ref(); tokens.extend(quote! { #name : #arg_type }) } } #[cfg_attr(test, derive(Debug))] pub(crate) struct OutputModel { type_name: Option, decl: Option, #[allow(dead_code)] kind: OutputKind, } impl OutputModel { fn new(kind: OutputKind, type_prefix: &str) -> Self { let (decl, type_name) = match &kind { OutputKind::State { def, .. } => { let state_trait = def.state_trait.as_ref(); if state_trait == End::ident() { let end_ident = format_ident!("{}", End::ident()); (None, Some(quote! { ::btrun::#end_ident })) } else { let type_name = format_ident!("{type_prefix}{}", state_trait); ( Some(quote! { type #type_name: #state_trait; }), Some(quote! { Self::#type_name }), ) } } OutputKind::Msg { msg_type, part_of_client, is_call, .. } => { let type_name = if *part_of_client { if *is_call { Some(quote! { <#msg_type as ::btrun::CallMsg>::Reply }) } else { None } } else { Some(quote! { #msg_type }) }; (None, type_name) } }; Self { type_name, decl, kind, } } pub(crate) fn type_name(&self) -> Option<&TokenStream> { self.type_name.as_ref() } pub(crate) fn decl(&self) -> Option<&TokenStream> { self.decl.as_ref() } } #[cfg_attr(test, derive(Debug))] pub(crate) enum OutputKind { State { def: Rc, }, Msg { #[allow(dead_code)] def: Rc, msg_type: Rc, is_call: bool, part_of_client: bool, }, } pub(crate) struct ActorLookup { actor_states: HashMap, HashSet>>, } impl ActorLookup { fn new<'a>(actor_defs: impl IntoIterator) -> Self { let mut actor_states = HashMap::new(); for actor_def in actor_defs.into_iter() { let mut states = HashSet::new(); for state in actor_def.states.as_ref().iter() { states.insert(state.clone()); } actor_states.insert(actor_def.actor.clone(), states); } Self { actor_states } } /// Returns the set of states associated with the given actor. /// /// This method **panics** if you call it with a non-existent actor name. pub(crate) fn actor_states(&self, actor_name: &Ident) -> &HashSet> { self.actor_states.get(actor_name).unwrap_or_else(|| { panic!("Unknown actor. This indicates there is a bug in the btproto crate.") }) } } pub(crate) struct MsgLookup { messages: HashMap, MsgInfo>, } impl MsgLookup { fn new<'a>(transitions: impl IntoIterator) -> Self { let mut messages = HashMap::new(); for transition in transitions.into_iter() { let in_state = &transition.in_state.state_trait; if let Some(in_msg) = transition.in_msg() { let msg_name = &in_msg.msg_type; let msg_info = messages .entry(msg_name.clone()) .or_insert_with(|| MsgInfo::empty(in_msg.clone())); msg_info.record_receiver(in_msg, in_state.clone()); } for dest in transition.out_msgs.as_ref().iter() { let msg = &dest.msg; let msg_info = messages .entry(msg.msg_type.clone()) .or_insert_with(|| MsgInfo::empty(msg.clone())); msg_info.record_sender(&dest.msg, in_state.clone()); } } Self { messages } } pub(crate) fn lookup(&self, msg: &Message) -> &MsgInfo { self.messages .get(msg.msg_type.as_ref()) // Since a message is either sent or received, and we've added all such messages in // the new method, this unwrap should not panic. .unwrap_or_else(|| { panic!("Failed to find message info. There is a bug in MessageLookup::new.") }) .info_for(msg) } pub(crate) fn msg_iter(&self) -> impl Iterator { self.messages .values() .flat_map(|msg_info| [Some(msg_info), msg_info.reply()]) .flatten() } } impl AsRef, MsgInfo>> for MsgLookup { fn as_ref(&self) -> &HashMap, MsgInfo> { &self.messages } } #[cfg_attr(test, derive(Debug))] pub(crate) struct MsgInfo { def: Rc, msg_name: Rc, msg_type: Rc, is_reply: bool, senders: HashSet>, receivers: HashSet>, reply: Option>, } impl MsgInfo { fn empty(def: Rc) -> Self { let msg_name = def.msg_type.as_ref(); Self { msg_name: def.msg_type.clone(), msg_type: Rc::new(quote! { #msg_name }), is_reply: false, senders: HashSet::new(), receivers: HashSet::new(), reply: None, def, } } fn info_for(&self, msg: &Message) -> &Self { if msg.is_reply() { // If this message is a reply, then we should have seen it in MsgLookup::new and // initialized its reply pointer. So this unwrap shouldn't panic. self.reply.as_ref().unwrap_or_else(|| { panic!( "A reply message was not properly recorded. There is a bug in MessageLookup::new." ) }) } else { self } } fn info_for_mut(&mut self, msg: &Rc) -> &mut Self { if msg.is_reply() { self.reply.get_or_insert_with(|| { let mut reply = MsgInfo::empty(msg.clone()); reply.is_reply = true; reply.msg_name = Rc::new(format_ident!( "{}{}", msg.msg_type.as_ref(), MessageReplyPart::REPLY_IDENT )); let parent = self.msg_name.as_ref(); reply.msg_type = Rc::new(quote! { <#parent as ::btrun::CallMsg>::Reply }); Box::new(reply) }) } else { self } } fn record_receiver(&mut self, msg: &Rc, receiver: Rc) { let target = self.info_for_mut(msg); target.receivers.insert(receiver); } fn record_sender(&mut self, msg: &Rc, sender: Rc) { let target = self.info_for_mut(msg); target.senders.insert(sender); } pub(crate) fn def(&self) -> &Rc { &self.def } /// The unique name of this message. If it is a reply, it will end in /// `MessageReplyPart::REPLY_IDENT`. pub(crate) fn msg_name(&self) -> &Rc { &self.msg_name } /// The type of this message. If this message is not a reply, this is just `msg_name`. If it is /// a reply, it is `<#msg_name as ::btrun::CallMsg>::Reply`. pub(crate) fn msg_type(&self) -> &Rc { &self.msg_type } pub(crate) fn reply(&self) -> Option<&MsgInfo> { self.reply.as_ref().map(|ptr| ptr.as_ref()) } pub(crate) fn is_reply(&self) -> bool { self.is_reply } /// Returns true iff this message is a call. pub(crate) fn is_call(&self) -> bool { self.reply.is_some() } } impl PartialEq for MsgInfo { fn eq(&self, other: &Self) -> bool { self.msg_name.as_ref() == other.msg_name.as_ref() } } impl Eq for MsgInfo {} impl Hash for MsgInfo { fn hash(&self, state: &mut H) { self.msg_name.hash(state) } } #[cfg(test)] mod tests { use crate::{ error::{assert_err, assert_ok}, parsing::{DestinationState, NameDef}, }; use super::*; #[test] fn protocol_model_new_minimal_ok() { let input = Protocol::minimal(); let result = ProtocolModel::new(input); assert_ok(result); } #[test] fn protocol_model_new_dup_recv_transition_err() { let input = Protocol::new( NameDef::new("Undeclared"), [ActorDef::new("actor", ["Init", "Next"])], [ Transition::new( State::new("Init", []), Some(Message::new("Query", false, [])), [State::new("Next", [])], [], ), Transition::new( State::new("Init", []), Some(Message::new("Query", false, [])), [State::new("Init", [])], [], ), ], ); let result = ProtocolModel::new(input); assert_err(result, error::msgs::DUPLICATE_TRANSITION); } #[test] fn protocol_model_new_dup_send_transition_err() { let input = Protocol::new( NameDef::new("Undeclared"), [ ActorDef::new("server", ["Init", "Next"]), ActorDef::new("client", ["Client"]), ], [ Transition::new( State::new("Client", []), None, [State::new("Client", [])], [Dest::new( DestinationState::Individual(State::new("Init", [])), Message::new("Query", false, []), )], ), Transition::new( State::new("Client", []), None, [State::new("Client", [])], [Dest::new( DestinationState::Individual(State::new("Next", [])), Message::new("Query", false, []), )], ), ], ); let result = ProtocolModel::new(input); assert_err(result, error::msgs::DUPLICATE_TRANSITION); } #[test] fn msg_sent_or_received_msg_received_ok() { let input = Protocol::new( NameDef::new("Test"), [ActorDef::new("actor", ["Init"])], [Transition::new( State::new("Init", []), Some(Message::new("Activate", false, [])), [State::new("End", [])], [], )], ); let result = ProtocolModel::new(input); assert_ok(result); } #[test] fn msg_sent_or_received_msg_sent_ok() { let input = Protocol::new( NameDef::new("Test"), [ActorDef::new("actor", ["First", "Second"])], [Transition::new( State::new("First", []), None, [State::new("First", [])], [Dest::new( DestinationState::Individual(State::new("Second", [])), Message::new("Msg", false, []), )], )], ); let result = ProtocolModel::new(input); assert_ok(result); } #[test] fn msg_sent_or_received_neither_err() { let input = Protocol::new( NameDef::new("Test"), [ActorDef::new("actor", ["First"])], [Transition::new( State::new("First", []), None, [State::new("First", [])], [], )], ); let result = ProtocolModel::new(input); assert_err(result, error::msgs::NO_MSG_SENT_OR_RECEIVED); } #[test] fn reply_is_marked_in_output() { const MSG: &str = "Ping"; let input = Protocol::new( NameDef::new("ReplyTest"), [ ActorDef::new("server", ["Listening"]), ActorDef::new("client", ["Client", "Waiting"]), ], [ Transition::new( State::new("Client", []), None, [State::new("Waiting", [])], [Dest::new( DestinationState::Service(State::new("Listening", [])), Message::new(MSG, false, []), )], ), Transition::new( State::new("Listening", []), Some(Message::new(MSG, false, [])), [State::new("Listening", [])], [Dest::new( DestinationState::Individual(State::new("Client", [])), Message::new(MSG, true, []), )], ), Transition::new( State::new("Waiting", []), Some(Message::new(MSG, true, [])), [State::new(End::ident(), [])], [], ), ], ); let actual = ProtocolModel::new(input).unwrap(); let outputs: Vec<_> = actual .outputs_iter() .map(|output| { if let OutputKind::Msg { is_call, .. } = output.kind { Some(is_call) } else { None } }) .filter(|x| x.is_some()) .map(|x| x.unwrap()) .collect(); assert_eq!(2, outputs.len()); assert_eq!(1, outputs.iter().filter(|is_reply| **is_reply).count()); assert_eq!(1, outputs.iter().filter(|is_reply| !*is_reply).count()); } }