//! Code which enables sending messages between processes in the blocktree system. use btlib::{ bterr, crypto::{rand_array, ConcreteCreds, CredsPriv, CredsPub}, error::BoxInIoErr, BlockPath, Result, }; use btserde::{read_from, write_to}; use bytes::{BufMut, BytesMut}; use core::{ future::Future, marker::Send, pin::Pin, task::{Context, Poll}, }; use futures::{ future::{ready, Ready}, sink::{Close, Send as SendFut, Sink}, stream::Stream, SinkExt, StreamExt, }; use lazy_static::lazy_static; use log::error; use quinn::{ClientConfig, Endpoint, SendStream, ServerConfig}; use rustls::{ Certificate, ConfigBuilder, ConfigSide, PrivateKey, WantsCipherSuites, WantsVerifier, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, io, marker::PhantomData, net::{IpAddr, Shutdown, SocketAddr}, path::PathBuf, sync::Arc, }; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, net::UnixDatagram, sync::{ broadcast::{self, error::TryRecvError}, mpsc, }, }; use tokio_stream::wrappers::ReceiverStream; use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite}; use zerocopy::FromBytes; pub use private::*; mod private { use super::*; /// Returns a [Router] which can be used to make a [Receiver] for the given path and /// [Sender] instances for any path. pub fn router(addr: Arc, creds: &ConcreteCreds) -> Result { QuicRouter::new(addr, creds) } lazy_static! { /// The default directory in which to place blocktree sockets. static ref SOCK_DIR: PathBuf = { let mut path: PathBuf = std::env::var("HOME").unwrap().into(); path.push(".btmsg"); path }; } /// Appends the given Blocktree path to the path of the given directory. #[allow(dead_code)] fn socket_path(fs_path: &mut PathBuf, addr: &BlockAddr) { fs_path.push(addr.path.to_string()); } fn common_config( builder: ConfigBuilder, ) -> Result> { builder .with_cipher_suites(&[rustls::cipher_suite::TLS13_AES_128_GCM_SHA256]) .with_kx_groups(&[&rustls::kx_group::SECP256R1]) .with_protocol_versions(&[&rustls::version::TLS13]) .map_err(|err| err.into()) } fn server_config(creds: &ConcreteCreds) -> Result { let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?; let chain = writecap.to_cert_chain(creds.public_sign())?; let mut cert_chain = Vec::with_capacity(chain.len()); for cert in chain { cert_chain.push(Certificate(cert)) } let key = PrivateKey(creds.private_sign().to_der()?); let server_config = common_config(rustls::ServerConfig::builder())? .with_no_client_auth() .with_single_cert(cert_chain, key)?; Ok(ServerConfig::with_crypto(Arc::new(server_config))) } fn client_config() -> Result { let client_config = common_config(rustls::ClientConfig::builder())? .with_custom_certificate_verifier(CertVerifier::new()) .with_no_client_auth(); Ok(ClientConfig::new(Arc::new(client_config))) } /// An identifier for a block. Persistent blocks (files, directories, and servers) are /// identified by the `Inode` variant and transient blocks (processes) are identified by the /// PID variant. #[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug)] pub enum BlockNum { Inode(u64), Pid(u64), } impl BlockNum { pub fn value(&self) -> u64 { match self { BlockNum::Inode(value) => *value, BlockNum::Pid(value) => *value, } } } impl From for u64 { fn from(value: BlockNum) -> Self { value.value() } } #[derive(PartialEq, Eq, Hash, Clone, Debug)] pub struct BlockAddr { ip_addr: IpAddr, path: Arc, } impl BlockAddr { pub fn new(ip_addr: IpAddr, path: Arc) -> Self { Self { ip_addr, path } } fn port(&self) -> u16 { let mut hasher = DefaultHasher::new(); self.path.hash(&mut hasher); let hash = hasher.finish(); // We compute a port in the dynamic range [49152, 65535] as defined by RFC 6335. const NUM_RES_PORTS: u16 = 49153; const PORTS_AVAIL: u64 = (u16::MAX - NUM_RES_PORTS) as u64; NUM_RES_PORTS + (hash % PORTS_AVAIL) as u16 } pub fn socket_addr(&self) -> SocketAddr { let port = self.port(); SocketAddr::new(self.ip_addr, port) } } /// Generates and returns a new message ID. fn rand_msg_id() -> Result { const LEN: usize = std::mem::size_of::(); let bytes = rand_array::()?; let option = u128::read_from(bytes.as_slice()); // Safety: because LEN == size_of::(), read_from should have returned Some. Ok(option.unwrap()) } #[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)] pub struct Msg { pub id: u128, pub body: T, } impl Msg { pub fn new(id: u128, body: T) -> Self { Self { id, body } } pub fn with_rand_id(body: T) -> Result { Ok(Self { id: rand_msg_id()?, body, }) } } /// A type which can be used to send messages. /// Once the "Permit impl Trait in type aliases" https://github.com/rust-lang/rust/issues/63063 /// feature lands the future types in this trait should be rewritten to use it. pub trait Sender { type SendFut<'a, T>: 'a + Future> + Send where Self: 'a, T: 'a + Serialize + Send; fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: Msg) -> Self::SendFut<'a, T>; type FinishFut: Future> + Send; fn finish(self) -> Self::FinishFut; fn addr(&self) -> &BlockAddr; fn send_with_rand_id<'a, T: 'a + Serialize + Send>( &'a mut self, body: T, ) -> Self::SendFut<'a, T> { let msg = Msg::with_rand_id(body).unwrap(); self.send(msg) } } /// A type which can be used to receive messages. pub trait Receiver: Stream>> { fn addr(&self) -> &BlockAddr; } pub trait Router { type Sender: Sender + Send; type SenderFut<'a>: 'a + Future> + Send where Self: 'a; fn sender(&self, addr: Arc) -> Self::SenderFut<'_>; type Receiver: Receiver + Send + Unpin; type ReceiverFut<'a, T>: 'a + Future>> + Send where T: 'static + DeserializeOwned + Send + Unpin, Self: 'a; fn receiver( &self, ) -> Self::ReceiverFut<'_, T>; } /// Encodes and decodes messages using [btserde]. struct MsgEncoder; impl MsgEncoder { fn new() -> Self { Self } } impl Encoder> for MsgEncoder { type Error = btlib::Error; fn encode(&mut self, item: Msg, dst: &mut BytesMut) -> Result<()> { const U64_LEN: usize = std::mem::size_of::(); let payload = dst.split_off(U64_LEN); let mut writer = payload.writer(); write_to(&item, &mut writer)?; let payload = writer.into_inner(); let payload_len = payload.len() as u64; let mut writer = dst.writer(); write_to(&payload_len, &mut writer)?; let dst = writer.into_inner(); dst.unsplit(payload); Ok(()) } } struct MsgDecoder(PhantomData); impl MsgDecoder { fn new() -> Self { Self(PhantomData) } } impl Decoder for MsgDecoder { type Item = Msg; type Error = btlib::Error; fn decode(&mut self, src: &mut BytesMut) -> Result> { let mut slice: &[u8] = src.as_ref(); let payload_len: u64 = match read_from(&mut slice) { Ok(payload_len) => payload_len, Err(err) => { if let btserde::Error::Eof = err { return Ok(None); } return Err(err.into()); } }; let payload_len: usize = payload_len.try_into().box_err()?; if slice.len() < payload_len { src.reserve(payload_len - slice.len()); return Ok(None); } let msg = read_from(&mut slice)?; // Consume all the bytes that have been read out of the buffer. let _ = src.split_to(std::mem::size_of::() + payload_len); Ok(Some(msg)) } } /// Wraps a [UnixDatagram] and implements [AsyncRead] and [AsyncWrite] for it. Read operations /// are translated to calls to `recv_from` and write operations are translated to `send`. Note /// that this means that writes will fail unless the wrapped socket is connected to a peer. struct DatagramAdapter { socket: UnixDatagram, } impl DatagramAdapter { #[allow(dead_code)] fn new(socket: UnixDatagram) -> Self { Self { socket } } fn get_ref(&self) -> &UnixDatagram { &self.socket } fn get_mut(&mut self) -> &mut UnixDatagram { &mut self.socket } } impl AsRef for DatagramAdapter { fn as_ref(&self) -> &UnixDatagram { self.get_ref() } } impl AsMut for DatagramAdapter { fn as_mut(&mut self) -> &mut UnixDatagram { self.get_mut() } } impl AsyncRead for DatagramAdapter { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { self.socket.poll_recv(cx, buf) } } impl AsyncWrite for DatagramAdapter { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { self.socket.poll_send(cx, buf) } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(self.socket.shutdown(Shutdown::Write)) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } /// An implementation of [Receiver] which uses a Unix datagram socket for receiving messages. struct UnixReceiver { addr: BlockAddr, socket: FramedRead>, } impl UnixReceiver { #[allow(dead_code)] fn new(mut fs_path: PathBuf, addr: BlockAddr) -> Result { socket_path(&mut fs_path, &addr); let socket = DatagramAdapter::new(UnixDatagram::bind(fs_path)?); let socket = FramedRead::new(socket, MsgDecoder(PhantomData)); Ok(Self { addr, socket }) } } impl Stream for UnixReceiver { type Item = Result>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.socket.poll_next_unpin(cx) } } impl Receiver for UnixReceiver { fn addr(&self) -> &BlockAddr { &self.addr } } /// An implementation of [Sender] which uses a Unix datagram socket to send messages. struct UnixSender { addr: BlockAddr, socket: FramedWrite, } impl UnixSender { #[allow(dead_code)] fn new(mut fs_path: PathBuf, addr: BlockAddr) -> Result { let socket = UnixDatagram::unbound()?; socket_path(&mut fs_path, &addr); socket.connect(fs_path)?; let socket = FramedWrite::new(DatagramAdapter::new(socket), MsgEncoder); Ok(Self { addr, socket }) } } impl Sink> for UnixSender { type Error = btlib::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { as futures::SinkExt< Msg, >>::poll_ready_unpin(&mut self.socket, cx) } fn start_send(mut self: Pin<&mut Self>, item: Msg) -> Result<()> { self.socket.start_send_unpin(item) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { as futures::SinkExt< Msg, >>::poll_flush_unpin(&mut self.socket, cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { as futures::SinkExt< Msg, >>::poll_close_unpin(&mut self.socket, cx) } } impl Sender for UnixSender { fn addr(&self) -> &BlockAddr { &self.addr } type SendFut<'a, T> = SendFut<'a, FramedWrite, Msg> where T: 'a + Serialize + Send; fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: Msg) -> Self::SendFut<'a, T> { self.socket.send(msg) } type FinishFut = Pin> + Send>>; fn finish(mut self) -> Self::FinishFut { Box::pin(async move { let fut: Close<'_, _, Msg<()>> = self.socket.close(); fut.await }) } } /// Causes the current function to return if the given `rx` has received a stop signal. macro_rules! check_stop { ($rx:expr) => { match $rx.try_recv() { Ok(_) => return, Err(err) => { if let TryRecvError::Closed = err { return; } } } }; } struct QuicRouter { recv_addr: Arc, endpoint: Endpoint, } impl QuicRouter { fn new(recv_addr: Arc, creds: &ConcreteCreds) -> Result { let socket_addr = recv_addr.socket_addr(); let endpoint = Endpoint::server(server_config(creds)?, socket_addr)?; Ok(Self { endpoint, recv_addr, }) } } impl Router for QuicRouter { type Receiver = QuicReceiver; type ReceiverFut<'a, T: 'static + DeserializeOwned + Send + Unpin> = Ready>>; fn receiver( &self, ) -> Self::ReceiverFut<'_, T> { ready(QuicReceiver::new( self.endpoint.clone(), self.recv_addr.clone(), )) } type Sender = QuicSender; type SenderFut<'a> = Pin> + Send>>; fn sender(&self, addr: Arc) -> Self::SenderFut<'_> { Box::pin(async { QuicSender::from_endpoint(self.endpoint.clone(), addr).await }) } } struct QuicReceiver { recv_addr: Arc, stop_tx: broadcast::Sender<()>, stream: ReceiverStream>>, } impl QuicReceiver { /// The size of the buffer to store received messages in. const MSG_BUF_SZ: usize = 64; fn new(endpoint: Endpoint, recv_addr: Arc) -> Result { let (stop_tx, mut stop_rx) = broadcast::channel(1); let (msg_tx, msg_rx) = mpsc::channel(Self::MSG_BUF_SZ); tokio::spawn(async move { loop { check_stop!(stop_rx); let connecting = match endpoint.accept().await { Some(connection) => connection, None => break, }; let connection = match connecting.await { Ok(connection) => connection, Err(err) => { error!("error accepting QUIC connection: {err}"); continue; } }; let conn_msg_tx = msg_tx.clone(); let mut conn_stop_rx = stop_rx.resubscribe(); tokio::spawn(async move { let recv_stream = match connection.accept_uni().await { Ok(recv_stream) => recv_stream, Err(err) => { error!("error accepting receive stream: {err}"); return; } }; let mut msg_stream = FramedRead::new(recv_stream, MsgDecoder::new()); loop { check_stop!(conn_stop_rx); let result = match msg_stream.next().await { Some(result) => result, None => return, }; if let Err(err) = conn_msg_tx.send(result).await { error!("error sending message to mpsc queue: {err}"); } } }); } }); Ok(Self { recv_addr, stop_tx, stream: ReceiverStream::new(msg_rx), }) } } impl Drop for QuicReceiver { fn drop(&mut self) { // This result will be a failure if the tasks have already returned, which is not a // problem. let _ = self.stop_tx.send(()); } } impl Stream for QuicReceiver { type Item = Result>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.stream.poll_next_unpin(cx) } } impl Receiver for QuicReceiver { fn addr(&self) -> &BlockAddr { &self.recv_addr } } struct CertVerifier; impl CertVerifier { fn new() -> Arc { Arc::new(Self) } } impl rustls::client::ServerCertVerifier for CertVerifier { fn verify_server_cert( &self, _end_entity: &Certificate, _intermediates: &[Certificate], _server_name: &rustls::ServerName, _scts: &mut dyn Iterator, _ocsp_response: &[u8], _now: std::time::SystemTime, ) -> std::result::Result { // TODO: Implement certificate verification. Ok(rustls::client::ServerCertVerified::assertion()) } } struct QuicSender { addr: Arc, sink: FramedWrite, } impl QuicSender { async fn from_endpoint(endpoint: Endpoint, addr: Arc) -> Result { let socket_addr = addr.socket_addr(); let connecting = endpoint.connect_with(client_config()?, socket_addr, "localhost")?; let connection = connecting.await?; let send_stream = connection.open_uni().await?; let sink = FramedWrite::new(send_stream, MsgEncoder::new()); Ok(Self { addr, sink }) } } impl Sender for QuicSender { fn addr(&self) -> &BlockAddr { &self.addr } type SendFut<'a, T> = SendFut<'a, FramedWrite, Msg> where T: 'a + Serialize + Send; fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: Msg) -> Self::SendFut<'a, T> { self.sink.send(msg) } type FinishFut = Pin> + Send>>; fn finish(mut self) -> Self::FinishFut { Box::pin(async move { let steam: &mut SendStream = self.sink.get_mut(); steam.finish().await.map_err(|err| bterr!(err)) }) } } /// This is an identify function which allows you to specify a type parameter for the output /// of a future. /// This was needed to work around a failure in type inference for types with higher-rank /// lifetimes. Once this issue is resolved this can be removed: /// https://github.com/rust-lang/rust/issues/102211 pub fn assert_send<'a, T>( fut: impl 'a + Future + Send, ) -> impl 'a + Future + Send { fut } } #[cfg(test)] mod tests { use super::*; use btlib::{crypto::Creds, Epoch, Principal, Principaled}; use ctor::ctor; use std::{ net::Ipv6Addr, sync::atomic::{AtomicU64, Ordering}, time::Duration, }; #[ctor] fn setup_logging() { env_logger::init(); } lazy_static! { static ref ROOT_CREDS: ConcreteCreds = ConcreteCreds::generate().unwrap(); static ref NODE_CREDS: ConcreteCreds = { let mut creds = ConcreteCreds::generate().unwrap(); let root_creds = &ROOT_CREDS; let writecap = root_creds .issue_writecap( creds.principal(), vec![], Epoch::now() + Duration::from_secs(3600), ) .unwrap(); creds.set_writecap(writecap); creds }; static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal(); } fn block_addr<'a, I: Iterator, S: ToString>(components: I) -> BlockAddr { let components = components.map(|e| e.to_string()).collect(); let path = BlockPath::new(ROOT_CREDS.principal(), components); let path = Arc::new(path); BlockAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), path) } #[derive(Serialize, Deserialize)] enum MsgError { Unknown, } #[derive(Deserialize)] enum BodyOwned { Ping, Success, Fail(MsgError), Read { offset: u64, size: u64 }, Write { offset: u64, buf: Vec }, } #[derive(Serialize)] enum BodyRef<'a> { Ping, Success, Fail(MsgError), Read { offset: u64, size: u64 }, Write { offset: u64, buf: &'a [u8] }, } struct TestCase { instance_num: u64, } impl TestCase { fn new() -> TestCase { static INSTANCE_NUM: AtomicU64 = AtomicU64::new(0); let instance_num = INSTANCE_NUM.fetch_add(1, Ordering::SeqCst); Self { instance_num } } async fn endpoint(&self, inode: u64) -> (impl Sender, impl Receiver) { let addr = Arc::new(block_addr([self.instance_num, inode].iter())); let router = router(addr.clone(), &NODE_CREDS).unwrap(); let receiver = router.receiver::().await.unwrap(); let sender = router.sender(addr).await.unwrap(); (sender, receiver) } } #[tokio::test] async fn message_received_is_message_sent() { let case = TestCase::new(); let (mut sender, mut receiver) = case.endpoint(1).await; sender.send_with_rand_id(BodyRef::Ping).await.unwrap(); let actual = receiver.next().await.unwrap().unwrap(); let matched = if let BodyOwned::Ping = actual.body { true } else { false }; assert!(matched); } #[tokio::test] async fn ping_pong() { let case = TestCase::new(); let (mut sender_one, mut receiver_one) = case.endpoint(1).await; let (mut sender_two, mut receiver_two) = case.endpoint(2).await; tokio::spawn(async move { let msg = receiver_one.next().await.unwrap().unwrap(); let reply_body = if let BodyOwned::Ping = msg.body { BodyRef::Success } else { BodyRef::Fail(MsgError::Unknown) }; let fut = assert_send::<'_, Result<()>>(sender_two.send_with_rand_id(reply_body)); fut.await.unwrap(); sender_two.finish().await.unwrap(); }); sender_one.send_with_rand_id(BodyRef::Ping).await.unwrap(); let reply = receiver_two.next().await.unwrap().unwrap(); let matched = if let BodyOwned::Success = reply.body { true } else { false }; assert!(matched) } #[tokio::test] async fn read_write() { let case = TestCase::new(); let (mut sender_one, mut receiver_one) = case.endpoint(1).await; let (mut sender_two, mut receiver_two) = case.endpoint(2).await; let handle = tokio::spawn(async move { let data: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7]; let msg = receiver_one.next().await.unwrap().unwrap(); let reply_body = if let BodyOwned::Read { offset, size } = msg.body { let offset: usize = offset.try_into().unwrap(); let size: usize = size.try_into().unwrap(); let end: usize = offset + size; BodyRef::Write { offset: offset as u64, buf: &data[offset..end], } } else { BodyRef::Fail(MsgError::Unknown) }; let msg = Msg::new(msg.id, reply_body); let fut = assert_send::<'_, Result<()>>(sender_two.send(msg)); fut.await.unwrap(); sender_two.finish().await.unwrap(); }); sender_one .send_with_rand_id(BodyRef::Read { offset: 2, size: 2 }) .await .unwrap(); handle.await.unwrap(); let reply = receiver_two.next().await.unwrap().unwrap(); if let BodyOwned::Write { offset, buf } = reply.body { assert_eq!(2, offset); assert_eq!([2, 3].as_slice(), buf.as_slice()); } else { panic!("reply was not the right type"); }; } }