| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770 | use std::{    collections::{HashMap, HashSet},    hash::Hash,    rc::Rc,};use btrun::End;use proc_macro2::{Ident, Span, TokenStream};use quote::{format_ident, quote, ToTokens};use crate::{    case_convert::CaseConvert,    error,    parsing::MessageReplyPart,    parsing::{ActorDef, Dest, GetSpan, Message, Protocol, State, Transition},};pub(crate) struct ProtocolModel {    def: Protocol,    msg_lookup: MsgLookup,    actors: HashMap<Rc<Ident>, ActorModel>,}impl ProtocolModel {    pub(crate) fn new(def: Protocol) -> syn::Result<Self> {        let actor_lookup = ActorLookup::new(def.actor_defs.iter().map(|x| x.as_ref()));        let msg_lookup = MsgLookup::new(def.transitions.iter().map(|x| x.as_ref()));        let mut actors = HashMap::new();        for actor_def in def.actor_defs.iter() {            let actor_name = &actor_def.actor;            let actor_states = actor_lookup.actor_states(actor_name.as_ref());            let transitions_by_state = actor_states.iter().map(|state_name| {                let transitions = def                    .transitions                    .iter()                    .filter(|transition| {                        state_name.as_ref() == transition.in_state.state_trait.as_ref()                    })                    .cloned();                (state_name.clone(), transitions)            });            let actor = ActorModel::new(actor_def.clone(), &msg_lookup, transitions_by_state)?;            actors.insert(actor_name.clone(), actor);        }        Ok(Self {            def,            msg_lookup,            actors,        })    }    pub(crate) fn def(&self) -> &Protocol {        &self.def    }    pub(crate) fn msg_lookup(&self) -> &MsgLookup {        &self.msg_lookup    }    pub(crate) fn actors_iter(&self) -> impl Iterator<Item = &ActorModel> {        self.actors.values()    }    pub(crate) fn states_iter(&self) -> impl Iterator<Item = &StateModel> {        self.actors_iter().flat_map(|actor| actor.states().values())    }    #[cfg(test)]    pub(crate) fn methods_iter(&self) -> impl Iterator<Item = &MethodModel> {        self.states_iter()            .flat_map(|state| state.methods().values())    }    #[cfg(test)]    pub(crate) fn outputs_iter(&self) -> impl Iterator<Item = &OutputModel> {        self.methods_iter()            .flat_map(|method| method.outputs().iter())    }}pub(crate) struct ActorModel {    #[allow(dead_code)]    def: Rc<ActorDef>,    is_client: bool,    states: HashMap<Rc<Ident>, StateModel>,}impl ActorModel {    fn new<S, T>(def: Rc<ActorDef>, messages: &MsgLookup, state_iter: S) -> syn::Result<Self>    where        S: IntoIterator<Item = (Rc<Ident>, T)>,        T: IntoIterator<Item = Rc<Transition>>,    {        let transitions: HashMap<_, Vec<_>> = state_iter            .into_iter()            .map(|(name, transitions)| (name, transitions.into_iter().collect()))            .collect();        let is_client = transitions            .values()            .flatten()            .any(|transition| transition.is_client());        let mut states = HashMap::new();        for (name, transitions) in transitions.into_iter() {            let state = StateModel::new(name.clone(), messages, transitions, is_client)?;            if let Some(prev) = states.insert(name, state) {                panic!(                    "States are not being grouped by actor correctly. Duplicate state name: '{}'",                    prev.name                );            }        }        Ok(Self {            def,            is_client,            states,        })    }    pub(crate) fn is_client(&self) -> bool {        self.is_client    }    pub(crate) fn states(&self) -> &HashMap<Rc<Ident>, StateModel> {        &self.states    }}pub(crate) struct StateModel {    name: Rc<Ident>,    methods: HashMap<Rc<Ident>, MethodModel>,}impl StateModel {    fn new<T>(        name: Rc<Ident>,        messages: &MsgLookup,        transitions: T,        part_of_client: bool,    ) -> syn::Result<Self>    where        T: IntoIterator<Item = Rc<Transition>>,    {        let mut methods = HashMap::new();        for transition in transitions.into_iter() {            let transition_span = transition.span();            let method = MethodModel::new(transition, messages, part_of_client)?;            if methods.insert(method.name.clone(), method).is_some() {                return Err(syn::Error::new(                    transition_span,                    error::msgs::DUPLICATE_TRANSITION,                ));            }        }        Ok(Self { name, methods })    }    pub(crate) fn name(&self) -> &Ident {        self.name.as_ref()    }    pub(crate) fn methods(&self) -> &HashMap<Rc<Ident>, MethodModel> {        &self.methods    }}impl GetSpan for StateModel {    fn span(&self) -> Span {        self.name.span()    }}#[cfg_attr(test, derive(Debug))]pub(crate) struct MethodModel {    def: Rc<Transition>,    name: Rc<Ident>,    inputs: Vec<InputModel>,    outputs: Vec<OutputModel>,    future: Ident,}impl MethodModel {    fn new(def: Rc<Transition>, messages: &MsgLookup, part_of_client: bool) -> syn::Result<Self> {        let name = Rc::new(Self::new_name(def.as_ref())?);        let type_prefix = name.snake_to_pascal();        Ok(Self {            name,            inputs: Self::new_inputs(def.as_ref(), messages, part_of_client),            outputs: Self::new_outputs(def.as_ref(), &type_prefix, messages, part_of_client),            future: format_ident!("{type_prefix}Fut"),            def,        })    }    fn new_name(def: &Transition) -> syn::Result<Ident> {        let name = if let Some(msg) = def.in_msg() {            format_ident!("handle_{}", msg.variant().pascal_to_snake())        } else {            let mut dests = def.out_msgs.as_ref().iter();            let mut msg_names = String::new();            if let Some(dest) = dests.next() {                msg_names.push_str(dest.msg.variant().pascal_to_snake().as_str());            } else {                return Err(syn::Error::new(                    def.span(),                    error::msgs::NO_MSG_SENT_OR_RECEIVED,                ));            }            for dest in dests {                msg_names.push('_');                msg_names.push_str(dest.msg.variant().pascal_to_snake().as_str());            }            format_ident!("send_{msg_names}")        };        Ok(name)    }    fn new_inputs(def: &Transition, messages: &MsgLookup, part_of_client: bool) -> Vec<InputModel> {        let mut inputs = Vec::new();        if let Some(in_msg) = def.in_msg() {            let msg_info = messages.lookup(in_msg);            inputs.push(InputModel::new(                msg_info.msg_name().clone(),                msg_info.msg_type.clone(),            ))        }        if part_of_client {            for out_msg in def.out_msgs.as_ref().iter() {                let msg_info = messages.lookup(&out_msg.msg);                inputs.push(InputModel::new(                    msg_info.msg_name().clone(),                    msg_info.msg_type.clone(),                ))            }        }        inputs    }    fn new_outputs(        def: &Transition,        type_prefix: &str,        messages: &MsgLookup,        part_of_client: bool,    ) -> Vec<OutputModel> {        let mut outputs = Vec::new();        for state in def.out_states.as_ref().iter() {            outputs.push(OutputModel::new(                OutputKind::State { def: state.clone() },                type_prefix,            ));        }        for dest in def.out_msgs.as_ref().iter() {            let msg_info = messages.lookup(&dest.msg);            outputs.push(OutputModel::new(                OutputKind::Msg {                    def: dest.clone(),                    msg_type: msg_info.msg_type.clone(),                    is_call: msg_info.is_call(),                    part_of_client,                },                type_prefix,            ))        }        outputs    }    pub(crate) fn def(&self) -> &Transition {        self.def.as_ref()    }    pub(crate) fn name(&self) -> &Rc<Ident> {        &self.name    }    pub(crate) fn inputs(&self) -> &Vec<InputModel> {        &self.inputs    }    pub(crate) fn outputs(&self) -> &Vec<OutputModel> {        &self.outputs    }    pub(crate) fn future(&self) -> &Ident {        &self.future    }}impl GetSpan for MethodModel {    fn span(&self) -> Span {        self.def.span()    }}#[cfg_attr(test, derive(Debug))]pub(crate) struct InputModel {    name: Ident,    arg_type: Rc<TokenStream>,}impl InputModel {    fn new(type_name: Rc<Ident>, arg_type: Rc<TokenStream>) -> Self {        let name = format_ident!("{}_arg", type_name.to_string().pascal_to_snake());        Self { name, arg_type }    }}impl ToTokens for InputModel {    fn to_tokens(&self, tokens: &mut TokenStream) {        let name = &self.name;        let arg_type = self.arg_type.as_ref();        tokens.extend(quote! { #name : #arg_type })    }}#[cfg_attr(test, derive(Debug))]pub(crate) struct OutputModel {    type_name: Option<TokenStream>,    decl: Option<TokenStream>,    #[allow(dead_code)]    kind: OutputKind,}impl OutputModel {    fn new(kind: OutputKind, type_prefix: &str) -> Self {        let (decl, type_name) = match &kind {            OutputKind::State { def, .. } => {                let state_trait = def.state_trait.as_ref();                if state_trait == End::ident() {                    let end_ident = format_ident!("{}", End::ident());                    (None, Some(quote! { ::btrun::#end_ident }))                } else {                    let type_name = format_ident!("{type_prefix}{}", state_trait);                    (                        Some(quote! { type  #type_name: #state_trait; }),                        Some(quote! { Self::#type_name }),                    )                }            }            OutputKind::Msg {                msg_type,                part_of_client,                is_call,                ..            } => {                let type_name = if *part_of_client {                    if *is_call {                        Some(quote! {                            <#msg_type as ::btrun::CallMsg>::Reply                        })                    } else {                        None                    }                } else {                    Some(quote! { #msg_type })                };                (None, type_name)            }        };        Self {            type_name,            decl,            kind,        }    }    pub(crate) fn type_name(&self) -> Option<&TokenStream> {        self.type_name.as_ref()    }    pub(crate) fn decl(&self) -> Option<&TokenStream> {        self.decl.as_ref()    }}#[cfg_attr(test, derive(Debug))]pub(crate) enum OutputKind {    State {        def: Rc<State>,    },    Msg {        #[allow(dead_code)]        def: Rc<Dest>,        msg_type: Rc<TokenStream>,        is_call: bool,        part_of_client: bool,    },}pub(crate) struct ActorLookup {    actor_states: HashMap<Rc<Ident>, HashSet<Rc<Ident>>>,}impl ActorLookup {    fn new<'a>(actor_defs: impl IntoIterator<Item = &'a ActorDef>) -> Self {        let mut actor_states = HashMap::new();        for actor_def in actor_defs.into_iter() {            let mut states = HashSet::new();            for state in actor_def.states.as_ref().iter() {                states.insert(state.clone());            }            actor_states.insert(actor_def.actor.clone(), states);        }        Self { actor_states }    }    /// Returns the set of states associated with the given actor.    ///    /// This method **panics** if you call it with a non-existent actor name.    pub(crate) fn actor_states(&self, actor_name: &Ident) -> &HashSet<Rc<Ident>> {        self.actor_states.get(actor_name).unwrap_or_else(|| {            panic!("Unknown actor. This indicates there is a bug in the btproto crate.")        })    }}pub(crate) struct MsgLookup {    messages: HashMap<Rc<Ident>, MsgInfo>,}impl MsgLookup {    fn new<'a>(transitions: impl IntoIterator<Item = &'a Transition>) -> Self {        let mut messages = HashMap::new();        for transition in transitions.into_iter() {            let in_state = &transition.in_state.state_trait;            if let Some(in_msg) = transition.in_msg() {                let msg_name = &in_msg.msg_type;                let msg_info = messages                    .entry(msg_name.clone())                    .or_insert_with(|| MsgInfo::empty(in_msg.clone()));                msg_info.record_receiver(in_msg, in_state.clone());            }            for dest in transition.out_msgs.as_ref().iter() {                let msg = &dest.msg;                let msg_info = messages                    .entry(msg.msg_type.clone())                    .or_insert_with(|| MsgInfo::empty(msg.clone()));                msg_info.record_sender(&dest.msg, in_state.clone());            }        }        Self { messages }    }    pub(crate) fn lookup(&self, msg: &Message) -> &MsgInfo {        self.messages            .get(msg.msg_type.as_ref())            // Since a message is either sent or received, and we've added all such messages in            // the new method, this unwrap should not panic.            .unwrap_or_else(|| {                panic!("Failed to find message info. There is a bug in MessageLookup::new.")            })            .info_for(msg)    }    pub(crate) fn msg_iter(&self) -> impl Iterator<Item = &MsgInfo> {        self.messages            .values()            .flat_map(|msg_info| [Some(msg_info), msg_info.reply()])            .flatten()    }}impl AsRef<HashMap<Rc<Ident>, MsgInfo>> for MsgLookup {    fn as_ref(&self) -> &HashMap<Rc<Ident>, MsgInfo> {        &self.messages    }}#[cfg_attr(test, derive(Debug))]pub(crate) struct MsgInfo {    def: Rc<Message>,    msg_name: Rc<Ident>,    msg_type: Rc<TokenStream>,    is_reply: bool,    senders: HashSet<Rc<Ident>>,    receivers: HashSet<Rc<Ident>>,    reply: Option<Box<MsgInfo>>,}impl MsgInfo {    fn empty(def: Rc<Message>) -> Self {        let msg_name = def.msg_type.as_ref();        Self {            msg_name: def.msg_type.clone(),            msg_type: Rc::new(quote! { #msg_name }),            is_reply: false,            senders: HashSet::new(),            receivers: HashSet::new(),            reply: None,            def,        }    }    fn info_for(&self, msg: &Message) -> &Self {        if msg.is_reply() {            // If this message is a reply, then we should have seen it in MsgLookup::new and            // initialized its reply pointer. So this unwrap shouldn't panic.            self.reply.as_ref().unwrap_or_else(|| {                panic!(                "A reply message was not properly recorded. There is a bug in MessageLookup::new."            )            })        } else {            self        }    }    fn info_for_mut(&mut self, msg: &Rc<Message>) -> &mut Self {        if msg.is_reply() {            self.reply.get_or_insert_with(|| {                let mut reply = MsgInfo::empty(msg.clone());                reply.is_reply = true;                reply.msg_name = Rc::new(format_ident!(                    "{}{}",                    msg.msg_type.as_ref(),                    MessageReplyPart::REPLY_IDENT                ));                let parent = self.msg_name.as_ref();                reply.msg_type = Rc::new(quote! { <#parent as ::btrun::CallMsg>::Reply });                Box::new(reply)            })        } else {            self        }    }    fn record_receiver(&mut self, msg: &Rc<Message>, receiver: Rc<Ident>) {        let target = self.info_for_mut(msg);        target.receivers.insert(receiver);    }    fn record_sender(&mut self, msg: &Rc<Message>, sender: Rc<Ident>) {        let target = self.info_for_mut(msg);        target.senders.insert(sender);    }    pub(crate) fn def(&self) -> &Rc<Message> {        &self.def    }    /// The unique name of this message. If it is a reply, it will end in    /// `MessageReplyPart::REPLY_IDENT`.    pub(crate) fn msg_name(&self) -> &Rc<Ident> {        &self.msg_name    }    /// The type of this message. If this message is not a reply, this is just `msg_name`. If it is    /// a reply, it is `<#msg_name as ::btrun::CallMsg>::Reply`.    pub(crate) fn msg_type(&self) -> &Rc<TokenStream> {        &self.msg_type    }    pub(crate) fn reply(&self) -> Option<&MsgInfo> {        self.reply.as_ref().map(|ptr| ptr.as_ref())    }    pub(crate) fn is_reply(&self) -> bool {        self.is_reply    }    /// Returns true iff this message is a call.    pub(crate) fn is_call(&self) -> bool {        self.reply.is_some()    }}impl PartialEq for MsgInfo {    fn eq(&self, other: &Self) -> bool {        self.msg_name.as_ref() == other.msg_name.as_ref()    }}impl Eq for MsgInfo {}impl Hash for MsgInfo {    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {        self.msg_name.hash(state)    }}#[cfg(test)]mod tests {    use crate::{        error::{assert_err, assert_ok},        parsing::{DestinationState, NameDef},    };    use super::*;    #[test]    fn protocol_model_new_minimal_ok() {        let input = Protocol::minimal();        let result = ProtocolModel::new(input);        assert_ok(result);    }    #[test]    fn protocol_model_new_dup_recv_transition_err() {        let input = Protocol::new(            NameDef::new("Undeclared"),            [ActorDef::new("actor", ["Init", "Next"])],            [                Transition::new(                    State::new("Init", []),                    Some(Message::new("Query", false, [])),                    [State::new("Next", [])],                    [],                ),                Transition::new(                    State::new("Init", []),                    Some(Message::new("Query", false, [])),                    [State::new("Init", [])],                    [],                ),            ],        );        let result = ProtocolModel::new(input);        assert_err(result, error::msgs::DUPLICATE_TRANSITION);    }    #[test]    fn protocol_model_new_dup_send_transition_err() {        let input = Protocol::new(            NameDef::new("Undeclared"),            [                ActorDef::new("server", ["Init", "Next"]),                ActorDef::new("client", ["Client"]),            ],            [                Transition::new(                    State::new("Client", []),                    None,                    [State::new("Client", [])],                    [Dest::new(                        DestinationState::Individual(State::new("Init", [])),                        Message::new("Query", false, []),                    )],                ),                Transition::new(                    State::new("Client", []),                    None,                    [State::new("Client", [])],                    [Dest::new(                        DestinationState::Individual(State::new("Next", [])),                        Message::new("Query", false, []),                    )],                ),            ],        );        let result = ProtocolModel::new(input);        assert_err(result, error::msgs::DUPLICATE_TRANSITION);    }    #[test]    fn msg_sent_or_received_msg_received_ok() {        let input = Protocol::new(            NameDef::new("Test"),            [ActorDef::new("actor", ["Init"])],            [Transition::new(                State::new("Init", []),                Some(Message::new("Activate", false, [])),                [State::new("End", [])],                [],            )],        );        let result = ProtocolModel::new(input);        assert_ok(result);    }    #[test]    fn msg_sent_or_received_msg_sent_ok() {        let input = Protocol::new(            NameDef::new("Test"),            [ActorDef::new("actor", ["First", "Second"])],            [Transition::new(                State::new("First", []),                None,                [State::new("First", [])],                [Dest::new(                    DestinationState::Individual(State::new("Second", [])),                    Message::new("Msg", false, []),                )],            )],        );        let result = ProtocolModel::new(input);        assert_ok(result);    }    #[test]    fn msg_sent_or_received_neither_err() {        let input = Protocol::new(            NameDef::new("Test"),            [ActorDef::new("actor", ["First"])],            [Transition::new(                State::new("First", []),                None,                [State::new("First", [])],                [],            )],        );        let result = ProtocolModel::new(input);        assert_err(result, error::msgs::NO_MSG_SENT_OR_RECEIVED);    }    #[test]    fn reply_is_marked_in_output() {        const MSG: &str = "Ping";        let input = Protocol::new(            NameDef::new("ReplyTest"),            [                ActorDef::new("server", ["Listening"]),                ActorDef::new("client", ["Client", "Waiting"]),            ],            [                Transition::new(                    State::new("Client", []),                    None,                    [State::new("Waiting", [])],                    [Dest::new(                        DestinationState::Service(State::new("Listening", [])),                        Message::new(MSG, false, []),                    )],                ),                Transition::new(                    State::new("Listening", []),                    Some(Message::new(MSG, false, [])),                    [State::new("Listening", [])],                    [Dest::new(                        DestinationState::Individual(State::new("Client", [])),                        Message::new(MSG, true, []),                    )],                ),                Transition::new(                    State::new("Waiting", []),                    Some(Message::new(MSG, true, [])),                    [State::new(End::ident(), [])],                    [],                ),            ],        );        let actual = ProtocolModel::new(input).unwrap();        let outputs: Vec<_> = actual            .outputs_iter()            .map(|output| {                if let OutputKind::Msg { is_call, .. } = output.kind {                    Some(is_call)                } else {                    None                }            })            .filter(|x| x.is_some())            .map(|x| x.unwrap())            .collect();        assert_eq!(2, outputs.len());        assert_eq!(1, outputs.iter().filter(|is_reply| **is_reply).count());        assert_eq!(1, outputs.iter().filter(|is_reply| !*is_reply).count());    }}
 |