tests.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. // SPDX-License-Identifier: AGPL-3.0-or-later
  2. #![feature(impl_trait_in_assoc_type)]
  3. use btmsg::*;
  4. use btlib::{
  5. crypto::{ConcreteCreds, Creds, CredsPriv},
  6. BlockPath, Epoch, Principal, Principaled,
  7. };
  8. use core::future::{ready, Future, Ready};
  9. use ctor::ctor;
  10. use futures::join;
  11. use lazy_static::lazy_static;
  12. use serde::{Deserialize, Serialize};
  13. use std::{
  14. io::Write,
  15. net::{IpAddr, Ipv6Addr},
  16. sync::{Arc, Mutex as SyncMutex},
  17. time::Duration,
  18. };
  19. use tokio::sync::mpsc::{self, Sender};
  20. #[ctor]
  21. fn setup_logging() {
  22. use env_logger::Env;
  23. let env = Env::default().default_filter_or("ERROR");
  24. env_logger::builder()
  25. .format(|fmt, record| {
  26. writeln!(
  27. fmt,
  28. "[{} {} {}:{}] {}",
  29. chrono::Utc::now().to_rfc3339(),
  30. record.level(),
  31. record.file().unwrap_or("(unknown)"),
  32. record.line().unwrap_or(u32::MAX),
  33. record.args(),
  34. )
  35. })
  36. .parse_env(env)
  37. .init();
  38. }
  39. lazy_static! {
  40. static ref ROOT_CREDS: ConcreteCreds = ConcreteCreds::generate().unwrap();
  41. static ref NODE_CREDS: ConcreteCreds = {
  42. let mut creds = ConcreteCreds::generate().unwrap();
  43. let root_creds = &ROOT_CREDS;
  44. let writecap = root_creds
  45. .issue_writecap(
  46. creds.principal(),
  47. vec![],
  48. Epoch::now() + Duration::from_secs(3600),
  49. )
  50. .unwrap();
  51. creds.set_writecap(writecap).unwrap();
  52. creds
  53. };
  54. static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal();
  55. }
  56. #[derive(Debug, Serialize, Deserialize)]
  57. enum Reply {
  58. Success,
  59. Fail,
  60. ReadReply { offset: u64, buf: Vec<u8> },
  61. }
  62. #[derive(Serialize, Deserialize)]
  63. enum Msg<'a> {
  64. Ping,
  65. Success,
  66. Fail,
  67. Read { offset: u64, size: u64 },
  68. Write { offset: u64, buf: &'a [u8] },
  69. }
  70. impl<'a> CallMsg<'a> for Msg<'a> {
  71. type Reply<'b> = Reply;
  72. }
  73. impl<'a> SendMsg<'a> for Msg<'a> {}
  74. trait TestFunc<S: 'static + Send, Fut: Send + Future>:
  75. Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut
  76. {
  77. }
  78. impl<
  79. S: 'static + Send,
  80. Fut: Send + Future,
  81. T: Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut,
  82. > TestFunc<S, Fut> for T
  83. {
  84. }
  85. struct Delegate<S, Fut> {
  86. func: Arc<dyn TestFunc<S, Fut>>,
  87. sender: Sender<S>,
  88. }
  89. impl<S, Fut> Clone for Delegate<S, Fut> {
  90. fn clone(&self) -> Self {
  91. Self {
  92. func: self.func.clone(),
  93. sender: self.sender.clone(),
  94. }
  95. }
  96. }
  97. impl<S: 'static + Send, Fut: Send + Future> Delegate<S, Fut> {
  98. fn new<F: 'static + TestFunc<S, Fut>>(sender: Sender<S>, func: F) -> Self {
  99. Self {
  100. func: Arc::new(func),
  101. sender,
  102. }
  103. }
  104. }
  105. impl<S: 'static + Send, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
  106. for Delegate<S, Fut>
  107. {
  108. type Arg<'de> = Msg<'de> where Self: 'de;
  109. type CallFut<'s> = Fut where Fut: 's;
  110. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
  111. (self.func)(arg, self.sender.clone())
  112. }
  113. }
  114. fn proc_creds() -> impl Creds {
  115. let mut creds = ConcreteCreds::generate().unwrap();
  116. let writecap = NODE_CREDS
  117. .issue_writecap(
  118. creds.principal(),
  119. vec![],
  120. Epoch::now() + Duration::from_secs(3600),
  121. )
  122. .unwrap();
  123. creds.set_writecap(writecap).unwrap();
  124. creds
  125. }
  126. fn proc_rx<F: 'static + MsgCallback>(callback: F) -> (impl Receiver, Arc<BlockAddr>) {
  127. let ip_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
  128. let creds = proc_creds();
  129. let writecap = creds.writecap().unwrap();
  130. let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
  131. (receiver(ip_addr, Arc::new(creds), callback).unwrap(), addr)
  132. }
  133. async fn proc_tx_rx<F: 'static + MsgCallback>(func: F) -> (impl Transmitter, impl Receiver) {
  134. let (receiver, addr) = proc_rx(func);
  135. let sender = receiver.transmitter(addr).await.unwrap();
  136. (sender, receiver)
  137. }
  138. async fn file_server() -> (impl Transmitter, impl Receiver) {
  139. let (sender, _) = mpsc::channel::<()>(1);
  140. let file = Arc::new(SyncMutex::new([0, 1, 2, 3, 4, 5, 6, 7]));
  141. proc_tx_rx(Delegate::new(
  142. sender,
  143. move |mut received: MsgReceived<Msg<'_>>, _| {
  144. let mut guard = file.lock().unwrap();
  145. let reply_body = match received.body() {
  146. Msg::Read { offset, size } => {
  147. let offset: usize = (*offset).try_into().unwrap();
  148. let size: usize = (*size).try_into().unwrap();
  149. let end: usize = offset + size;
  150. let mut buf = Vec::with_capacity(end - offset);
  151. buf.extend_from_slice(&guard[offset..end]);
  152. Reply::ReadReply {
  153. offset: offset as u64,
  154. buf,
  155. }
  156. }
  157. Msg::Write { offset, ref buf } => {
  158. let offset: usize = (*offset).try_into().unwrap();
  159. let end: usize = offset + buf.len();
  160. (&mut guard[offset..end]).copy_from_slice(buf);
  161. Reply::Success
  162. }
  163. _ => Reply::Fail,
  164. };
  165. let mut replier = received.take_replier().unwrap();
  166. async move { replier.reply(reply_body).await }
  167. },
  168. ))
  169. .await
  170. }
  171. async fn timeout<F: Future>(future: F) -> F::Output {
  172. tokio::time::timeout(Duration::from_millis(1000), future)
  173. .await
  174. .unwrap()
  175. }
  176. macro_rules! recv {
  177. ($rx:expr) => {
  178. timeout($rx.recv()).await.unwrap()
  179. };
  180. }
  181. #[tokio::test]
  182. async fn message_received_is_message_sent() {
  183. let (sender, mut passed) = mpsc::channel(1);
  184. let (sender, _receiver) = proc_tx_rx(Delegate::new(
  185. sender,
  186. |msg: MsgReceived<Msg<'_>>, sender: Sender<bool>| {
  187. let passed = if let Msg::Ping = msg.body() {
  188. true
  189. } else {
  190. false
  191. };
  192. let sender = sender.clone();
  193. async move {
  194. sender.send(passed).await.unwrap();
  195. Ok(())
  196. }
  197. },
  198. ))
  199. .await;
  200. sender.send(Msg::Ping).await.unwrap();
  201. assert!(recv!(passed));
  202. }
  203. #[tokio::test]
  204. async fn message_received_from_path_is_correct() {
  205. let (sender, mut path) = mpsc::channel(1);
  206. let (sender, receiver) = proc_tx_rx(Delegate::new(
  207. sender,
  208. |msg: MsgReceived<Msg<'_>>, sender: Sender<Arc<BlockPath>>| {
  209. let path = msg.from().clone();
  210. let sender = sender.clone();
  211. async move {
  212. sender.send(path).await.unwrap();
  213. Ok(())
  214. }
  215. },
  216. ))
  217. .await;
  218. sender.send(Msg::Ping).await.unwrap();
  219. assert_eq!(receiver.addr().path(), recv!(path).as_ref());
  220. }
  221. #[tokio::test]
  222. async fn reply_to_read() {
  223. let (sender, _receiver) = file_server().await;
  224. let reply = sender
  225. .call_through::<Msg>(Msg::Read { offset: 2, size: 2 })
  226. .await
  227. .unwrap();
  228. if let Reply::ReadReply { offset, buf } = reply {
  229. assert_eq!(2, offset);
  230. assert_eq!([2, 3].as_slice(), buf.as_slice());
  231. } else {
  232. panic!("reply was not the right type");
  233. };
  234. }
  235. #[tokio::test]
  236. async fn call_twice() {
  237. let (sender, _receiver) = file_server().await;
  238. let reply = sender
  239. .call_through::<Msg>(Msg::Write {
  240. offset: 1,
  241. buf: &[1, 1],
  242. })
  243. .await
  244. .unwrap();
  245. if let Reply::Success = reply {
  246. ()
  247. } else {
  248. panic!("reply was not the right type");
  249. };
  250. let reply = sender
  251. .call_through::<Msg>(Msg::Read { offset: 1, size: 2 })
  252. .await
  253. .unwrap();
  254. if let Reply::ReadReply { offset, buf } = reply {
  255. assert_eq!(1, offset);
  256. assert_eq!([1, 1].as_slice(), buf.as_slice());
  257. } else {
  258. panic!("second reply was not the right type");
  259. }
  260. }
  261. #[tokio::test]
  262. async fn separate_transmitter() {
  263. let (_sender, receiver) = file_server().await;
  264. let creds = proc_creds();
  265. let transmitter = transmitter(receiver.addr().clone(), Arc::new(creds))
  266. .await
  267. .unwrap();
  268. let reply = transmitter
  269. .call_through::<Msg>(Msg::Write {
  270. offset: 5,
  271. buf: &[7, 7, 7],
  272. })
  273. .await
  274. .unwrap();
  275. let matched = if let Reply::Success = reply {
  276. true
  277. } else {
  278. false
  279. };
  280. assert!(matched);
  281. }
  282. #[derive(Serialize, Deserialize)]
  283. struct Read {
  284. offset: usize,
  285. size: usize,
  286. }
  287. #[derive(Serialize, Deserialize)]
  288. struct ReadReply<'a> {
  289. buf: &'a [u8],
  290. }
  291. impl<'a> CallMsg<'a> for Read {
  292. type Reply<'b> = ReadReply<'b>;
  293. }
  294. #[derive(Clone)]
  295. struct ReadChecker<'a> {
  296. expected: &'a [u8],
  297. }
  298. impl<'a> ReadChecker<'a> {
  299. fn new(expected: &'a [u8]) -> Self {
  300. Self { expected }
  301. }
  302. }
  303. impl<'a> DeserCallback for ReadChecker<'a> {
  304. type Arg<'de> = ReadReply<'de> where Self: 'de;
  305. type Return = bool;
  306. type CallFut<'s> = Ready<bool> where Self: 's;
  307. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  308. ready(self.expected == arg.buf)
  309. }
  310. }
  311. trait ActionFn<Arg, Fut: Send + Future>: Send + Sync + Fn(MsgReceived<Arg>) -> Fut {}
  312. impl<Arg, Fut: Send + Future, T: Send + Sync + Fn(MsgReceived<Arg>) -> Fut> ActionFn<Arg, Fut>
  313. for T
  314. {
  315. }
  316. struct Action<Arg, Fut> {
  317. func: Arc<dyn ActionFn<Arg, Fut>>,
  318. }
  319. impl<Arg, Fut: Send + Future> Action<Arg, Fut> {
  320. fn new<F: 'static + ActionFn<Arg, Fut>>(func: F) -> Self {
  321. Self {
  322. func: Arc::new(func),
  323. }
  324. }
  325. }
  326. impl<Arg, Fut> Clone for Action<Arg, Fut> {
  327. fn clone(&self) -> Self {
  328. Self {
  329. func: self.func.clone(),
  330. }
  331. }
  332. }
  333. impl<Arg: for<'a> CallMsg<'a>, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
  334. for Action<Arg, Fut>
  335. {
  336. type Arg<'de> = Arg where Arg: 'de, Fut: 'de;
  337. type CallFut<'de> = Fut where Arg: 'de, Fut: 'de;
  338. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
  339. (self.func)(arg)
  340. }
  341. }
  342. async fn read_server() -> (impl Transmitter, impl Receiver) {
  343. let file = [0, 1, 2, 3, 4, 5, 6, 7];
  344. proc_tx_rx(Action::new(move |mut msg: MsgReceived<Read>| async move {
  345. let body = msg.body();
  346. let start = body.offset;
  347. let end = start + body.size;
  348. let buf = &file[start..end];
  349. let mut replier = msg.take_replier().unwrap();
  350. replier.reply(ReadReply { buf }).await
  351. }))
  352. .await
  353. }
  354. #[tokio::test]
  355. async fn call_with_lifetime() {
  356. let (sender, _receiver) = read_server().await;
  357. let correct_one = sender
  358. .call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]))
  359. .await
  360. .unwrap();
  361. let correct_two = sender
  362. .call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]))
  363. .await
  364. .unwrap();
  365. assert!(correct_one);
  366. assert!(correct_two);
  367. }
  368. #[tokio::test]
  369. async fn call_concurrently() {
  370. let (sender, _receiver) = read_server().await;
  371. let call_one = sender.call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]));
  372. let call_two = sender.call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]));
  373. let (result_one, result_two) = join!(call_one, call_two);
  374. assert!(result_one.unwrap());
  375. assert!(result_two.unwrap());
  376. }