validation.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. use std::collections::HashSet;
  2. use proc_macro2::{Ident, Span};
  3. use crate::{
  4. error::MaybeErr,
  5. parsing::{DestinationState, Message, State},
  6. Protocol,
  7. };
  8. impl Protocol {
  9. pub(crate) fn validate(&self) -> syn::Result<()> {
  10. self.all_states_declared_and_used()
  11. .combine(self.match_receivers_and_senders())
  12. .combine(self.no_undeliverable_msgs())
  13. .combine(self.valid_replies())
  14. .into()
  15. }
  16. const UNDECLARED_STATE_ERR: &str = "State was not declared.";
  17. const UNUSED_STATE_ERR: &str = "State was declared but never used.";
  18. const END_STATE: &str = "End";
  19. /// Verifies that every state which is used has been declared, except for the End state.
  20. fn all_states_declared_and_used(&self) -> MaybeErr {
  21. let end = Ident::new(Self::END_STATE, Span::call_site());
  22. let declared: HashSet<&Ident> = self
  23. .states_def
  24. .states
  25. .as_ref()
  26. .iter()
  27. .chain([&end].into_iter())
  28. .collect();
  29. let mut used: HashSet<&Ident> = HashSet::with_capacity(declared.len());
  30. for transition in self.transitions.iter() {
  31. let in_state = &transition.in_state;
  32. used.insert(&in_state.state_trait);
  33. used.extend(in_state.owned_states.as_ref().iter());
  34. if let Some(in_msg) = &transition.in_msg {
  35. used.extend(in_msg.owned_states.as_ref().iter());
  36. }
  37. for out_states in transition.out_states.as_ref().iter() {
  38. used.insert(&out_states.state_trait);
  39. used.extend(out_states.owned_states.as_ref().iter());
  40. }
  41. // We don't have to check the states referred to in out_msgs because the
  42. // match_receivers_and_senders method ensures that each of these exists in a receiver
  43. // position.
  44. }
  45. let undeclared: MaybeErr = used
  46. .difference(&declared)
  47. .map(|ident| syn::Error::new(ident.span(), Self::UNDECLARED_STATE_ERR))
  48. .collect();
  49. let unused: MaybeErr = declared
  50. .difference(&used)
  51. .filter(|ident| **ident != Self::END_STATE)
  52. .map(|ident| syn::Error::new(ident.span(), Self::UNUSED_STATE_ERR))
  53. .collect();
  54. undeclared.combine(unused)
  55. }
  56. const UNMATCHED_SENDER_ERR: &str = "No receiver found for message type.";
  57. const UNMATCHED_RECEIVER_ERR: &str = "No sender found for message type.";
  58. const ACTIVATE_MSG: &str = "Activate";
  59. /// Ensures that the recipient state for every sent message has a receiving transition
  60. /// defined, and every receiver has a sender (except for the Activate message which is sent
  61. /// by the runtime).
  62. fn match_receivers_and_senders(&self) -> MaybeErr {
  63. let mut senders: HashSet<(&State, &Message)> = HashSet::new();
  64. let mut receivers: HashSet<(&State, &Message)> = HashSet::new();
  65. for transition in self.transitions.iter() {
  66. if let Some(msg) = &transition.in_msg {
  67. receivers.insert((&transition.in_state, msg));
  68. if msg.msg_type == Self::ACTIVATE_MSG {
  69. // The Activate message is sent by the run time, so a sender is created to
  70. // represent it.
  71. senders.insert((&transition.in_state, msg));
  72. }
  73. }
  74. for dest in transition.out_msgs.as_ref().iter() {
  75. let dest_state = match &dest.state {
  76. DestinationState::Individual(dest_state) => dest_state,
  77. DestinationState::Service(dest_state) => dest_state,
  78. };
  79. senders.insert((dest_state, &dest.msg));
  80. }
  81. }
  82. let extra_senders: MaybeErr = senders
  83. .difference(&receivers)
  84. .map(|pair| syn::Error::new(pair.1.msg_type.span(), Self::UNMATCHED_SENDER_ERR))
  85. .collect();
  86. let extra_receivers: MaybeErr = receivers
  87. .difference(&senders)
  88. .map(|pair| syn::Error::new(pair.1.msg_type.span(), Self::UNMATCHED_RECEIVER_ERR))
  89. .collect();
  90. extra_senders.combine(extra_receivers)
  91. }
  92. const UNDELIVERABLE_ERR: &str =
  93. "Receiver must either be a service, an owned state, or an out state, or the message must be a reply.";
  94. /// Checks that messages are only sent to destinations which are either services, states
  95. /// which are owned by the sender, listed in the output states, or that the message is a
  96. /// reply.
  97. fn no_undeliverable_msgs(&self) -> MaybeErr {
  98. let mut err = MaybeErr::none();
  99. for transition in self.transitions.iter() {
  100. let mut allowed_states: Option<HashSet<&Ident>> = None;
  101. for dest in transition.out_msgs.as_ref().iter() {
  102. if dest.msg.is_reply {
  103. continue;
  104. }
  105. match &dest.state {
  106. DestinationState::Service(_) => continue,
  107. DestinationState::Individual(dest_state) => {
  108. let allowed = allowed_states.get_or_insert_with(|| {
  109. transition
  110. .out_states
  111. .as_ref()
  112. .iter()
  113. .map(|state| &state.state_trait)
  114. .chain(transition.in_state.owned_states.as_ref().iter())
  115. .collect()
  116. });
  117. if !allowed.contains(&dest_state.state_trait) {
  118. err = err.combine(
  119. syn::Error::new(
  120. dest_state.state_trait.span(),
  121. Self::UNDELIVERABLE_ERR,
  122. )
  123. .into(),
  124. );
  125. }
  126. }
  127. }
  128. }
  129. }
  130. err
  131. }
  132. const INVALID_REPLY_ERR: &str =
  133. "Replies can only be used in transitions which handle messages.";
  134. const MULTIPLE_REPLIES_ERR: &str =
  135. "Only a single reply can be sent in response to any message.";
  136. /// Verifies that replies are only sent in response to messages.
  137. fn valid_replies(&self) -> MaybeErr {
  138. let mut err = MaybeErr::none();
  139. for transition in self.transitions.iter() {
  140. let replies: Vec<_> = transition
  141. .out_msgs
  142. .as_ref()
  143. .iter()
  144. .map(|dest| &dest.msg)
  145. .filter(|msg| msg.is_reply)
  146. .collect();
  147. if replies.is_empty() {
  148. continue;
  149. }
  150. if replies.len() > 1 {
  151. err = err.combine(
  152. replies
  153. .iter()
  154. .map(|reply| {
  155. syn::Error::new(reply.msg_type.span(), Self::MULTIPLE_REPLIES_ERR)
  156. })
  157. .collect(),
  158. );
  159. }
  160. if transition.in_msg.is_none() {
  161. err = err.combine(
  162. replies
  163. .iter()
  164. .map(|reply| {
  165. syn::Error::new(reply.msg_type.span(), Self::INVALID_REPLY_ERR)
  166. })
  167. .collect(),
  168. );
  169. }
  170. }
  171. err
  172. }
  173. }
  174. #[cfg(test)]
  175. mod tests {
  176. use super::*;
  177. use syn::parse_str;
  178. macro_rules! assert_ok {
  179. ($maybe_err:expr) => {
  180. let result: syn::Result<()> = $maybe_err.into();
  181. assert!(result.is_ok(), "{}", result.err().unwrap());
  182. };
  183. }
  184. macro_rules! assert_err {
  185. ($maybe_err:expr, $expected_msg:expr) => {
  186. let result: syn::Result<()> = $maybe_err.into();
  187. assert!(result.is_err());
  188. assert_eq!($expected_msg, result.err().unwrap().to_string());
  189. };
  190. }
  191. /// A minimal valid protocol definition.
  192. const MIN_PROTOCOL: &str = "
  193. let name = Test;
  194. let states = [Init];
  195. Init?Activate -> End;
  196. ";
  197. #[test]
  198. fn all_states_declared_and_used_ok() {
  199. let result = parse_str::<Protocol>(MIN_PROTOCOL)
  200. .unwrap()
  201. .all_states_declared_and_used();
  202. assert_ok!(result);
  203. }
  204. #[test]
  205. fn all_states_declared_and_used_end_not_used_ok() {
  206. const INPUT: &str = "
  207. let name = Test;
  208. let states = [Init];
  209. Init?Activate -> Init;
  210. ";
  211. let result = parse_str::<Protocol>(INPUT)
  212. .unwrap()
  213. .all_states_declared_and_used();
  214. assert_ok!(result);
  215. }
  216. #[test]
  217. fn all_states_declared_and_used_undeclared_err() {
  218. const INPUT: &str = "
  219. let name = Undeclared;
  220. let states = [Init];
  221. Init?Activate -> Next;
  222. ";
  223. let result = parse_str::<Protocol>(INPUT)
  224. .unwrap()
  225. .all_states_declared_and_used();
  226. assert_err!(result, Protocol::UNDECLARED_STATE_ERR);
  227. }
  228. #[test]
  229. fn all_states_declared_and_used_undeclared_out_state_owned_err() {
  230. const INPUT: &str = "
  231. let name = Undeclared;
  232. let states = [Init, Next];
  233. Init?Activate -> Init, Next[Undeclared];
  234. ";
  235. let result = parse_str::<Protocol>(INPUT)
  236. .unwrap()
  237. .all_states_declared_and_used();
  238. assert_err!(result, Protocol::UNDECLARED_STATE_ERR);
  239. }
  240. #[test]
  241. fn all_states_declared_and_used_undeclared_in_state_owned_err() {
  242. const INPUT: &str = "
  243. let name = Undeclared;
  244. let states = [Init, Next];
  245. Init[Undeclared]?Activate -> Next;
  246. ";
  247. let result = parse_str::<Protocol>(INPUT)
  248. .unwrap()
  249. .all_states_declared_and_used();
  250. assert_err!(result, Protocol::UNDECLARED_STATE_ERR);
  251. }
  252. #[test]
  253. fn all_states_declared_and_used_unused_err() {
  254. const INPUT: &str = "
  255. let name = Unused;
  256. let states = [Init, Extra];
  257. Init?Activate -> End;
  258. ";
  259. let result = parse_str::<Protocol>(INPUT)
  260. .unwrap()
  261. .all_states_declared_and_used();
  262. assert_err!(result, Protocol::UNUSED_STATE_ERR);
  263. }
  264. #[test]
  265. fn match_receivers_and_senders_ok() {
  266. let result = parse_str::<Protocol>(MIN_PROTOCOL)
  267. .unwrap()
  268. .match_receivers_and_senders();
  269. assert_ok!(result);
  270. }
  271. #[test]
  272. fn match_receivers_and_senders_send_activate_ok() {
  273. const INPUT: &str = "
  274. let name = Unbalanced;
  275. let states = [First, Second];
  276. First?Activate -> First, >Second!Activate;
  277. Second?Activate -> Second;
  278. ";
  279. let result = parse_str::<Protocol>(INPUT)
  280. .unwrap()
  281. .match_receivers_and_senders();
  282. assert_ok!(result);
  283. }
  284. #[test]
  285. fn match_receivers_and_senders_unmatched_sender_err() {
  286. const INPUT: &str = "
  287. let name = Unbalanced;
  288. let states = [Init, Other];
  289. Init?Activate -> Init, >Other!Activate;
  290. ";
  291. let result = parse_str::<Protocol>(INPUT)
  292. .unwrap()
  293. .match_receivers_and_senders();
  294. assert_err!(result, Protocol::UNMATCHED_SENDER_ERR);
  295. }
  296. #[test]
  297. fn match_receivers_and_senders_unmatched_receiver_err() {
  298. const INPUT: &str = "
  299. let name = Unbalanced;
  300. let states = [Init];
  301. Init?NotExists -> Init;
  302. ";
  303. let result = parse_str::<Protocol>(INPUT)
  304. .unwrap()
  305. .match_receivers_and_senders();
  306. assert_err!(result, Protocol::UNMATCHED_RECEIVER_ERR);
  307. }
  308. #[test]
  309. fn no_undeliverable_msgs_ok() {
  310. let result = parse_str::<Protocol>(MIN_PROTOCOL)
  311. .unwrap()
  312. .no_undeliverable_msgs();
  313. assert_ok!(result);
  314. }
  315. #[test]
  316. fn no_undeliverable_msgs_reply_ok() {
  317. const INPUT: &str = "
  318. let name = Undeliverable;
  319. let states = [Listening, Client];
  320. Listening?Msg -> Listening, >Client!Msg::Reply;
  321. ";
  322. let result = parse_str::<Protocol>(INPUT)
  323. .unwrap()
  324. .no_undeliverable_msgs();
  325. assert_ok!(result);
  326. }
  327. #[test]
  328. fn no_undeliverable_msgs_service_ok() {
  329. const INPUT: &str = "
  330. let name = Undeliverable;
  331. let states = [Client, Server];
  332. Client -> Client, >service(Server)!Msg;
  333. ";
  334. let result = parse_str::<Protocol>(INPUT)
  335. .unwrap()
  336. .no_undeliverable_msgs();
  337. assert_ok!(result);
  338. }
  339. #[test]
  340. fn no_undeliverable_msgs_owned_ok() {
  341. const INPUT: &str = "
  342. let name = Undeliverable;
  343. let states = [FileClient, FileHandle];
  344. FileClient[FileHandle] -> FileClient, >FileHandle!FileOp;
  345. ";
  346. let result = parse_str::<Protocol>(INPUT)
  347. .unwrap()
  348. .no_undeliverable_msgs();
  349. assert_ok!(result);
  350. }
  351. #[test]
  352. fn no_undeliverable_msgs_err() {
  353. const INPUT: &str = "
  354. let name = Undeliverable;
  355. let states = [Client, Server];
  356. Client -> Client, >Server!Msg;
  357. ";
  358. let result = parse_str::<Protocol>(INPUT)
  359. .unwrap()
  360. .no_undeliverable_msgs();
  361. assert_err!(result, Protocol::UNDELIVERABLE_ERR);
  362. }
  363. #[test]
  364. fn valid_replies_ok() {
  365. const INPUT: &str = "
  366. let name = ValidReplies;
  367. let states = [Client, Server];
  368. Server?Msg -> Server, >Client!Msg::Reply;
  369. ";
  370. let result = parse_str::<Protocol>(INPUT).unwrap().valid_replies();
  371. assert_ok!(result);
  372. }
  373. #[test]
  374. fn valid_replies_invalid_reply_err() {
  375. const INPUT: &str = "
  376. let name = ValidReplies;
  377. let states = [Client, Server];
  378. Client -> Client, >Server!Msg::Reply;
  379. ";
  380. let result = parse_str::<Protocol>(INPUT).unwrap().valid_replies();
  381. assert_err!(result, Protocol::INVALID_REPLY_ERR);
  382. }
  383. #[test]
  384. fn valid_replies_multiple_replies_err() {
  385. const INPUT: &str = "
  386. let name = ValidReplies;
  387. let states = [Client, OtherClient, Server];
  388. Server?Msg -> Server, >Client!Msg::Reply, OtherClient!Msg::Reply;
  389. ";
  390. let result = parse_str::<Protocol>(INPUT).unwrap().valid_replies();
  391. assert_err!(result, Protocol::MULTIPLE_REPLIES_ERR);
  392. }
  393. }