tests.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. #![feature(type_alias_impl_trait)]
  2. use btmsg::*;
  3. use btlib::{
  4. crypto::{ConcreteCreds, Creds, CredsPriv},
  5. BlockPath, Epoch, Principal, Principaled,
  6. };
  7. use core::future::{ready, Future, Ready};
  8. use ctor::ctor;
  9. use futures::join;
  10. use lazy_static::lazy_static;
  11. use serde::{Deserialize, Serialize};
  12. use std::{
  13. io::Write,
  14. net::{IpAddr, Ipv6Addr},
  15. sync::{Arc, Mutex as SyncMutex},
  16. time::Duration,
  17. };
  18. use tokio::sync::mpsc::{self, Sender};
  19. #[ctor]
  20. fn setup_logging() {
  21. use env_logger::Env;
  22. let env = Env::default().default_filter_or("ERROR");
  23. env_logger::builder()
  24. .format(|fmt, record| {
  25. writeln!(
  26. fmt,
  27. "[{} {} {}:{}] {}",
  28. chrono::Utc::now().to_rfc3339(),
  29. record.level(),
  30. record.file().unwrap_or("(unknown)"),
  31. record.line().unwrap_or(u32::MAX),
  32. record.args(),
  33. )
  34. })
  35. .parse_env(env)
  36. .init();
  37. }
  38. lazy_static! {
  39. static ref ROOT_CREDS: ConcreteCreds = ConcreteCreds::generate().unwrap();
  40. static ref NODE_CREDS: ConcreteCreds = {
  41. let mut creds = ConcreteCreds::generate().unwrap();
  42. let root_creds = &ROOT_CREDS;
  43. let writecap = root_creds
  44. .issue_writecap(
  45. creds.principal(),
  46. vec![],
  47. Epoch::now() + Duration::from_secs(3600),
  48. )
  49. .unwrap();
  50. creds.set_writecap(writecap);
  51. creds
  52. };
  53. static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal();
  54. }
  55. #[derive(Debug, Serialize, Deserialize)]
  56. enum Reply {
  57. Success,
  58. Fail,
  59. ReadReply { offset: u64, buf: Vec<u8> },
  60. }
  61. #[derive(Serialize, Deserialize)]
  62. enum Msg<'a> {
  63. Ping,
  64. Success,
  65. Fail,
  66. Read { offset: u64, size: u64 },
  67. Write { offset: u64, buf: &'a [u8] },
  68. }
  69. impl<'a> CallMsg<'a> for Msg<'a> {
  70. type Reply<'b> = Reply;
  71. }
  72. impl<'a> SendMsg<'a> for Msg<'a> {}
  73. trait TestFunc<S: 'static + Send, Fut: Send + Future>:
  74. Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut
  75. {
  76. }
  77. impl<
  78. S: 'static + Send,
  79. Fut: Send + Future,
  80. T: Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut,
  81. > TestFunc<S, Fut> for T
  82. {
  83. }
  84. struct Delegate<S, Fut> {
  85. func: Arc<dyn TestFunc<S, Fut>>,
  86. sender: Sender<S>,
  87. }
  88. impl<S, Fut> Clone for Delegate<S, Fut> {
  89. fn clone(&self) -> Self {
  90. Self {
  91. func: self.func.clone(),
  92. sender: self.sender.clone(),
  93. }
  94. }
  95. }
  96. impl<S: 'static + Send, Fut: Send + Future> Delegate<S, Fut> {
  97. fn new<F: 'static + TestFunc<S, Fut>>(sender: Sender<S>, func: F) -> Self {
  98. Self {
  99. func: Arc::new(func),
  100. sender,
  101. }
  102. }
  103. }
  104. impl<S: 'static + Send, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
  105. for Delegate<S, Fut>
  106. {
  107. type Arg<'de> = Msg<'de> where Self: 'de;
  108. type CallFut<'s> = Fut where Fut: 's;
  109. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
  110. (self.func)(arg, self.sender.clone())
  111. }
  112. }
  113. fn proc_creds() -> impl Creds {
  114. let mut creds = ConcreteCreds::generate().unwrap();
  115. let writecap = NODE_CREDS
  116. .issue_writecap(
  117. creds.principal(),
  118. vec![],
  119. Epoch::now() + Duration::from_secs(3600),
  120. )
  121. .unwrap();
  122. creds.set_writecap(writecap);
  123. creds
  124. }
  125. fn proc_rx<F: 'static + MsgCallback>(callback: F) -> (impl Receiver, Arc<BlockAddr>) {
  126. let ip_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
  127. let creds = proc_creds();
  128. let writecap = creds.writecap().unwrap();
  129. let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
  130. (receiver(ip_addr, Arc::new(creds), callback).unwrap(), addr)
  131. }
  132. async fn proc_tx_rx<F: 'static + MsgCallback>(func: F) -> (impl Transmitter, impl Receiver) {
  133. let (receiver, addr) = proc_rx(func);
  134. let sender = receiver.transmitter(addr).await.unwrap();
  135. (sender, receiver)
  136. }
  137. async fn file_server() -> (impl Transmitter, impl Receiver) {
  138. let (sender, _) = mpsc::channel::<()>(1);
  139. let file = Arc::new(SyncMutex::new([0, 1, 2, 3, 4, 5, 6, 7]));
  140. proc_tx_rx(Delegate::new(
  141. sender,
  142. move |mut received: MsgReceived<Msg<'_>>, _| {
  143. let mut guard = file.lock().unwrap();
  144. let reply_body = match received.body() {
  145. Msg::Read { offset, size } => {
  146. let offset: usize = (*offset).try_into().unwrap();
  147. let size: usize = (*size).try_into().unwrap();
  148. let end: usize = offset + size;
  149. let mut buf = Vec::with_capacity(end - offset);
  150. buf.extend_from_slice(&guard[offset..end]);
  151. Reply::ReadReply {
  152. offset: offset as u64,
  153. buf,
  154. }
  155. }
  156. Msg::Write { offset, ref buf } => {
  157. let offset: usize = (*offset).try_into().unwrap();
  158. let end: usize = offset + buf.len();
  159. (&mut guard[offset..end]).copy_from_slice(buf);
  160. Reply::Success
  161. }
  162. _ => Reply::Fail,
  163. };
  164. let mut replier = received.take_replier().unwrap();
  165. async move { replier.reply(reply_body).await }
  166. },
  167. ))
  168. .await
  169. }
  170. async fn timeout<F: Future>(future: F) -> F::Output {
  171. tokio::time::timeout(Duration::from_millis(1000), future)
  172. .await
  173. .unwrap()
  174. }
  175. macro_rules! recv {
  176. ($rx:expr) => {
  177. timeout($rx.recv()).await.unwrap()
  178. };
  179. }
  180. #[tokio::test]
  181. async fn message_received_is_message_sent() {
  182. let (sender, mut passed) = mpsc::channel(1);
  183. let (sender, _receiver) = proc_tx_rx(Delegate::new(
  184. sender,
  185. |msg: MsgReceived<Msg<'_>>, sender: Sender<bool>| {
  186. let passed = if let Msg::Ping = msg.body() {
  187. true
  188. } else {
  189. false
  190. };
  191. let sender = sender.clone();
  192. async move {
  193. sender.send(passed).await.unwrap();
  194. Ok(())
  195. }
  196. },
  197. ))
  198. .await;
  199. sender.send(Msg::Ping).await.unwrap();
  200. assert!(recv!(passed));
  201. }
  202. #[tokio::test]
  203. async fn message_received_from_path_is_correct() {
  204. let (sender, mut path) = mpsc::channel(1);
  205. let (sender, receiver) = proc_tx_rx(Delegate::new(
  206. sender,
  207. |msg: MsgReceived<Msg<'_>>, sender: Sender<Arc<BlockPath>>| {
  208. let path = msg.from().clone();
  209. let sender = sender.clone();
  210. async move {
  211. sender.send(path).await.unwrap();
  212. Ok(())
  213. }
  214. },
  215. ))
  216. .await;
  217. sender.send(Msg::Ping).await.unwrap();
  218. assert_eq!(receiver.addr().path(), recv!(path).as_ref());
  219. }
  220. #[tokio::test]
  221. async fn reply_to_read() {
  222. let (sender, _receiver) = file_server().await;
  223. let reply = sender
  224. .call_through::<Msg>(Msg::Read { offset: 2, size: 2 })
  225. .await
  226. .unwrap();
  227. if let Reply::ReadReply { offset, buf } = reply {
  228. assert_eq!(2, offset);
  229. assert_eq!([2, 3].as_slice(), buf.as_slice());
  230. } else {
  231. panic!("reply was not the right type");
  232. };
  233. }
  234. #[tokio::test]
  235. async fn call_twice() {
  236. let (sender, _receiver) = file_server().await;
  237. let reply = sender
  238. .call_through::<Msg>(Msg::Write {
  239. offset: 1,
  240. buf: &[1, 1],
  241. })
  242. .await
  243. .unwrap();
  244. if let Reply::Success = reply {
  245. ()
  246. } else {
  247. panic!("reply was not the right type");
  248. };
  249. let reply = sender
  250. .call_through::<Msg>(Msg::Read { offset: 1, size: 2 })
  251. .await
  252. .unwrap();
  253. if let Reply::ReadReply { offset, buf } = reply {
  254. assert_eq!(1, offset);
  255. assert_eq!([1, 1].as_slice(), buf.as_slice());
  256. } else {
  257. panic!("second reply was not the right type");
  258. }
  259. }
  260. #[tokio::test]
  261. async fn separate_transmitter() {
  262. let (_sender, receiver) = file_server().await;
  263. let creds = proc_creds();
  264. let transmitter = transmitter(receiver.addr().clone(), Arc::new(creds))
  265. .await
  266. .unwrap();
  267. let reply = transmitter
  268. .call_through::<Msg>(Msg::Write {
  269. offset: 5,
  270. buf: &[7, 7, 7],
  271. })
  272. .await
  273. .unwrap();
  274. let matched = if let Reply::Success = reply {
  275. true
  276. } else {
  277. false
  278. };
  279. assert!(matched);
  280. }
  281. #[derive(Serialize, Deserialize)]
  282. struct Read {
  283. offset: usize,
  284. size: usize,
  285. }
  286. #[derive(Serialize, Deserialize)]
  287. struct ReadReply<'a> {
  288. buf: &'a [u8],
  289. }
  290. impl<'a> CallMsg<'a> for Read {
  291. type Reply<'b> = ReadReply<'b>;
  292. }
  293. #[derive(Clone)]
  294. struct ReadChecker<'a> {
  295. expected: &'a [u8],
  296. }
  297. impl<'a> ReadChecker<'a> {
  298. fn new(expected: &'a [u8]) -> Self {
  299. Self { expected }
  300. }
  301. }
  302. impl<'a> DeserCallback for ReadChecker<'a> {
  303. type Arg<'de> = ReadReply<'de> where Self: 'de;
  304. type Return = bool;
  305. type CallFut<'s> = Ready<bool> where Self: 's;
  306. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  307. ready(self.expected == arg.buf)
  308. }
  309. }
  310. trait ActionFn<Arg, Fut: Send + Future>: Send + Sync + Fn(MsgReceived<Arg>) -> Fut {}
  311. impl<Arg, Fut: Send + Future, T: Send + Sync + Fn(MsgReceived<Arg>) -> Fut> ActionFn<Arg, Fut>
  312. for T
  313. {
  314. }
  315. struct Action<Arg, Fut> {
  316. func: Arc<dyn ActionFn<Arg, Fut>>,
  317. }
  318. impl<Arg, Fut: Send + Future> Action<Arg, Fut> {
  319. fn new<F: 'static + ActionFn<Arg, Fut>>(func: F) -> Self {
  320. Self {
  321. func: Arc::new(func),
  322. }
  323. }
  324. }
  325. impl<Arg, Fut> Clone for Action<Arg, Fut> {
  326. fn clone(&self) -> Self {
  327. Self {
  328. func: self.func.clone(),
  329. }
  330. }
  331. }
  332. impl<Arg: for<'a> CallMsg<'a>, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
  333. for Action<Arg, Fut>
  334. {
  335. type Arg<'de> = Arg where Arg: 'de, Fut: 'de;
  336. type CallFut<'de> = Fut where Arg: 'de, Fut: 'de;
  337. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
  338. (self.func)(arg)
  339. }
  340. }
  341. async fn read_server() -> (impl Transmitter, impl Receiver) {
  342. let file = [0, 1, 2, 3, 4, 5, 6, 7];
  343. proc_tx_rx(Action::new(move |mut msg: MsgReceived<Read>| async move {
  344. let body = msg.body();
  345. let start = body.offset;
  346. let end = start + body.size;
  347. let buf = &file[start..end];
  348. let mut replier = msg.take_replier().unwrap();
  349. replier.reply(ReadReply { buf }).await
  350. }))
  351. .await
  352. }
  353. #[tokio::test]
  354. async fn call_with_lifetime() {
  355. let (sender, _receiver) = read_server().await;
  356. let correct_one = sender
  357. .call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]))
  358. .await
  359. .unwrap();
  360. let correct_two = sender
  361. .call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]))
  362. .await
  363. .unwrap();
  364. assert!(correct_one);
  365. assert!(correct_two);
  366. }
  367. #[tokio::test]
  368. async fn call_concurrently() {
  369. let (sender, _receiver) = read_server().await;
  370. let call_one = sender.call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]));
  371. let call_two = sender.call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]));
  372. let (result_one, result_two) = join!(call_one, call_two);
  373. assert!(result_one.unwrap());
  374. assert!(result_two.unwrap());
  375. }