|
- // SPDX-License-Identifier: AGPL-3.0-or-later
- #![feature(type_alias_impl_trait)]
- use btmsg::*;
- use btlib::{
- crypto::{ConcreteCreds, Creds, CredsPriv},
- BlockPath, Epoch, Principal, Principaled,
- };
- use core::future::{ready, Future, Ready};
- use ctor::ctor;
- use futures::join;
- use lazy_static::lazy_static;
- use serde::{Deserialize, Serialize};
- use std::{
- io::Write,
- net::{IpAddr, Ipv6Addr},
- sync::{Arc, Mutex as SyncMutex},
- time::Duration,
- };
- use tokio::sync::mpsc::{self, Sender};
- #[ctor]
- fn setup_logging() {
- use env_logger::Env;
- let env = Env::default().default_filter_or("ERROR");
- env_logger::builder()
- .format(|fmt, record| {
- writeln!(
- fmt,
- "[{} {} {}:{}] {}",
- chrono::Utc::now().to_rfc3339(),
- record.level(),
- record.file().unwrap_or("(unknown)"),
- record.line().unwrap_or(u32::MAX),
- record.args(),
- )
- })
- .parse_env(env)
- .init();
- }
- lazy_static! {
- static ref ROOT_CREDS: ConcreteCreds = ConcreteCreds::generate().unwrap();
- static ref NODE_CREDS: ConcreteCreds = {
- let mut creds = ConcreteCreds::generate().unwrap();
- let root_creds = &ROOT_CREDS;
- let writecap = root_creds
- .issue_writecap(
- creds.principal(),
- vec![],
- Epoch::now() + Duration::from_secs(3600),
- )
- .unwrap();
- creds.set_writecap(writecap);
- creds
- };
- static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal();
- }
- #[derive(Debug, Serialize, Deserialize)]
- enum Reply {
- Success,
- Fail,
- ReadReply { offset: u64, buf: Vec<u8> },
- }
- #[derive(Serialize, Deserialize)]
- enum Msg<'a> {
- Ping,
- Success,
- Fail,
- Read { offset: u64, size: u64 },
- Write { offset: u64, buf: &'a [u8] },
- }
- impl<'a> CallMsg<'a> for Msg<'a> {
- type Reply<'b> = Reply;
- }
- impl<'a> SendMsg<'a> for Msg<'a> {}
- trait TestFunc<S: 'static + Send, Fut: Send + Future>:
- Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut
- {
- }
- impl<
- S: 'static + Send,
- Fut: Send + Future,
- T: Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut,
- > TestFunc<S, Fut> for T
- {
- }
- struct Delegate<S, Fut> {
- func: Arc<dyn TestFunc<S, Fut>>,
- sender: Sender<S>,
- }
- impl<S, Fut> Clone for Delegate<S, Fut> {
- fn clone(&self) -> Self {
- Self {
- func: self.func.clone(),
- sender: self.sender.clone(),
- }
- }
- }
- impl<S: 'static + Send, Fut: Send + Future> Delegate<S, Fut> {
- fn new<F: 'static + TestFunc<S, Fut>>(sender: Sender<S>, func: F) -> Self {
- Self {
- func: Arc::new(func),
- sender,
- }
- }
- }
- impl<S: 'static + Send, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
- for Delegate<S, Fut>
- {
- type Arg<'de> = Msg<'de> where Self: 'de;
- type CallFut<'s> = Fut where Fut: 's;
- fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
- (self.func)(arg, self.sender.clone())
- }
- }
- fn proc_creds() -> impl Creds {
- let mut creds = ConcreteCreds::generate().unwrap();
- let writecap = NODE_CREDS
- .issue_writecap(
- creds.principal(),
- vec![],
- Epoch::now() + Duration::from_secs(3600),
- )
- .unwrap();
- creds.set_writecap(writecap);
- creds
- }
- fn proc_rx<F: 'static + MsgCallback>(callback: F) -> (impl Receiver, Arc<BlockAddr>) {
- let ip_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
- let creds = proc_creds();
- let writecap = creds.writecap().unwrap();
- let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
- (receiver(ip_addr, Arc::new(creds), callback).unwrap(), addr)
- }
- async fn proc_tx_rx<F: 'static + MsgCallback>(func: F) -> (impl Transmitter, impl Receiver) {
- let (receiver, addr) = proc_rx(func);
- let sender = receiver.transmitter(addr).await.unwrap();
- (sender, receiver)
- }
- async fn file_server() -> (impl Transmitter, impl Receiver) {
- let (sender, _) = mpsc::channel::<()>(1);
- let file = Arc::new(SyncMutex::new([0, 1, 2, 3, 4, 5, 6, 7]));
- proc_tx_rx(Delegate::new(
- sender,
- move |mut received: MsgReceived<Msg<'_>>, _| {
- let mut guard = file.lock().unwrap();
- let reply_body = match received.body() {
- Msg::Read { offset, size } => {
- let offset: usize = (*offset).try_into().unwrap();
- let size: usize = (*size).try_into().unwrap();
- let end: usize = offset + size;
- let mut buf = Vec::with_capacity(end - offset);
- buf.extend_from_slice(&guard[offset..end]);
- Reply::ReadReply {
- offset: offset as u64,
- buf,
- }
- }
- Msg::Write { offset, ref buf } => {
- let offset: usize = (*offset).try_into().unwrap();
- let end: usize = offset + buf.len();
- (&mut guard[offset..end]).copy_from_slice(buf);
- Reply::Success
- }
- _ => Reply::Fail,
- };
- let mut replier = received.take_replier().unwrap();
- async move { replier.reply(reply_body).await }
- },
- ))
- .await
- }
- async fn timeout<F: Future>(future: F) -> F::Output {
- tokio::time::timeout(Duration::from_millis(1000), future)
- .await
- .unwrap()
- }
- macro_rules! recv {
- ($rx:expr) => {
- timeout($rx.recv()).await.unwrap()
- };
- }
- #[tokio::test]
- async fn message_received_is_message_sent() {
- let (sender, mut passed) = mpsc::channel(1);
- let (sender, _receiver) = proc_tx_rx(Delegate::new(
- sender,
- |msg: MsgReceived<Msg<'_>>, sender: Sender<bool>| {
- let passed = if let Msg::Ping = msg.body() {
- true
- } else {
- false
- };
- let sender = sender.clone();
- async move {
- sender.send(passed).await.unwrap();
- Ok(())
- }
- },
- ))
- .await;
- sender.send(Msg::Ping).await.unwrap();
- assert!(recv!(passed));
- }
- #[tokio::test]
- async fn message_received_from_path_is_correct() {
- let (sender, mut path) = mpsc::channel(1);
- let (sender, receiver) = proc_tx_rx(Delegate::new(
- sender,
- |msg: MsgReceived<Msg<'_>>, sender: Sender<Arc<BlockPath>>| {
- let path = msg.from().clone();
- let sender = sender.clone();
- async move {
- sender.send(path).await.unwrap();
- Ok(())
- }
- },
- ))
- .await;
- sender.send(Msg::Ping).await.unwrap();
- assert_eq!(receiver.addr().path(), recv!(path).as_ref());
- }
- #[tokio::test]
- async fn reply_to_read() {
- let (sender, _receiver) = file_server().await;
- let reply = sender
- .call_through::<Msg>(Msg::Read { offset: 2, size: 2 })
- .await
- .unwrap();
- if let Reply::ReadReply { offset, buf } = reply {
- assert_eq!(2, offset);
- assert_eq!([2, 3].as_slice(), buf.as_slice());
- } else {
- panic!("reply was not the right type");
- };
- }
- #[tokio::test]
- async fn call_twice() {
- let (sender, _receiver) = file_server().await;
- let reply = sender
- .call_through::<Msg>(Msg::Write {
- offset: 1,
- buf: &[1, 1],
- })
- .await
- .unwrap();
- if let Reply::Success = reply {
- ()
- } else {
- panic!("reply was not the right type");
- };
- let reply = sender
- .call_through::<Msg>(Msg::Read { offset: 1, size: 2 })
- .await
- .unwrap();
- if let Reply::ReadReply { offset, buf } = reply {
- assert_eq!(1, offset);
- assert_eq!([1, 1].as_slice(), buf.as_slice());
- } else {
- panic!("second reply was not the right type");
- }
- }
- #[tokio::test]
- async fn separate_transmitter() {
- let (_sender, receiver) = file_server().await;
- let creds = proc_creds();
- let transmitter = transmitter(receiver.addr().clone(), Arc::new(creds))
- .await
- .unwrap();
- let reply = transmitter
- .call_through::<Msg>(Msg::Write {
- offset: 5,
- buf: &[7, 7, 7],
- })
- .await
- .unwrap();
- let matched = if let Reply::Success = reply {
- true
- } else {
- false
- };
- assert!(matched);
- }
- #[derive(Serialize, Deserialize)]
- struct Read {
- offset: usize,
- size: usize,
- }
- #[derive(Serialize, Deserialize)]
- struct ReadReply<'a> {
- buf: &'a [u8],
- }
- impl<'a> CallMsg<'a> for Read {
- type Reply<'b> = ReadReply<'b>;
- }
- #[derive(Clone)]
- struct ReadChecker<'a> {
- expected: &'a [u8],
- }
- impl<'a> ReadChecker<'a> {
- fn new(expected: &'a [u8]) -> Self {
- Self { expected }
- }
- }
- impl<'a> DeserCallback for ReadChecker<'a> {
- type Arg<'de> = ReadReply<'de> where Self: 'de;
- type Return = bool;
- type CallFut<'s> = Ready<bool> where Self: 's;
- fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
- ready(self.expected == arg.buf)
- }
- }
- trait ActionFn<Arg, Fut: Send + Future>: Send + Sync + Fn(MsgReceived<Arg>) -> Fut {}
- impl<Arg, Fut: Send + Future, T: Send + Sync + Fn(MsgReceived<Arg>) -> Fut> ActionFn<Arg, Fut>
- for T
- {
- }
- struct Action<Arg, Fut> {
- func: Arc<dyn ActionFn<Arg, Fut>>,
- }
- impl<Arg, Fut: Send + Future> Action<Arg, Fut> {
- fn new<F: 'static + ActionFn<Arg, Fut>>(func: F) -> Self {
- Self {
- func: Arc::new(func),
- }
- }
- }
- impl<Arg, Fut> Clone for Action<Arg, Fut> {
- fn clone(&self) -> Self {
- Self {
- func: self.func.clone(),
- }
- }
- }
- impl<Arg: for<'a> CallMsg<'a>, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
- for Action<Arg, Fut>
- {
- type Arg<'de> = Arg where Arg: 'de, Fut: 'de;
- type CallFut<'de> = Fut where Arg: 'de, Fut: 'de;
- fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
- (self.func)(arg)
- }
- }
- async fn read_server() -> (impl Transmitter, impl Receiver) {
- let file = [0, 1, 2, 3, 4, 5, 6, 7];
- proc_tx_rx(Action::new(move |mut msg: MsgReceived<Read>| async move {
- let body = msg.body();
- let start = body.offset;
- let end = start + body.size;
- let buf = &file[start..end];
- let mut replier = msg.take_replier().unwrap();
- replier.reply(ReadReply { buf }).await
- }))
- .await
- }
- #[tokio::test]
- async fn call_with_lifetime() {
- let (sender, _receiver) = read_server().await;
- let correct_one = sender
- .call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]))
- .await
- .unwrap();
- let correct_two = sender
- .call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]))
- .await
- .unwrap();
- assert!(correct_one);
- assert!(correct_two);
- }
- #[tokio::test]
- async fn call_concurrently() {
- let (sender, _receiver) = read_server().await;
- let call_one = sender.call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]));
- let call_two = sender.call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]));
- let (result_one, result_two) = join!(call_one, call_two);
- assert!(result_one.unwrap());
- assert!(result_two.unwrap());
- }
|