use std::collections::{HashMap, HashSet}; use crate::parsing::{Message, Transition}; use btrun::{Activate, End}; use super::Protocol; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote, ToTokens}; impl ToTokens for Protocol { fn to_tokens(&self, tokens: &mut TokenStream) { tokens.extend(self.generate_message_enum()); tokens.extend(self.generate_state_traits()); } } impl Protocol { fn generate_message_enum(&self) -> TokenStream { let mut msgs: HashSet<&Message> = HashSet::new(); for transition in self.transitions.iter() { // We only need to insert received messages because every sent message has a // corresponding receiver thanks to the validator. if let Some(msg) = transition.in_msg() { msgs.insert(msg); } } let variants = msgs.iter().map(|msg| msg.ident()); let msg_types = msgs.iter().map(|msg| msg.type_tokens()); let enum_name = format_ident!("{}Msgs", self.name_def.name); quote! { pub enum #enum_name { #( #variants(#msg_types) ),* } } } fn generate_state_traits(&self) -> TokenStream { let mut traits: HashMap<&Ident, Vec<&Transition>> = HashMap::new(); for transition in self.transitions.iter() { let vec = traits .entry(&transition.in_state.state_trait) .or_insert_with(Vec::new); vec.push(transition); } let mut tokens = TokenStream::new(); for (trait_ident, transitions) in traits { let transition_tokens = transitions.iter().map(|x| x.generate_tokens()); quote! { pub trait #trait_ident { #( #transition_tokens )* } } .to_tokens(&mut tokens); } tokens } } impl Message { /// Returns the tokens which represent the type of this message. fn type_tokens(&self) -> TokenStream { // We generate a fully-qualified path to the Activate message so that it doesn't need to be // imported for every protocol definition. let msg_type = if self.msg_type == Activate::ident() { quote! { ::btrun::Activate } } else { let msg_type = &self.msg_type; quote! { #msg_type } }; if self.is_reply() { quote! { <#msg_type as ::btrun::CallMsg>::Reply } } else { quote! { #msg_type } } } } impl Transition { /// Generates the tokens for the code which implements this transition. fn generate_tokens(&self) -> TokenStream { let (msg_arg, method_ident) = if let Some(msg) = self.in_msg() { let msg_type = if msg.msg_type == Activate::ident() { quote! { ::btrun::Activate } } else { msg.msg_type.to_token_stream() }; let method_ident = format_ident!("handle_{}", msg.ident().pascal_to_snake()); let msg_arg = quote! { , msg: #msg_type }; (msg_arg, method_ident) } else { let msg_arg = quote! {}; let msg_names = self .out_msgs .as_ref() .iter() .fold(Option::::None, |accum, curr| { let msg_name = curr.msg.ident().pascal_to_snake(); if let Some(mut accum) = accum { accum.push('_'); accum.push_str(&msg_name); Some(accum) } else { Some(msg_name) } }) // Since no message is being handled, the validator ensures that at least one // message is being sent. Hence this unwrap will not panic. .unwrap(); let method_ident = format_ident!("send_{}", msg_names); (msg_arg, method_ident) }; let method_type_prefix = method_ident.snake_to_pascal(); let output_pairs: Vec<_> = self .out_states .as_ref() .iter() .map(|state| { let state_trait = &state.state_trait; if state_trait == End::ident() { (quote! {}, quote! { ::btrun::End }) } else { let assoc_type = format_ident!("{}{}", method_type_prefix, state_trait); let output_decl = quote! { type #assoc_type: #state_trait; }; let output_type = quote! { Self::#assoc_type }; (output_decl, output_type) } }) .collect(); let output_decls = output_pairs.iter().map(|(decl, _)| decl); let output_types = output_pairs.iter().map(|(_, output_type)| output_type); let future_name = format_ident!("{}Fut", method_type_prefix); quote! { #( #output_decls )* type #future_name: ::std::future::Future>; fn #method_ident(self #msg_arg) -> Self::#future_name; } } } trait CaseConvert { /// Converts a name in snake_case to PascalCase. fn snake_to_pascal(&self) -> String; /// Converts a name in PascalCase to snake_case. fn pascal_to_snake(&self) -> String; } impl CaseConvert for String { fn snake_to_pascal(&self) -> String { let mut pascal = String::with_capacity(self.len()); let mut prev_underscore = true; for c in self.chars() { if '_' == c { prev_underscore = true; } else { if prev_underscore { pascal.extend(c.to_uppercase()); } else { pascal.push(c); } prev_underscore = false; } } pascal } fn pascal_to_snake(&self) -> String { let mut snake = String::with_capacity(self.len()); let mut prev_lower = false; for c in self.chars() { if c.is_uppercase() { if prev_lower { snake.push('_'); } snake.extend(c.to_lowercase()); prev_lower = false; } else { prev_lower = true; snake.push(c); } } snake } } impl CaseConvert for Ident { fn snake_to_pascal(&self) -> String { self.to_string().snake_to_pascal() } fn pascal_to_snake(&self) -> String { self.to_string().pascal_to_snake() } } #[cfg(test)] mod tests { use super::*; #[test] fn string_snake_to_pascal_multiple_segments() { const EXPECTED: &str = "FirstSecondThird"; let input = String::from("first_second_third"); let actual = input.snake_to_pascal(); assert_eq!(EXPECTED, actual); } #[test] fn string_snake_to_pascal_single_segment() { const EXPECTED: &str = "First"; let input = String::from("first"); let actual = input.snake_to_pascal(); assert_eq!(EXPECTED, actual); } #[test] fn string_snake_to_pascal_empty_string() { const EXPECTED: &str = ""; let input = String::from(EXPECTED); let actual = input.snake_to_pascal(); assert_eq!(EXPECTED, actual); } #[test] fn string_snake_to_pascal_leading_underscore() { const EXPECTED: &str = "First"; let input = String::from("_first"); let actual = input.snake_to_pascal(); assert_eq!(EXPECTED, actual); } #[test] fn string_snake_to_pascal_leading_underscores() { const EXPECTED: &str = "First"; let input = String::from("__first"); let actual = input.snake_to_pascal(); assert_eq!(EXPECTED, actual); } #[test] fn string_snake_to_pascal_multiple_underscores() { const EXPECTED: &str = "FirstSecondThird"; let input = String::from("first__second___third"); let actual = input.snake_to_pascal(); assert_eq!(EXPECTED, actual); } #[test] fn string_pascal_to_snake_multiple_segments() { const EXPECTED: &str = "first_second_third"; let input = String::from("FirstSecondThird"); let actual = input.pascal_to_snake(); assert_eq!(EXPECTED, actual); } #[test] fn string_pascal_to_snake_single_segment() { let input = String::from("First"); const EXPECTED: &str = "first"; let actual = input.pascal_to_snake(); assert_eq!(EXPECTED, actual); } #[test] fn string_pascal_to_snake_empty_string() { const EXPECTED: &str = ""; let input = String::from(EXPECTED); let actual = input.pascal_to_snake(); assert_eq!(EXPECTED, actual); } #[test] fn string_pascal_to_snake_consecutive_uppercase() { const EXPECTED: &str = "kernel_mc"; let input = String::from("KernelMC"); let actual = input.pascal_to_snake(); assert_eq!(EXPECTED, actual); } }