123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770 |
- use std::{
- collections::{HashMap, HashSet},
- hash::Hash,
- rc::Rc,
- };
- use btrun::End;
- use proc_macro2::{Ident, Span, TokenStream};
- use quote::{format_ident, quote, ToTokens};
- use crate::{
- case_convert::CaseConvert,
- error,
- parsing::MessageReplyPart,
- parsing::{ActorDef, Dest, GetSpan, Message, Protocol, State, Transition},
- };
- pub(crate) struct ProtocolModel {
- def: Protocol,
- msg_lookup: MsgLookup,
- actors: HashMap<Rc<Ident>, ActorModel>,
- }
- impl ProtocolModel {
- pub(crate) fn new(def: Protocol) -> syn::Result<Self> {
- let actor_lookup = ActorLookup::new(def.actor_defs.iter().map(|x| x.as_ref()));
- let msg_lookup = MsgLookup::new(def.transitions.iter().map(|x| x.as_ref()));
- let mut actors = HashMap::new();
- for actor_def in def.actor_defs.iter() {
- let actor_name = &actor_def.actor;
- let actor_states = actor_lookup.actor_states(actor_name.as_ref());
- let transitions_by_state = actor_states.iter().map(|state_name| {
- let transitions = def
- .transitions
- .iter()
- .filter(|transition| {
- state_name.as_ref() == transition.in_state.state_trait.as_ref()
- })
- .cloned();
- (state_name.clone(), transitions)
- });
- let actor = ActorModel::new(actor_def.clone(), &msg_lookup, transitions_by_state)?;
- actors.insert(actor_name.clone(), actor);
- }
- Ok(Self {
- def,
- msg_lookup,
- actors,
- })
- }
- pub(crate) fn def(&self) -> &Protocol {
- &self.def
- }
- pub(crate) fn msg_lookup(&self) -> &MsgLookup {
- &self.msg_lookup
- }
- pub(crate) fn actors_iter(&self) -> impl Iterator<Item = &ActorModel> {
- self.actors.values()
- }
- pub(crate) fn states_iter(&self) -> impl Iterator<Item = &StateModel> {
- self.actors_iter().flat_map(|actor| actor.states().values())
- }
- #[cfg(test)]
- pub(crate) fn methods_iter(&self) -> impl Iterator<Item = &MethodModel> {
- self.states_iter()
- .flat_map(|state| state.methods().values())
- }
- #[cfg(test)]
- pub(crate) fn outputs_iter(&self) -> impl Iterator<Item = &OutputModel> {
- self.methods_iter()
- .flat_map(|method| method.outputs().iter())
- }
- }
- pub(crate) struct ActorModel {
- #[allow(dead_code)]
- def: Rc<ActorDef>,
- is_client: bool,
- states: HashMap<Rc<Ident>, StateModel>,
- }
- impl ActorModel {
- fn new<S, T>(def: Rc<ActorDef>, messages: &MsgLookup, state_iter: S) -> syn::Result<Self>
- where
- S: IntoIterator<Item = (Rc<Ident>, T)>,
- T: IntoIterator<Item = Rc<Transition>>,
- {
- let transitions: HashMap<_, Vec<_>> = state_iter
- .into_iter()
- .map(|(name, transitions)| (name, transitions.into_iter().collect()))
- .collect();
- let is_client = transitions
- .values()
- .flatten()
- .any(|transition| transition.is_client());
- let mut states = HashMap::new();
- for (name, transitions) in transitions.into_iter() {
- let state = StateModel::new(name.clone(), messages, transitions, is_client)?;
- if let Some(prev) = states.insert(name, state) {
- panic!(
- "States are not being grouped by actor correctly. Duplicate state name: '{}'",
- prev.name
- );
- }
- }
- Ok(Self {
- def,
- is_client,
- states,
- })
- }
- pub(crate) fn is_client(&self) -> bool {
- self.is_client
- }
- pub(crate) fn states(&self) -> &HashMap<Rc<Ident>, StateModel> {
- &self.states
- }
- }
- pub(crate) struct StateModel {
- name: Rc<Ident>,
- methods: HashMap<Rc<Ident>, MethodModel>,
- }
- impl StateModel {
- fn new<T>(
- name: Rc<Ident>,
- messages: &MsgLookup,
- transitions: T,
- part_of_client: bool,
- ) -> syn::Result<Self>
- where
- T: IntoIterator<Item = Rc<Transition>>,
- {
- let mut methods = HashMap::new();
- for transition in transitions.into_iter() {
- let transition_span = transition.span();
- let method = MethodModel::new(transition, messages, part_of_client)?;
- if methods.insert(method.name.clone(), method).is_some() {
- return Err(syn::Error::new(
- transition_span,
- error::msgs::DUPLICATE_TRANSITION,
- ));
- }
- }
- Ok(Self { name, methods })
- }
- pub(crate) fn name(&self) -> &Ident {
- self.name.as_ref()
- }
- pub(crate) fn methods(&self) -> &HashMap<Rc<Ident>, MethodModel> {
- &self.methods
- }
- }
- impl GetSpan for StateModel {
- fn span(&self) -> Span {
- self.name.span()
- }
- }
- #[cfg_attr(test, derive(Debug))]
- pub(crate) struct MethodModel {
- def: Rc<Transition>,
- name: Rc<Ident>,
- inputs: Vec<InputModel>,
- outputs: Vec<OutputModel>,
- future: Ident,
- }
- impl MethodModel {
- fn new(def: Rc<Transition>, messages: &MsgLookup, part_of_client: bool) -> syn::Result<Self> {
- let name = Rc::new(Self::new_name(def.as_ref())?);
- let type_prefix = name.snake_to_pascal();
- Ok(Self {
- name,
- inputs: Self::new_inputs(def.as_ref(), messages, part_of_client),
- outputs: Self::new_outputs(def.as_ref(), &type_prefix, messages, part_of_client),
- future: format_ident!("{type_prefix}Fut"),
- def,
- })
- }
- fn new_name(def: &Transition) -> syn::Result<Ident> {
- let name = if let Some(msg) = def.in_msg() {
- format_ident!("handle_{}", msg.variant().pascal_to_snake())
- } else {
- let mut dests = def.out_msgs.as_ref().iter();
- let mut msg_names = String::new();
- if let Some(dest) = dests.next() {
- msg_names.push_str(dest.msg.variant().pascal_to_snake().as_str());
- } else {
- return Err(syn::Error::new(
- def.span(),
- error::msgs::NO_MSG_SENT_OR_RECEIVED,
- ));
- }
- for dest in dests {
- msg_names.push('_');
- msg_names.push_str(dest.msg.variant().pascal_to_snake().as_str());
- }
- format_ident!("send_{msg_names}")
- };
- Ok(name)
- }
- fn new_inputs(def: &Transition, messages: &MsgLookup, part_of_client: bool) -> Vec<InputModel> {
- let mut inputs = Vec::new();
- if let Some(in_msg) = def.in_msg() {
- let msg_info = messages.lookup(in_msg);
- inputs.push(InputModel::new(
- msg_info.msg_name().clone(),
- msg_info.msg_type.clone(),
- ))
- }
- if part_of_client {
- for out_msg in def.out_msgs.as_ref().iter() {
- let msg_info = messages.lookup(&out_msg.msg);
- inputs.push(InputModel::new(
- msg_info.msg_name().clone(),
- msg_info.msg_type.clone(),
- ))
- }
- }
- inputs
- }
- fn new_outputs(
- def: &Transition,
- type_prefix: &str,
- messages: &MsgLookup,
- part_of_client: bool,
- ) -> Vec<OutputModel> {
- let mut outputs = Vec::new();
- for state in def.out_states.as_ref().iter() {
- outputs.push(OutputModel::new(
- OutputKind::State { def: state.clone() },
- type_prefix,
- ));
- }
- for dest in def.out_msgs.as_ref().iter() {
- let msg_info = messages.lookup(&dest.msg);
- outputs.push(OutputModel::new(
- OutputKind::Msg {
- def: dest.clone(),
- msg_type: msg_info.msg_type.clone(),
- is_call: msg_info.is_call(),
- part_of_client,
- },
- type_prefix,
- ))
- }
- outputs
- }
- pub(crate) fn def(&self) -> &Transition {
- self.def.as_ref()
- }
- pub(crate) fn name(&self) -> &Rc<Ident> {
- &self.name
- }
- pub(crate) fn inputs(&self) -> &Vec<InputModel> {
- &self.inputs
- }
- pub(crate) fn outputs(&self) -> &Vec<OutputModel> {
- &self.outputs
- }
- pub(crate) fn future(&self) -> &Ident {
- &self.future
- }
- }
- impl GetSpan for MethodModel {
- fn span(&self) -> Span {
- self.def.span()
- }
- }
- #[cfg_attr(test, derive(Debug))]
- pub(crate) struct InputModel {
- name: Ident,
- arg_type: Rc<TokenStream>,
- }
- impl InputModel {
- fn new(type_name: Rc<Ident>, arg_type: Rc<TokenStream>) -> Self {
- let name = format_ident!("{}_arg", type_name.to_string().pascal_to_snake());
- Self { name, arg_type }
- }
- }
- impl ToTokens for InputModel {
- fn to_tokens(&self, tokens: &mut TokenStream) {
- let name = &self.name;
- let arg_type = self.arg_type.as_ref();
- tokens.extend(quote! { #name : #arg_type })
- }
- }
- #[cfg_attr(test, derive(Debug))]
- pub(crate) struct OutputModel {
- type_name: Option<TokenStream>,
- decl: Option<TokenStream>,
- #[allow(dead_code)]
- kind: OutputKind,
- }
- impl OutputModel {
- fn new(kind: OutputKind, type_prefix: &str) -> Self {
- let (decl, type_name) = match &kind {
- OutputKind::State { def, .. } => {
- let state_trait = def.state_trait.as_ref();
- if state_trait == End::ident() {
- let end_ident = format_ident!("{}", End::ident());
- (None, Some(quote! { ::btrun::#end_ident }))
- } else {
- let type_name = format_ident!("{type_prefix}{}", state_trait);
- (
- Some(quote! { type #type_name: #state_trait; }),
- Some(quote! { Self::#type_name }),
- )
- }
- }
- OutputKind::Msg {
- msg_type,
- part_of_client,
- is_call,
- ..
- } => {
- let type_name = if *part_of_client {
- if *is_call {
- Some(quote! {
- <#msg_type as ::btrun::CallMsg>::Reply
- })
- } else {
- None
- }
- } else {
- Some(quote! { #msg_type })
- };
- (None, type_name)
- }
- };
- Self {
- type_name,
- decl,
- kind,
- }
- }
- pub(crate) fn type_name(&self) -> Option<&TokenStream> {
- self.type_name.as_ref()
- }
- pub(crate) fn decl(&self) -> Option<&TokenStream> {
- self.decl.as_ref()
- }
- }
- #[cfg_attr(test, derive(Debug))]
- pub(crate) enum OutputKind {
- State {
- def: Rc<State>,
- },
- Msg {
- #[allow(dead_code)]
- def: Rc<Dest>,
- msg_type: Rc<TokenStream>,
- is_call: bool,
- part_of_client: bool,
- },
- }
- pub(crate) struct ActorLookup {
- actor_states: HashMap<Rc<Ident>, HashSet<Rc<Ident>>>,
- }
- impl ActorLookup {
- fn new<'a>(actor_defs: impl IntoIterator<Item = &'a ActorDef>) -> Self {
- let mut actor_states = HashMap::new();
- for actor_def in actor_defs.into_iter() {
- let mut states = HashSet::new();
- for state in actor_def.states.as_ref().iter() {
- states.insert(state.clone());
- }
- actor_states.insert(actor_def.actor.clone(), states);
- }
- Self { actor_states }
- }
- /// Returns the set of states associated with the given actor.
- ///
- /// This method **panics** if you call it with a non-existent actor name.
- pub(crate) fn actor_states(&self, actor_name: &Ident) -> &HashSet<Rc<Ident>> {
- self.actor_states.get(actor_name).unwrap_or_else(|| {
- panic!("Unknown actor. This indicates there is a bug in the btproto crate.")
- })
- }
- }
- pub(crate) struct MsgLookup {
- messages: HashMap<Rc<Ident>, MsgInfo>,
- }
- impl MsgLookup {
- fn new<'a>(transitions: impl IntoIterator<Item = &'a Transition>) -> Self {
- let mut messages = HashMap::new();
- for transition in transitions.into_iter() {
- let in_state = &transition.in_state.state_trait;
- if let Some(in_msg) = transition.in_msg() {
- let msg_name = &in_msg.msg_type;
- let msg_info = messages
- .entry(msg_name.clone())
- .or_insert_with(|| MsgInfo::empty(in_msg.clone()));
- msg_info.record_receiver(in_msg, in_state.clone());
- }
- for dest in transition.out_msgs.as_ref().iter() {
- let msg = &dest.msg;
- let msg_info = messages
- .entry(msg.msg_type.clone())
- .or_insert_with(|| MsgInfo::empty(msg.clone()));
- msg_info.record_sender(&dest.msg, in_state.clone());
- }
- }
- Self { messages }
- }
- pub(crate) fn lookup(&self, msg: &Message) -> &MsgInfo {
- self.messages
- .get(msg.msg_type.as_ref())
- // Since a message is either sent or received, and we've added all such messages in
- // the new method, this unwrap should not panic.
- .unwrap_or_else(|| {
- panic!("Failed to find message info. There is a bug in MessageLookup::new.")
- })
- .info_for(msg)
- }
- pub(crate) fn msg_iter(&self) -> impl Iterator<Item = &MsgInfo> {
- self.messages
- .values()
- .flat_map(|msg_info| [Some(msg_info), msg_info.reply()])
- .flatten()
- }
- }
- impl AsRef<HashMap<Rc<Ident>, MsgInfo>> for MsgLookup {
- fn as_ref(&self) -> &HashMap<Rc<Ident>, MsgInfo> {
- &self.messages
- }
- }
- #[cfg_attr(test, derive(Debug))]
- pub(crate) struct MsgInfo {
- def: Rc<Message>,
- msg_name: Rc<Ident>,
- msg_type: Rc<TokenStream>,
- is_reply: bool,
- senders: HashSet<Rc<Ident>>,
- receivers: HashSet<Rc<Ident>>,
- reply: Option<Box<MsgInfo>>,
- }
- impl MsgInfo {
- fn empty(def: Rc<Message>) -> Self {
- let msg_name = def.msg_type.as_ref();
- Self {
- msg_name: def.msg_type.clone(),
- msg_type: Rc::new(quote! { #msg_name }),
- is_reply: false,
- senders: HashSet::new(),
- receivers: HashSet::new(),
- reply: None,
- def,
- }
- }
- fn info_for(&self, msg: &Message) -> &Self {
- if msg.is_reply() {
- // If this message is a reply, then we should have seen it in MsgLookup::new and
- // initialized its reply pointer. So this unwrap shouldn't panic.
- self.reply.as_ref().unwrap_or_else(|| {
- panic!(
- "A reply message was not properly recorded. There is a bug in MessageLookup::new."
- )
- })
- } else {
- self
- }
- }
- fn info_for_mut(&mut self, msg: &Rc<Message>) -> &mut Self {
- if msg.is_reply() {
- self.reply.get_or_insert_with(|| {
- let mut reply = MsgInfo::empty(msg.clone());
- reply.is_reply = true;
- reply.msg_name = Rc::new(format_ident!(
- "{}{}",
- msg.msg_type.as_ref(),
- MessageReplyPart::REPLY_IDENT
- ));
- let parent = self.msg_name.as_ref();
- reply.msg_type = Rc::new(quote! { <#parent as ::btrun::CallMsg>::Reply });
- Box::new(reply)
- })
- } else {
- self
- }
- }
- fn record_receiver(&mut self, msg: &Rc<Message>, receiver: Rc<Ident>) {
- let target = self.info_for_mut(msg);
- target.receivers.insert(receiver);
- }
- fn record_sender(&mut self, msg: &Rc<Message>, sender: Rc<Ident>) {
- let target = self.info_for_mut(msg);
- target.senders.insert(sender);
- }
- pub(crate) fn def(&self) -> &Rc<Message> {
- &self.def
- }
- /// The unique name of this message. If it is a reply, it will end in
- /// `MessageReplyPart::REPLY_IDENT`.
- pub(crate) fn msg_name(&self) -> &Rc<Ident> {
- &self.msg_name
- }
- /// The type of this message. If this message is not a reply, this is just `msg_name`. If it is
- /// a reply, it is `<#msg_name as ::btrun::CallMsg>::Reply`.
- pub(crate) fn msg_type(&self) -> &Rc<TokenStream> {
- &self.msg_type
- }
- pub(crate) fn reply(&self) -> Option<&MsgInfo> {
- self.reply.as_ref().map(|ptr| ptr.as_ref())
- }
- pub(crate) fn is_reply(&self) -> bool {
- self.is_reply
- }
- /// Returns true iff this message is a call.
- pub(crate) fn is_call(&self) -> bool {
- self.reply.is_some()
- }
- }
- impl PartialEq for MsgInfo {
- fn eq(&self, other: &Self) -> bool {
- self.msg_name.as_ref() == other.msg_name.as_ref()
- }
- }
- impl Eq for MsgInfo {}
- impl Hash for MsgInfo {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- self.msg_name.hash(state)
- }
- }
- #[cfg(test)]
- mod tests {
- use crate::{
- error::{assert_err, assert_ok},
- parsing::{DestinationState, NameDef},
- };
- use super::*;
- #[test]
- fn protocol_model_new_minimal_ok() {
- let input = Protocol::minimal();
- let result = ProtocolModel::new(input);
- assert_ok(result);
- }
- #[test]
- fn protocol_model_new_dup_recv_transition_err() {
- let input = Protocol::new(
- NameDef::new("Undeclared"),
- [ActorDef::new("actor", ["Init", "Next"])],
- [
- Transition::new(
- State::new("Init", []),
- Some(Message::new("Query", false, [])),
- [State::new("Next", [])],
- [],
- ),
- Transition::new(
- State::new("Init", []),
- Some(Message::new("Query", false, [])),
- [State::new("Init", [])],
- [],
- ),
- ],
- );
- let result = ProtocolModel::new(input);
- assert_err(result, error::msgs::DUPLICATE_TRANSITION);
- }
- #[test]
- fn protocol_model_new_dup_send_transition_err() {
- let input = Protocol::new(
- NameDef::new("Undeclared"),
- [
- ActorDef::new("server", ["Init", "Next"]),
- ActorDef::new("client", ["Client"]),
- ],
- [
- Transition::new(
- State::new("Client", []),
- None,
- [State::new("Client", [])],
- [Dest::new(
- DestinationState::Individual(State::new("Init", [])),
- Message::new("Query", false, []),
- )],
- ),
- Transition::new(
- State::new("Client", []),
- None,
- [State::new("Client", [])],
- [Dest::new(
- DestinationState::Individual(State::new("Next", [])),
- Message::new("Query", false, []),
- )],
- ),
- ],
- );
- let result = ProtocolModel::new(input);
- assert_err(result, error::msgs::DUPLICATE_TRANSITION);
- }
- #[test]
- fn msg_sent_or_received_msg_received_ok() {
- let input = Protocol::new(
- NameDef::new("Test"),
- [ActorDef::new("actor", ["Init"])],
- [Transition::new(
- State::new("Init", []),
- Some(Message::new("Activate", false, [])),
- [State::new("End", [])],
- [],
- )],
- );
- let result = ProtocolModel::new(input);
- assert_ok(result);
- }
- #[test]
- fn msg_sent_or_received_msg_sent_ok() {
- let input = Protocol::new(
- NameDef::new("Test"),
- [ActorDef::new("actor", ["First", "Second"])],
- [Transition::new(
- State::new("First", []),
- None,
- [State::new("First", [])],
- [Dest::new(
- DestinationState::Individual(State::new("Second", [])),
- Message::new("Msg", false, []),
- )],
- )],
- );
- let result = ProtocolModel::new(input);
- assert_ok(result);
- }
- #[test]
- fn msg_sent_or_received_neither_err() {
- let input = Protocol::new(
- NameDef::new("Test"),
- [ActorDef::new("actor", ["First"])],
- [Transition::new(
- State::new("First", []),
- None,
- [State::new("First", [])],
- [],
- )],
- );
- let result = ProtocolModel::new(input);
- assert_err(result, error::msgs::NO_MSG_SENT_OR_RECEIVED);
- }
- #[test]
- fn reply_is_marked_in_output() {
- const MSG: &str = "Ping";
- let input = Protocol::new(
- NameDef::new("ReplyTest"),
- [
- ActorDef::new("server", ["Listening"]),
- ActorDef::new("client", ["Client", "Waiting"]),
- ],
- [
- Transition::new(
- State::new("Client", []),
- None,
- [State::new("Waiting", [])],
- [Dest::new(
- DestinationState::Service(State::new("Listening", [])),
- Message::new(MSG, false, []),
- )],
- ),
- Transition::new(
- State::new("Listening", []),
- Some(Message::new(MSG, false, [])),
- [State::new("Listening", [])],
- [Dest::new(
- DestinationState::Individual(State::new("Client", [])),
- Message::new(MSG, true, []),
- )],
- ),
- Transition::new(
- State::new("Waiting", []),
- Some(Message::new(MSG, true, [])),
- [State::new(End::ident(), [])],
- [],
- ),
- ],
- );
- let actual = ProtocolModel::new(input).unwrap();
- let outputs: Vec<_> = actual
- .outputs_iter()
- .map(|output| {
- if let OutputKind::Msg { is_call, .. } = output.kind {
- Some(is_call)
- } else {
- None
- }
- })
- .filter(|x| x.is_some())
- .map(|x| x.unwrap())
- .collect();
- assert_eq!(2, outputs.len());
- assert_eq!(1, outputs.iter().filter(|is_reply| **is_reply).count());
- assert_eq!(1, outputs.iter().filter(|is_reply| !*is_reply).count());
- }
- }
|