tests.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. // SPDX-License-Identifier: AGPL-3.0-or-later
  2. #![feature(impl_trait_in_assoc_type)]
  3. use bttp::*;
  4. use btlib::{
  5. crypto::{ConcreteCreds, Creds, CredsPriv},
  6. BlockPath, Epoch, Principal, Principaled,
  7. };
  8. use core::future::{ready, Future, Ready};
  9. use ctor::ctor;
  10. use futures::join;
  11. use lazy_static::lazy_static;
  12. use serde::{Deserialize, Serialize};
  13. use std::{
  14. io::Write,
  15. net::{IpAddr, Ipv6Addr},
  16. sync::{Arc, Mutex as SyncMutex},
  17. time::Duration,
  18. };
  19. use tokio::sync::mpsc::{self, Sender};
  20. #[ctor]
  21. fn setup_logging() {
  22. use env_logger::Env;
  23. let env = Env::default().default_filter_or("ERROR");
  24. env_logger::builder()
  25. .format(|fmt, record| {
  26. writeln!(
  27. fmt,
  28. "[{} {} {}:{}] {}",
  29. chrono::Utc::now().to_rfc3339(),
  30. record.level(),
  31. record.file().unwrap_or("(unknown)"),
  32. record.line().unwrap_or(u32::MAX),
  33. record.args(),
  34. )
  35. })
  36. .parse_env(env)
  37. .init();
  38. }
  39. lazy_static! {
  40. static ref ROOT_CREDS: ConcreteCreds = ConcreteCreds::generate().unwrap();
  41. static ref NODE_CREDS: ConcreteCreds = {
  42. let mut creds = ConcreteCreds::generate().unwrap();
  43. let root_creds = &ROOT_CREDS;
  44. let writecap = root_creds
  45. .issue_writecap(
  46. creds.principal(),
  47. &mut std::iter::empty(),
  48. Epoch::now() + Duration::from_secs(3600),
  49. )
  50. .unwrap();
  51. creds.set_writecap(writecap).unwrap();
  52. creds
  53. };
  54. static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal();
  55. }
  56. #[derive(Debug, Serialize, Deserialize)]
  57. enum Reply {
  58. Success,
  59. Fail,
  60. ReadReply { offset: u64, buf: Vec<u8> },
  61. }
  62. #[derive(Serialize, Deserialize)]
  63. enum Msg<'a> {
  64. Ping,
  65. Success,
  66. Fail,
  67. Read { offset: u64, size: u64 },
  68. Write { offset: u64, buf: &'a [u8] },
  69. }
  70. impl<'a> CallMsg<'a> for Msg<'a> {
  71. type Reply<'b> = Reply;
  72. }
  73. impl<'a> SendMsg<'a> for Msg<'a> {}
  74. trait TestFunc<S: 'static + Send, Fut: Send + Future>:
  75. Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut
  76. {
  77. }
  78. impl<
  79. S: 'static + Send,
  80. Fut: Send + Future,
  81. T: Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut,
  82. > TestFunc<S, Fut> for T
  83. {
  84. }
  85. struct Delegate<S, Fut> {
  86. func: Arc<dyn TestFunc<S, Fut>>,
  87. sender: Sender<S>,
  88. }
  89. impl<S, Fut> Clone for Delegate<S, Fut> {
  90. fn clone(&self) -> Self {
  91. Self {
  92. func: self.func.clone(),
  93. sender: self.sender.clone(),
  94. }
  95. }
  96. }
  97. impl<S: 'static + Send, Fut: Send + Future> Delegate<S, Fut> {
  98. fn new<F: 'static + TestFunc<S, Fut>>(sender: Sender<S>, func: F) -> Self {
  99. Self {
  100. func: Arc::new(func),
  101. sender,
  102. }
  103. }
  104. }
  105. impl<S: 'static + Send, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
  106. for Delegate<S, Fut>
  107. {
  108. type Arg<'de> = Msg<'de> where Self: 'de;
  109. type CallFut<'s> = Fut where Fut: 's;
  110. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
  111. (self.func)(arg, self.sender.clone())
  112. }
  113. }
  114. fn proc_creds() -> impl Creds {
  115. let mut creds = ConcreteCreds::generate().unwrap();
  116. let writecap = NODE_CREDS
  117. .issue_writecap(
  118. creds.principal(),
  119. &mut std::iter::empty(),
  120. Epoch::now() + Duration::from_secs(3600),
  121. )
  122. .unwrap();
  123. creds.set_writecap(writecap).unwrap();
  124. creds
  125. }
  126. fn proc_rx<F: 'static + MsgCallback>(callback: F) -> (Receiver, Arc<BlockAddr>) {
  127. let ip_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
  128. let creds = proc_creds();
  129. let writecap = creds.writecap().unwrap();
  130. let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
  131. (
  132. Receiver::new(ip_addr, Arc::new(creds), callback).unwrap(),
  133. addr,
  134. )
  135. }
  136. async fn proc_tx_rx<F: 'static + MsgCallback>(func: F) -> (Transmitter, Receiver) {
  137. let (receiver, addr) = proc_rx(func);
  138. let sender = receiver.transmitter(addr).await.unwrap();
  139. (sender, receiver)
  140. }
  141. async fn file_server() -> (Transmitter, Receiver) {
  142. let (sender, _) = mpsc::channel::<()>(1);
  143. let file = Arc::new(SyncMutex::new([0, 1, 2, 3, 4, 5, 6, 7]));
  144. proc_tx_rx(Delegate::new(
  145. sender,
  146. move |mut received: MsgReceived<Msg<'_>>, _| {
  147. let mut guard = file.lock().unwrap();
  148. let reply_body = match received.body() {
  149. Msg::Read { offset, size } => {
  150. let offset: usize = (*offset).try_into().unwrap();
  151. let size: usize = (*size).try_into().unwrap();
  152. let end: usize = offset + size;
  153. let mut buf = Vec::with_capacity(end - offset);
  154. buf.extend_from_slice(&guard[offset..end]);
  155. Reply::ReadReply {
  156. offset: offset as u64,
  157. buf,
  158. }
  159. }
  160. Msg::Write { offset, ref buf } => {
  161. let offset: usize = (*offset).try_into().unwrap();
  162. let end: usize = offset + buf.len();
  163. (&mut guard[offset..end]).copy_from_slice(buf);
  164. Reply::Success
  165. }
  166. _ => Reply::Fail,
  167. };
  168. let mut replier = received.take_replier().unwrap();
  169. async move { replier.reply(reply_body).await }
  170. },
  171. ))
  172. .await
  173. }
  174. async fn timeout<F: Future>(future: F) -> F::Output {
  175. tokio::time::timeout(Duration::from_millis(1000), future)
  176. .await
  177. .unwrap()
  178. }
  179. macro_rules! recv {
  180. ($rx:expr) => {
  181. timeout($rx.recv()).await.unwrap()
  182. };
  183. }
  184. #[tokio::test]
  185. async fn message_received_is_message_sent() {
  186. let (sender, mut passed) = mpsc::channel(1);
  187. let (sender, _receiver) = proc_tx_rx(Delegate::new(
  188. sender,
  189. |msg: MsgReceived<Msg<'_>>, sender: Sender<bool>| {
  190. let passed = if let Msg::Ping = msg.body() {
  191. true
  192. } else {
  193. false
  194. };
  195. let sender = sender.clone();
  196. async move {
  197. sender.send(passed).await.unwrap();
  198. Ok(())
  199. }
  200. },
  201. ))
  202. .await;
  203. sender.send(Msg::Ping).await.unwrap();
  204. assert!(recv!(passed));
  205. }
  206. #[tokio::test]
  207. async fn message_received_from_path_is_correct() {
  208. let (sender, mut path) = mpsc::channel(1);
  209. let (sender, receiver) = proc_tx_rx(Delegate::new(
  210. sender,
  211. |msg: MsgReceived<Msg<'_>>, sender: Sender<Arc<BlockPath>>| {
  212. let path = msg.from().clone();
  213. let sender = sender.clone();
  214. async move {
  215. sender.send(path).await.unwrap();
  216. Ok(())
  217. }
  218. },
  219. ))
  220. .await;
  221. sender.send(Msg::Ping).await.unwrap();
  222. assert_eq!(receiver.addr().path().as_ref(), recv!(path).as_ref());
  223. }
  224. #[tokio::test]
  225. async fn reply_to_read() {
  226. let (sender, _receiver) = file_server().await;
  227. let reply = sender
  228. .call_through::<Msg>(Msg::Read { offset: 2, size: 2 })
  229. .await
  230. .unwrap();
  231. if let Reply::ReadReply { offset, buf } = reply {
  232. assert_eq!(2, offset);
  233. assert_eq!([2, 3].as_slice(), buf.as_slice());
  234. } else {
  235. panic!("reply was not the right type");
  236. };
  237. }
  238. #[tokio::test]
  239. async fn call_twice() {
  240. let (sender, _receiver) = file_server().await;
  241. let reply = sender
  242. .call_through::<Msg>(Msg::Write {
  243. offset: 1,
  244. buf: &[1, 1],
  245. })
  246. .await
  247. .unwrap();
  248. if let Reply::Success = reply {
  249. ()
  250. } else {
  251. panic!("reply was not the right type");
  252. };
  253. let reply = sender
  254. .call_through::<Msg>(Msg::Read { offset: 1, size: 2 })
  255. .await
  256. .unwrap();
  257. if let Reply::ReadReply { offset, buf } = reply {
  258. assert_eq!(1, offset);
  259. assert_eq!([1, 1].as_slice(), buf.as_slice());
  260. } else {
  261. panic!("second reply was not the right type");
  262. }
  263. }
  264. #[tokio::test]
  265. async fn separate_transmitter() {
  266. let (_sender, receiver) = file_server().await;
  267. let creds = proc_creds();
  268. let transmitter = Transmitter::new(receiver.addr().clone(), Arc::new(creds))
  269. .await
  270. .unwrap();
  271. let reply = transmitter
  272. .call_through::<Msg>(Msg::Write {
  273. offset: 5,
  274. buf: &[7, 7, 7],
  275. })
  276. .await
  277. .unwrap();
  278. let matched = if let Reply::Success = reply {
  279. true
  280. } else {
  281. false
  282. };
  283. assert!(matched);
  284. }
  285. #[derive(Serialize, Deserialize)]
  286. struct Read {
  287. offset: usize,
  288. size: usize,
  289. }
  290. #[derive(Serialize, Deserialize)]
  291. struct ReadReply<'a> {
  292. buf: &'a [u8],
  293. }
  294. impl<'a> CallMsg<'a> for Read {
  295. type Reply<'b> = ReadReply<'b>;
  296. }
  297. #[derive(Clone)]
  298. struct ReadChecker<'a> {
  299. expected: &'a [u8],
  300. }
  301. impl<'a> ReadChecker<'a> {
  302. fn new(expected: &'a [u8]) -> Self {
  303. Self { expected }
  304. }
  305. }
  306. impl<'a> DeserCallback for ReadChecker<'a> {
  307. type Arg<'de> = ReadReply<'de> where Self: 'de;
  308. type Return = bool;
  309. type CallFut<'s> = Ready<bool> where Self: 's;
  310. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  311. ready(self.expected == arg.buf)
  312. }
  313. }
  314. trait ActionFn<Arg, Fut: Send + Future>: Send + Sync + Fn(MsgReceived<Arg>) -> Fut {}
  315. impl<Arg, Fut: Send + Future, T: Send + Sync + Fn(MsgReceived<Arg>) -> Fut> ActionFn<Arg, Fut>
  316. for T
  317. {
  318. }
  319. struct Action<Arg, Fut> {
  320. func: Arc<dyn ActionFn<Arg, Fut>>,
  321. }
  322. impl<Arg, Fut: Send + Future> Action<Arg, Fut> {
  323. fn new<F: 'static + ActionFn<Arg, Fut>>(func: F) -> Self {
  324. Self {
  325. func: Arc::new(func),
  326. }
  327. }
  328. }
  329. impl<Arg, Fut> Clone for Action<Arg, Fut> {
  330. fn clone(&self) -> Self {
  331. Self {
  332. func: self.func.clone(),
  333. }
  334. }
  335. }
  336. impl<Arg: for<'a> CallMsg<'a>, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
  337. for Action<Arg, Fut>
  338. {
  339. type Arg<'de> = Arg where Arg: 'de, Fut: 'de;
  340. type CallFut<'de> = Fut where Arg: 'de, Fut: 'de;
  341. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
  342. (self.func)(arg)
  343. }
  344. }
  345. async fn read_server() -> (Transmitter, Receiver) {
  346. let file = [0, 1, 2, 3, 4, 5, 6, 7];
  347. proc_tx_rx(Action::new(move |mut msg: MsgReceived<Read>| async move {
  348. let body = msg.body();
  349. let start = body.offset;
  350. let end = start + body.size;
  351. let buf = &file[start..end];
  352. let mut replier = msg.take_replier().unwrap();
  353. replier.reply(ReadReply { buf }).await
  354. }))
  355. .await
  356. }
  357. #[tokio::test]
  358. async fn call_with_lifetime() {
  359. let (sender, _receiver) = read_server().await;
  360. let correct_one = sender
  361. .call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]))
  362. .await
  363. .unwrap();
  364. let correct_two = sender
  365. .call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]))
  366. .await
  367. .unwrap();
  368. assert!(correct_one);
  369. assert!(correct_two);
  370. }
  371. #[tokio::test]
  372. async fn call_concurrently() {
  373. let (sender, _receiver) = read_server().await;
  374. let call_one = sender.call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]));
  375. let call_two = sender.call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]));
  376. let (result_one, result_two) = join!(call_one, call_two);
  377. assert!(result_one.unwrap());
  378. assert!(result_two.unwrap());
  379. }