// SPDX-License-Identifier: AGPL-3.0-or-later #![feature(impl_trait_in_assoc_type)] use btmsg::*; use btlib::{ crypto::{ConcreteCreds, Creds, CredsPriv}, BlockPath, Epoch, Principal, Principaled, }; use core::future::{ready, Future, Ready}; use ctor::ctor; use futures::join; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use std::{ io::Write, net::{IpAddr, Ipv6Addr}, sync::{Arc, Mutex as SyncMutex}, time::Duration, }; use tokio::sync::mpsc::{self, Sender}; #[ctor] fn setup_logging() { use env_logger::Env; let env = Env::default().default_filter_or("ERROR"); env_logger::builder() .format(|fmt, record| { writeln!( fmt, "[{} {} {}:{}] {}", chrono::Utc::now().to_rfc3339(), record.level(), record.file().unwrap_or("(unknown)"), record.line().unwrap_or(u32::MAX), record.args(), ) }) .parse_env(env) .init(); } lazy_static! { static ref ROOT_CREDS: ConcreteCreds = ConcreteCreds::generate().unwrap(); static ref NODE_CREDS: ConcreteCreds = { let mut creds = ConcreteCreds::generate().unwrap(); let root_creds = &ROOT_CREDS; let writecap = root_creds .issue_writecap( creds.principal(), vec![], Epoch::now() + Duration::from_secs(3600), ) .unwrap(); creds.set_writecap(writecap); creds }; static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal(); } #[derive(Debug, Serialize, Deserialize)] enum Reply { Success, Fail, ReadReply { offset: u64, buf: Vec }, } #[derive(Serialize, Deserialize)] enum Msg<'a> { Ping, Success, Fail, Read { offset: u64, size: u64 }, Write { offset: u64, buf: &'a [u8] }, } impl<'a> CallMsg<'a> for Msg<'a> { type Reply<'b> = Reply; } impl<'a> SendMsg<'a> for Msg<'a> {} trait TestFunc: Send + Sync + Fn(MsgReceived>, Sender) -> Fut { } impl< S: 'static + Send, Fut: Send + Future, T: Send + Sync + Fn(MsgReceived>, Sender) -> Fut, > TestFunc for T { } struct Delegate { func: Arc>, sender: Sender, } impl Clone for Delegate { fn clone(&self) -> Self { Self { func: self.func.clone(), sender: self.sender.clone(), } } } impl Delegate { fn new>(sender: Sender, func: F) -> Self { Self { func: Arc::new(func), sender, } } } impl>> MsgCallback for Delegate { type Arg<'de> = Msg<'de> where Self: 'de; type CallFut<'s> = Fut where Fut: 's; fn call<'de>(&'de self, arg: MsgReceived>) -> Self::CallFut<'de> { (self.func)(arg, self.sender.clone()) } } fn proc_creds() -> impl Creds { let mut creds = ConcreteCreds::generate().unwrap(); let writecap = NODE_CREDS .issue_writecap( creds.principal(), vec![], Epoch::now() + Duration::from_secs(3600), ) .unwrap(); creds.set_writecap(writecap); creds } fn proc_rx(callback: F) -> (impl Receiver, Arc) { let ip_addr = IpAddr::V6(Ipv6Addr::LOCALHOST); let creds = proc_creds(); let writecap = creds.writecap().unwrap(); let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path()))); (receiver(ip_addr, Arc::new(creds), callback).unwrap(), addr) } async fn proc_tx_rx(func: F) -> (impl Transmitter, impl Receiver) { let (receiver, addr) = proc_rx(func); let sender = receiver.transmitter(addr).await.unwrap(); (sender, receiver) } async fn file_server() -> (impl Transmitter, impl Receiver) { let (sender, _) = mpsc::channel::<()>(1); let file = Arc::new(SyncMutex::new([0, 1, 2, 3, 4, 5, 6, 7])); proc_tx_rx(Delegate::new( sender, move |mut received: MsgReceived>, _| { let mut guard = file.lock().unwrap(); let reply_body = match received.body() { Msg::Read { offset, size } => { let offset: usize = (*offset).try_into().unwrap(); let size: usize = (*size).try_into().unwrap(); let end: usize = offset + size; let mut buf = Vec::with_capacity(end - offset); buf.extend_from_slice(&guard[offset..end]); Reply::ReadReply { offset: offset as u64, buf, } } Msg::Write { offset, ref buf } => { let offset: usize = (*offset).try_into().unwrap(); let end: usize = offset + buf.len(); (&mut guard[offset..end]).copy_from_slice(buf); Reply::Success } _ => Reply::Fail, }; let mut replier = received.take_replier().unwrap(); async move { replier.reply(reply_body).await } }, )) .await } async fn timeout(future: F) -> F::Output { tokio::time::timeout(Duration::from_millis(1000), future) .await .unwrap() } macro_rules! recv { ($rx:expr) => { timeout($rx.recv()).await.unwrap() }; } #[tokio::test] async fn message_received_is_message_sent() { let (sender, mut passed) = mpsc::channel(1); let (sender, _receiver) = proc_tx_rx(Delegate::new( sender, |msg: MsgReceived>, sender: Sender| { let passed = if let Msg::Ping = msg.body() { true } else { false }; let sender = sender.clone(); async move { sender.send(passed).await.unwrap(); Ok(()) } }, )) .await; sender.send(Msg::Ping).await.unwrap(); assert!(recv!(passed)); } #[tokio::test] async fn message_received_from_path_is_correct() { let (sender, mut path) = mpsc::channel(1); let (sender, receiver) = proc_tx_rx(Delegate::new( sender, |msg: MsgReceived>, sender: Sender>| { let path = msg.from().clone(); let sender = sender.clone(); async move { sender.send(path).await.unwrap(); Ok(()) } }, )) .await; sender.send(Msg::Ping).await.unwrap(); assert_eq!(receiver.addr().path(), recv!(path).as_ref()); } #[tokio::test] async fn reply_to_read() { let (sender, _receiver) = file_server().await; let reply = sender .call_through::(Msg::Read { offset: 2, size: 2 }) .await .unwrap(); if let Reply::ReadReply { offset, buf } = reply { assert_eq!(2, offset); assert_eq!([2, 3].as_slice(), buf.as_slice()); } else { panic!("reply was not the right type"); }; } #[tokio::test] async fn call_twice() { let (sender, _receiver) = file_server().await; let reply = sender .call_through::(Msg::Write { offset: 1, buf: &[1, 1], }) .await .unwrap(); if let Reply::Success = reply { () } else { panic!("reply was not the right type"); }; let reply = sender .call_through::(Msg::Read { offset: 1, size: 2 }) .await .unwrap(); if let Reply::ReadReply { offset, buf } = reply { assert_eq!(1, offset); assert_eq!([1, 1].as_slice(), buf.as_slice()); } else { panic!("second reply was not the right type"); } } #[tokio::test] async fn separate_transmitter() { let (_sender, receiver) = file_server().await; let creds = proc_creds(); let transmitter = transmitter(receiver.addr().clone(), Arc::new(creds)) .await .unwrap(); let reply = transmitter .call_through::(Msg::Write { offset: 5, buf: &[7, 7, 7], }) .await .unwrap(); let matched = if let Reply::Success = reply { true } else { false }; assert!(matched); } #[derive(Serialize, Deserialize)] struct Read { offset: usize, size: usize, } #[derive(Serialize, Deserialize)] struct ReadReply<'a> { buf: &'a [u8], } impl<'a> CallMsg<'a> for Read { type Reply<'b> = ReadReply<'b>; } #[derive(Clone)] struct ReadChecker<'a> { expected: &'a [u8], } impl<'a> ReadChecker<'a> { fn new(expected: &'a [u8]) -> Self { Self { expected } } } impl<'a> DeserCallback for ReadChecker<'a> { type Arg<'de> = ReadReply<'de> where Self: 'de; type Return = bool; type CallFut<'s> = Ready where Self: 's; fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> { ready(self.expected == arg.buf) } } trait ActionFn: Send + Sync + Fn(MsgReceived) -> Fut {} impl) -> Fut> ActionFn for T { } struct Action { func: Arc>, } impl Action { fn new>(func: F) -> Self { Self { func: Arc::new(func), } } } impl Clone for Action { fn clone(&self) -> Self { Self { func: self.func.clone(), } } } impl CallMsg<'a>, Fut: Send + Future>> MsgCallback for Action { type Arg<'de> = Arg where Arg: 'de, Fut: 'de; type CallFut<'de> = Fut where Arg: 'de, Fut: 'de; fn call<'de>(&'de self, arg: MsgReceived>) -> Self::CallFut<'de> { (self.func)(arg) } } async fn read_server() -> (impl Transmitter, impl Receiver) { let file = [0, 1, 2, 3, 4, 5, 6, 7]; proc_tx_rx(Action::new(move |mut msg: MsgReceived| async move { let body = msg.body(); let start = body.offset; let end = start + body.size; let buf = &file[start..end]; let mut replier = msg.take_replier().unwrap(); replier.reply(ReadReply { buf }).await })) .await } #[tokio::test] async fn call_with_lifetime() { let (sender, _receiver) = read_server().await; let correct_one = sender .call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4])) .await .unwrap(); let correct_two = sender .call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1])) .await .unwrap(); assert!(correct_one); assert!(correct_two); } #[tokio::test] async fn call_concurrently() { let (sender, _receiver) = read_server().await; let call_one = sender.call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4])); let call_two = sender.call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1])); let (result_one, result_two) = join!(call_one, call_two); assert!(result_one.unwrap()); assert!(result_two.unwrap()); }