tests.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. #![feature(type_alias_impl_trait)]
  2. use btmsg::*;
  3. use btlib::{
  4. crypto::{ConcreteCreds, Creds, CredsPriv},
  5. BlockPath, Epoch, Principal, Principaled,
  6. };
  7. use core::future::{ready, Future, Ready};
  8. use ctor::ctor;
  9. use lazy_static::lazy_static;
  10. use serde::{Deserialize, Serialize};
  11. use std::{
  12. io::Write,
  13. net::{IpAddr, Ipv6Addr},
  14. sync::{Arc, Mutex as SyncMutex},
  15. time::Duration,
  16. };
  17. use tokio::sync::mpsc::{self, Sender};
  18. #[ctor]
  19. fn setup_logging() {
  20. use env_logger::Env;
  21. let env = Env::default().default_filter_or("ERROR");
  22. env_logger::builder()
  23. .format(|fmt, record| {
  24. writeln!(
  25. fmt,
  26. "[{} {} {}:{}] {}",
  27. chrono::Utc::now().to_rfc3339(),
  28. record.level(),
  29. record.file().unwrap_or("(unknown)"),
  30. record.line().unwrap_or(u32::MAX),
  31. record.args(),
  32. )
  33. })
  34. .parse_env(env)
  35. .init();
  36. }
  37. lazy_static! {
  38. static ref ROOT_CREDS: ConcreteCreds = ConcreteCreds::generate().unwrap();
  39. static ref NODE_CREDS: ConcreteCreds = {
  40. let mut creds = ConcreteCreds::generate().unwrap();
  41. let root_creds = &ROOT_CREDS;
  42. let writecap = root_creds
  43. .issue_writecap(
  44. creds.principal(),
  45. vec![],
  46. Epoch::now() + Duration::from_secs(3600),
  47. )
  48. .unwrap();
  49. creds.set_writecap(writecap);
  50. creds
  51. };
  52. static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal();
  53. }
  54. #[derive(Debug, Serialize, Deserialize)]
  55. enum Reply {
  56. Success,
  57. Fail,
  58. ReadReply { offset: u64, buf: Vec<u8> },
  59. }
  60. #[derive(Serialize, Deserialize)]
  61. enum Msg<'a> {
  62. Ping,
  63. Success,
  64. Fail,
  65. Read { offset: u64, size: u64 },
  66. Write { offset: u64, buf: &'a [u8] },
  67. }
  68. impl<'a> CallMsg<'a> for Msg<'a> {
  69. type Reply<'b> = Reply;
  70. }
  71. impl<'a> SendMsg<'a> for Msg<'a> {}
  72. trait TestFunc<S: 'static + Send, Fut: Send + Future>:
  73. Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut
  74. {
  75. }
  76. impl<
  77. S: 'static + Send,
  78. Fut: Send + Future,
  79. T: Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut,
  80. > TestFunc<S, Fut> for T
  81. {
  82. }
  83. struct Delegate<S, Fut> {
  84. func: Arc<dyn TestFunc<S, Fut>>,
  85. sender: Sender<S>,
  86. }
  87. impl<S, Fut> Clone for Delegate<S, Fut> {
  88. fn clone(&self) -> Self {
  89. Self {
  90. func: self.func.clone(),
  91. sender: self.sender.clone(),
  92. }
  93. }
  94. }
  95. impl<S: 'static + Send, Fut: Send + Future> Delegate<S, Fut> {
  96. fn new<F: 'static + TestFunc<S, Fut>>(sender: Sender<S>, func: F) -> Self {
  97. Self {
  98. func: Arc::new(func),
  99. sender,
  100. }
  101. }
  102. }
  103. impl<S: 'static + Send, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
  104. for Delegate<S, Fut>
  105. {
  106. type Arg<'de> = Msg<'de> where Self: 'de;
  107. type CallFut<'s> = Fut where Fut: 's;
  108. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
  109. (self.func)(arg, self.sender.clone())
  110. }
  111. }
  112. fn proc_creds() -> impl Creds {
  113. let mut creds = ConcreteCreds::generate().unwrap();
  114. let writecap = NODE_CREDS
  115. .issue_writecap(
  116. creds.principal(),
  117. vec![],
  118. Epoch::now() + Duration::from_secs(3600),
  119. )
  120. .unwrap();
  121. creds.set_writecap(writecap);
  122. creds
  123. }
  124. fn proc_rx<F: 'static + MsgCallback>(callback: F) -> (impl Receiver, Arc<BlockAddr>) {
  125. let ip_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
  126. let creds = proc_creds();
  127. let writecap = creds.writecap().unwrap();
  128. let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
  129. (receiver(ip_addr, Arc::new(creds), callback).unwrap(), addr)
  130. }
  131. async fn proc_tx_rx<F: 'static + MsgCallback>(func: F) -> (impl Transmitter, impl Receiver) {
  132. let (receiver, addr) = proc_rx(func);
  133. let sender = receiver.transmitter(addr).await.unwrap();
  134. (sender, receiver)
  135. }
  136. async fn file_server() -> (impl Transmitter, impl Receiver) {
  137. let (sender, _) = mpsc::channel::<()>(1);
  138. let file = Arc::new(SyncMutex::new([0, 1, 2, 3, 4, 5, 6, 7]));
  139. proc_tx_rx(Delegate::new(
  140. sender,
  141. move |mut received: MsgReceived<Msg<'_>>, _| {
  142. let mut guard = file.lock().unwrap();
  143. let reply_body = match received.body() {
  144. Msg::Read { offset, size } => {
  145. let offset: usize = (*offset).try_into().unwrap();
  146. let size: usize = (*size).try_into().unwrap();
  147. let end: usize = offset + size;
  148. let mut buf = Vec::with_capacity(end - offset);
  149. buf.extend_from_slice(&guard[offset..end]);
  150. Reply::ReadReply {
  151. offset: offset as u64,
  152. buf,
  153. }
  154. }
  155. Msg::Write { offset, ref buf } => {
  156. let offset: usize = (*offset).try_into().unwrap();
  157. let end: usize = offset + buf.len();
  158. (&mut guard[offset..end]).copy_from_slice(buf);
  159. Reply::Success
  160. }
  161. _ => Reply::Fail,
  162. };
  163. let mut replier = received.take_replier().unwrap();
  164. async move { replier.reply(reply_body).await }
  165. },
  166. ))
  167. .await
  168. }
  169. async fn timeout<F: Future>(future: F) -> F::Output {
  170. tokio::time::timeout(Duration::from_millis(1000), future)
  171. .await
  172. .unwrap()
  173. }
  174. macro_rules! recv {
  175. ($rx:expr) => {
  176. timeout($rx.recv()).await.unwrap()
  177. };
  178. }
  179. #[tokio::test]
  180. async fn message_received_is_message_sent() {
  181. let (sender, mut passed) = mpsc::channel(1);
  182. let (mut sender, _receiver) = proc_tx_rx(Delegate::new(
  183. sender,
  184. |msg: MsgReceived<Msg<'_>>, sender: Sender<bool>| {
  185. let passed = if let Msg::Ping = msg.body() {
  186. true
  187. } else {
  188. false
  189. };
  190. let sender = sender.clone();
  191. async move {
  192. sender.send(passed).await.unwrap();
  193. Ok(())
  194. }
  195. },
  196. ))
  197. .await;
  198. sender.send(Msg::Ping).await.unwrap();
  199. assert!(recv!(passed));
  200. }
  201. #[tokio::test]
  202. async fn message_received_from_path_is_correct() {
  203. let (sender, mut path) = mpsc::channel(1);
  204. let (mut sender, receiver) = proc_tx_rx(Delegate::new(
  205. sender,
  206. |msg: MsgReceived<Msg<'_>>, sender: Sender<Arc<BlockPath>>| {
  207. let path = msg.from().clone();
  208. let sender = sender.clone();
  209. async move {
  210. sender.send(path).await.unwrap();
  211. Ok(())
  212. }
  213. },
  214. ))
  215. .await;
  216. sender.send(Msg::Ping).await.unwrap();
  217. assert_eq!(receiver.addr().path(), recv!(path).as_ref());
  218. }
  219. #[tokio::test]
  220. async fn reply_to_read() {
  221. let (mut sender, _receiver) = file_server().await;
  222. let reply = sender
  223. .call_through::<Msg>(Msg::Read { offset: 2, size: 2 })
  224. .await
  225. .unwrap();
  226. if let Reply::ReadReply { offset, buf } = reply {
  227. assert_eq!(2, offset);
  228. assert_eq!([2, 3].as_slice(), buf.as_slice());
  229. } else {
  230. panic!("reply was not the right type");
  231. };
  232. }
  233. #[tokio::test]
  234. async fn call_twice() {
  235. let (mut sender, _receiver) = file_server().await;
  236. let reply = sender
  237. .call_through::<Msg>(Msg::Write {
  238. offset: 1,
  239. buf: &[1, 1],
  240. })
  241. .await
  242. .unwrap();
  243. if let Reply::Success = reply {
  244. ()
  245. } else {
  246. panic!("reply was not the right type");
  247. };
  248. let reply = sender
  249. .call_through::<Msg>(Msg::Read { offset: 1, size: 2 })
  250. .await
  251. .unwrap();
  252. if let Reply::ReadReply { offset, buf } = reply {
  253. assert_eq!(1, offset);
  254. assert_eq!([1, 1].as_slice(), buf.as_slice());
  255. } else {
  256. panic!("second reply was not the right type");
  257. }
  258. }
  259. #[tokio::test]
  260. async fn separate_transmitter() {
  261. let (_sender, receiver) = file_server().await;
  262. let creds = proc_creds();
  263. let mut transmitter = transmitter(receiver.addr().clone(), Arc::new(creds))
  264. .await
  265. .unwrap();
  266. let reply = transmitter
  267. .call_through::<Msg>(Msg::Write {
  268. offset: 5,
  269. buf: &[7, 7, 7],
  270. })
  271. .await
  272. .unwrap();
  273. let matched = if let Reply::Success = reply {
  274. true
  275. } else {
  276. false
  277. };
  278. assert!(matched);
  279. }
  280. #[derive(Serialize, Deserialize)]
  281. struct Read {
  282. offset: usize,
  283. size: usize,
  284. }
  285. #[derive(Serialize, Deserialize)]
  286. struct ReadReply<'a> {
  287. buf: &'a [u8],
  288. }
  289. impl<'a> CallMsg<'a> for Read {
  290. type Reply<'b> = ReadReply<'b>;
  291. }
  292. #[derive(Clone)]
  293. struct ReadChecker<'a> {
  294. expected: &'a [u8],
  295. }
  296. impl<'a> ReadChecker<'a> {
  297. fn new(expected: &'a [u8]) -> Self {
  298. Self { expected }
  299. }
  300. }
  301. impl<'a> DeserCallback for ReadChecker<'a> {
  302. type Arg<'de> = ReadReply<'de> where Self: 'de;
  303. type Return = bool;
  304. type CallFut<'s> = Ready<bool> where Self: 's;
  305. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  306. ready(self.expected == arg.buf)
  307. }
  308. }
  309. trait ActionFn<Arg, Fut: Send + Future>: Send + Sync + Fn(MsgReceived<Arg>) -> Fut {}
  310. impl<Arg, Fut: Send + Future, T: Send + Sync + Fn(MsgReceived<Arg>) -> Fut> ActionFn<Arg, Fut>
  311. for T
  312. {
  313. }
  314. struct Action<Arg, Fut> {
  315. func: Arc<dyn ActionFn<Arg, Fut>>,
  316. }
  317. impl<Arg, Fut: Send + Future> Action<Arg, Fut> {
  318. fn new<F: 'static + ActionFn<Arg, Fut>>(func: F) -> Self {
  319. Self {
  320. func: Arc::new(func),
  321. }
  322. }
  323. }
  324. impl<Arg, Fut> Clone for Action<Arg, Fut> {
  325. fn clone(&self) -> Self {
  326. Self {
  327. func: self.func.clone(),
  328. }
  329. }
  330. }
  331. impl<Arg: for<'a> CallMsg<'a>, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
  332. for Action<Arg, Fut>
  333. {
  334. type Arg<'de> = Arg where Arg: 'de, Fut: 'de;
  335. type CallFut<'de> = Fut where Arg: 'de, Fut: 'de;
  336. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
  337. (self.func)(arg)
  338. }
  339. }
  340. async fn read_server() -> (impl Transmitter, impl Receiver) {
  341. let file = [0, 1, 2, 3, 4, 5, 6, 7];
  342. proc_tx_rx(Action::new(move |mut msg: MsgReceived<Read>| async move {
  343. let body = msg.body();
  344. let start = body.offset;
  345. let end = start + body.size;
  346. let buf = &file[start..end];
  347. let mut replier = msg.take_replier().unwrap();
  348. replier.reply(ReadReply { buf }).await
  349. }))
  350. .await
  351. }
  352. #[tokio::test]
  353. async fn call_with_lifetime() {
  354. let (mut sender, _receiver) = read_server().await;
  355. let correct_one = sender
  356. .call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]))
  357. .await
  358. .unwrap();
  359. let correct_two = sender
  360. .call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]))
  361. .await
  362. .unwrap();
  363. assert!(correct_one);
  364. assert!(correct_two);
  365. }