// SPDX-License-Identifier: AGPL-3.0-or-later //! Code which enables sending messages between processes in the blocktree system. #![feature(impl_trait_in_assoc_type)] mod tls; use tls::*; mod callback_framed; use callback_framed::CallbackFramed; pub use callback_framed::DeserCallback; use btlib::{bterr, crypto::Creds, error::DisplayErr, BlockPath, Result, Writecap}; use btserde::{field_helpers::smart_ptr, write_to}; use bytes::{BufMut, BytesMut}; use core::{ future::{ready, Future, Ready}, marker::Send, pin::Pin, }; use futures::{FutureExt, SinkExt}; use log::{debug, error}; use quinn::{Connection, ConnectionError, Endpoint, RecvStream, SendStream}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::{ any::Any, fmt::Display, hash::Hash, io, marker::PhantomData, net::{IpAddr, Ipv6Addr, SocketAddr}, result::Result as StdResult, sync::{Arc, Mutex as StdMutex}, }; use tokio::{ select, sync::{broadcast, Mutex}, task::{JoinError, JoinHandle}, }; use tokio_util::codec::{Encoder, Framed, FramedParts, FramedWrite}; /// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of /// the [Writecap] in the given credentials. The returned type can be used to make /// [Transmitter]s for any path. pub fn receiver( ip_addr: IpAddr, creds: Arc, callback: F, ) -> Result { let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?; let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path()))); QuicReceiver::new(addr, Arc::new(CertResolver::new(creds)?), callback) } pub async fn transmitter( addr: Arc, creds: Arc, ) -> Result { let resolver = Arc::new(CertResolver::new(creds)?); let endpoint = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))?; QuicTransmitter::from_endpoint(endpoint, addr, resolver).await } pub trait MsgCallback: Clone + Send + Sync + Unpin { type Arg<'de>: CallMsg<'de> where Self: 'de; type CallFut<'de>: Future> + Send where Self: 'de; fn call<'de>(&'de self, arg: MsgReceived>) -> Self::CallFut<'de>; } impl MsgCallback for &T { type Arg<'de> = T::Arg<'de> where Self: 'de; type CallFut<'de> = T::CallFut<'de> where Self: 'de; fn call<'de>(&'de self, arg: MsgReceived>) -> Self::CallFut<'de> { (*self).call(arg) } } /// Trait for messages which can be transmitted using the call method. pub trait CallMsg<'de>: Serialize + Deserialize<'de> + Send + Sync { type Reply<'r>: Serialize + Deserialize<'r> + Send; } /// Trait for messages which can be transmitted using the send method. /// Types which implement this trait should specify `()` as their reply type. pub trait SendMsg<'de>: CallMsg<'de> {} /// An address which identifies a block on the network. An instance of this struct can be /// used to get a socket address for the block this address refers to. #[derive(PartialEq, Eq, Hash, Clone, Debug, Serialize, Deserialize)] pub struct BlockAddr { #[serde(rename = "ipaddr")] ip_addr: IpAddr, #[serde(with = "smart_ptr")] path: Arc, } impl BlockAddr { pub fn new(ip_addr: IpAddr, path: Arc) -> Self { Self { ip_addr, path } } pub fn ip_addr(&self) -> IpAddr { self.ip_addr } pub fn path(&self) -> &BlockPath { self.path.as_ref() } fn port(&self) -> Result { self.path.port() } /// Returns the socket address of the block this instance refers to. pub fn socket_addr(&self) -> Result { Ok(SocketAddr::new(self.ip_addr, self.port()?)) } } impl Display for BlockAddr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}@{}", self.path, self.ip_addr) } } /// Indicates whether a message was sent using `call` or `send`. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)] enum MsgKind { /// This message expects exactly one reply. Call, /// This message expects exactly zero replies. Send, } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)] struct Envelope { kind: MsgKind, msg: T, } impl Envelope { fn send(msg: T) -> Self { Self { msg, kind: MsgKind::Send, } } fn call(msg: T) -> Self { Self { msg, kind: MsgKind::Call, } } fn msg(&self) -> &T { &self.msg } } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)] enum ReplyEnvelope { Ok(T), Err { message: String, os_code: Option, }, } impl ReplyEnvelope { fn err(message: String, os_code: Option) -> Self { Self::Err { message, os_code } } } /// A message tagged with the block path that it was sent from. pub struct MsgReceived { from: Arc, msg: Envelope, replier: Option, } impl MsgReceived { fn new(from: Arc, msg: Envelope, replier: Option) -> Self { Self { from, msg, replier } } pub fn into_parts(self) -> (Arc, T, Option) { (self.from, self.msg.msg, self.replier) } /// The path from which this message was received. pub fn from(&self) -> &Arc { &self.from } /// Payload contained in this message. pub fn body(&self) -> &T { self.msg.msg() } /// Returns true if and only if this messages needs to be replied to. pub fn needs_reply(&self) -> bool { self.replier.is_some() } /// Takes the replier out of this struct and returns it, if it has not previously been returned. pub fn take_replier(&mut self) -> Option { self.replier.take() } } /// Trait for receiving messages and creating [Transmitter]s. pub trait Receiver { /// The address at which messages will be received. fn addr(&self) -> &Arc; type Transmitter: Transmitter + Send; type TransmitterFut<'a>: 'a + Future> + Send where Self: 'a; /// Creates a [Transmitter] which is connected to the given address. fn transmitter(&self, addr: Arc) -> Self::TransmitterFut<'_>; type CompleteErr: std::error::Error + Send; type CompleteFut<'a>: 'a + Future> + Send where Self: 'a; /// Returns a future which completes when this [Receiver] has completed (which may be never). fn complete(&self) -> Result>; type StopFut<'a>: 'a + Future> + Send where Self: 'a; fn stop(&self) -> Self::StopFut<'_>; } /// A type which can be used to transmit messages. pub trait Transmitter { type SendFut<'call, T>: 'call + Future> + Send where Self: 'call, T: 'call + SendMsg<'call>; /// Transmit a message to the connected [Receiver] without waiting for a reply. fn send<'call, T: 'call + SendMsg<'call>>(&'call self, msg: T) -> Self::SendFut<'call, T>; type CallFut<'call, T, F>: 'call + Future> + Send where Self: 'call, T: 'call + CallMsg<'call>, F: 'static + Send + DeserCallback; /// Transmit a message to the connected [Receiver], waits for a reply, then calls the given /// [DeserCallback] with the deserialized reply. /// /// ## WARNING /// The callback must be such that `F::Arg<'a> = T::Reply<'a>` for any `'a`. If this /// is violated, then a deserilization error will occur at runtime. /// /// ## TODO /// This issue needs to be fixed. Due to the fact that /// `F::Arg` is a Generic Associated Type (GAT) I have been unable to express this constraint in /// the where clause of this method. I'm not sure if the errors I've encountered are due to a /// lack of understanding on my part or due to the current limitations of the borrow checker in /// its handling of GATs. fn call<'call, T, F>(&'call self, msg: T, callback: F) -> Self::CallFut<'call, T, F> where T: 'call + CallMsg<'call>, F: 'static + Send + DeserCallback; /// Transmits a message to the connected [Receiver], waits for a reply, then passes back the /// the reply to the caller. fn call_through<'call, T>( &'call self, msg: T, ) -> Self::CallFut<'call, T, Passthrough>> where T: 'call + CallMsg<'call>, T::Reply<'call>: 'static + Send + Sync + DeserializeOwned, { self.call(msg, Passthrough::new()) } /// Returns the address that this instance is transmitting to. fn addr(&self) -> &Arc; } pub struct Passthrough { phantom: PhantomData, } impl Passthrough { pub fn new() -> Self { Self { phantom: PhantomData, } } } impl Default for Passthrough { fn default() -> Self { Self::new() } } impl Clone for Passthrough { fn clone(&self) -> Self { Self::new() } } impl DeserCallback for Passthrough { type Arg<'de> = T; type Return = T; type CallFut<'de> = Ready; fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> { ready(arg) } } /// Encodes messages using [btserde]. #[derive(Debug)] struct MsgEncoder; impl MsgEncoder { fn new() -> Self { Self } } impl Encoder for MsgEncoder { type Error = btlib::Error; fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<()> { const U64_LEN: usize = std::mem::size_of::(); let payload = dst.split_off(U64_LEN); let mut writer = payload.writer(); write_to(&item, &mut writer)?; let payload = writer.into_inner(); let payload_len = payload.len() as u64; let mut writer = dst.writer(); write_to(&payload_len, &mut writer)?; let dst = writer.into_inner(); dst.unsplit(payload); Ok(()) } } type FramedMsg = FramedWrite; type ArcMutex = Arc>; #[derive(Clone)] pub struct Replier { stream: ArcMutex, } impl Replier { fn new(stream: ArcMutex) -> Self { Self { stream } } pub async fn reply(&mut self, reply: T) -> Result<()> { let mut guard = self.stream.lock().await; guard.send(ReplyEnvelope::Ok(reply)).await?; Ok(()) } pub async fn reply_err(&mut self, err: String, os_code: Option) -> Result<()> { let mut guard = self.stream.lock().await; guard.send(ReplyEnvelope::<()>::err(err, os_code)).await?; Ok(()) } } struct MsgRecvdCallback { path: Arc, replier: Replier, inner: F, } impl MsgRecvdCallback { fn new(path: Arc, framed_msg: Arc>, inner: F) -> Self { Self { path, replier: Replier::new(framed_msg), inner, } } } impl DeserCallback for MsgRecvdCallback { type Arg<'de> = Envelope> where Self: 'de; type Return = Result<()>; type CallFut<'de> = impl 'de + Future + Send where F: 'de, Self: 'de; fn call<'de>(&'de mut self, arg: Envelope>) -> Self::CallFut<'de> { let replier = match arg.kind { MsgKind::Call => Some(self.replier.clone()), MsgKind::Send => None, }; async move { let result = self .inner .call(MsgReceived::new(self.path.clone(), arg, replier)) .await; match result { Ok(value) => Ok(value), Err(err) => match err.downcast::() { Ok(err) => { self.replier .reply_err(err.to_string(), err.raw_os_error()) .await } Err(err) => self.replier.reply_err(err.to_string(), None).await, }, } } } } macro_rules! handle_err { ($result:expr, $on_err:expr, $control_flow:expr) => { match $result { Ok(inner) => inner, Err(err) => { $on_err(err); $control_flow; } } }; } /// Unwraps the given result, or if the result is an error, returns from the enclosing function. macro_rules! unwrap_or_return { ($result:expr, $on_err:expr) => { handle_err!($result, $on_err, return) }; ($result:expr) => { unwrap_or_return!($result, |err| error!("{err}")) }; } /// Unwraps the given result, or if the result is an error, continues the enclosing loop. macro_rules! unwrap_or_continue { ($result:expr, $on_err:expr) => { handle_err!($result, $on_err, continue) }; ($result:expr) => { unwrap_or_continue!($result, |err| error!("{err}")) }; } /// Awaits its first argument, unless interrupted by its second argument, in which case the /// enclosing function returns. The second argument needs to be cancel safe, but the first /// need not be if it is discarded when the enclosing function returns (because losing messages /// from the first argument doesn't matter in this case). macro_rules! await_or_stop { ($future:expr, $stop_fut:expr) => { select! { Some(connecting) = $future => connecting, _ = $stop_fut => break, } }; } struct QuicReceiver { recv_addr: Arc, stop_tx: broadcast::Sender<()>, endpoint: Endpoint, resolver: Arc, join_handle: StdMutex>>, } impl QuicReceiver { fn new( recv_addr: Arc, resolver: Arc, callback: F, ) -> Result { log::info!("starting QuicReceiver with address {}", recv_addr); let socket_addr = recv_addr.socket_addr()?; let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?; let (stop_tx, stop_rx) = broadcast::channel(1); let join_handle = tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx)); Ok(Self { recv_addr, stop_tx, endpoint, resolver, join_handle: StdMutex::new(Some(join_handle)), }) } async fn server_loop( endpoint: Endpoint, callback: F, mut stop_rx: broadcast::Receiver<()>, ) { loop { let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv()); let connection = unwrap_or_continue!(connecting.await, |err| error!( "error accepting QUIC connection: {err}" )); tokio::spawn(Self::handle_connection( connection, callback.clone(), stop_rx.resubscribe(), )); } } async fn handle_connection( connection: Connection, callback: F, mut stop_rx: broadcast::Receiver<()>, ) { let client_path = unwrap_or_return!( Self::client_path(connection.peer_identity()), |err| error!("failed to get client path from peer identity: {err}") ); loop { let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv()); let (send_stream, recv_stream) = match result { Ok(pair) => pair, Err(err) => match err { ConnectionError::ApplicationClosed(app) => { debug!("connection closed: {app}"); return; } _ => { error!("error accepting stream: {err}"); continue; } }, }; let client_path = client_path.clone(); let callback = callback.clone(); tokio::task::spawn(Self::handle_message( client_path, send_stream, recv_stream, callback, )); } } async fn handle_message( client_path: Arc, send_stream: SendStream, recv_stream: RecvStream, callback: F, ) { let framed_msg = Arc::new(Mutex::new(FramedWrite::new(send_stream, MsgEncoder::new()))); let callback = MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone()); let mut msg_stream = CallbackFramed::new(recv_stream); let result = msg_stream .next(callback) .await .ok_or_else(|| bterr!("client closed stream before sending a message")); match unwrap_or_return!(result) { Err(err) => error!("msg_stream produced an error: {err}"), Ok(result) => { if let Err(err) = result { error!("callback returned an error: {err}"); } } } } /// Returns the path the client is bound to. fn client_path(peer_identity: Option>) -> Result> { let peer_identity = peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?; let client_certs = peer_identity .downcast::>() .map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?; let first = client_certs .first() .ok_or_else(|| bterr!("no certificates were presented by the client"))?; let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?; Ok(Arc::new(writecap.bind_path())) } } impl Drop for QuicReceiver { fn drop(&mut self) { // This result will be a failure if the tasks have already returned, which is not a // problem. let _ = self.stop_tx.send(()); } } impl Receiver for QuicReceiver { fn addr(&self) -> &Arc { &self.recv_addr } type Transmitter = QuicTransmitter; type TransmitterFut<'a> = Pin> + Send>>; fn transmitter(&self, addr: Arc) -> Self::TransmitterFut<'_> { Box::pin(async { QuicTransmitter::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await }) } type CompleteErr = JoinError; type CompleteFut<'a> = JoinHandle<()>; fn complete(&self) -> Result> { let mut guard = self.join_handle.lock().display_err()?; let handle = guard .take() .ok_or_else(|| bterr!("join handle has already been taken"))?; Ok(handle) } type StopFut<'a> = Ready>; fn stop(&self) -> Self::StopFut<'_> { ready(self.stop_tx.send(()).map(|_| ()).map_err(|err| err.into())) } } macro_rules! cleanup_on_err { ($result:expr, $guard:ident, $parts:ident) => { match $result { Ok(value) => value, Err(err) => { *$guard = Some($parts); return Err(err.into()); } } }; } struct QuicTransmitter { addr: Arc, connection: Connection, send_parts: Mutex>>, recv_buf: Mutex>, } impl QuicTransmitter { async fn from_endpoint( endpoint: Endpoint, addr: Arc, resolver: Arc, ) -> Result { let socket_addr = addr.socket_addr()?; let connecting = endpoint.connect_with( client_config(addr.path.clone(), resolver)?, socket_addr, // The ServerCertVerifier ensures we connect to the correct path. "UNIMPORTANT", )?; let connection = connecting.await?; let send_parts = Mutex::new(None); let recv_buf = Mutex::new(Some(BytesMut::new())); Ok(Self { addr, connection, send_parts, recv_buf, }) } async fn transmit(&self, envelope: Envelope) -> Result { let mut guard = self.send_parts.lock().await; let (send_stream, recv_stream) = self.connection.open_bi().await?; let parts = match guard.take() { Some(mut parts) => { parts.io = send_stream; parts } None => FramedParts::new::>(send_stream, MsgEncoder::new()), }; let mut sink = Framed::from_parts(parts); let result = sink.send(envelope).await; let parts = sink.into_parts(); cleanup_on_err!(result, guard, parts); *guard = Some(parts); Ok(recv_stream) } async fn call<'ser, T, F>(&'ser self, msg: T, callback: F) -> Result where T: 'ser + CallMsg<'ser>, F: 'static + Send + DeserCallback, { let recv_stream = self.transmit(Envelope::call(msg)).await?; let mut guard = self.recv_buf.lock().await; let buffer = guard.take().unwrap(); let mut callback_framed = CallbackFramed::from_parts(recv_stream, buffer); let result = callback_framed .next(ReplyCallback::new(callback)) .await .ok_or_else(|| bterr!("server hung up before sending reply")); let (_, buffer) = callback_framed.into_parts(); let output = cleanup_on_err!(result, guard, buffer); *guard = Some(buffer); output? } } impl Transmitter for QuicTransmitter { fn addr(&self) -> &Arc { &self.addr } type SendFut<'ser, T> = impl 'ser + Future> + Send where T: 'ser + SendMsg<'ser>; fn send<'ser, 'r, T: 'ser + SendMsg<'ser>>(&'ser self, msg: T) -> Self::SendFut<'ser, T> { self.transmit(Envelope::send(msg)) .map(|result| result.map(|_| ())) } type CallFut<'ser, T, F> = impl 'ser + Future> + Send where Self: 'ser, T: 'ser + CallMsg<'ser>, F: 'static + Send + DeserCallback; fn call<'ser, T, F>(&'ser self, msg: T, callback: F) -> Self::CallFut<'ser, T, F> where T: 'ser + CallMsg<'ser>, F: 'static + Send + DeserCallback, { self.call(msg, callback) } } struct ReplyCallback { inner: F, } impl ReplyCallback { fn new(inner: F) -> Self { Self { inner } } } impl DeserCallback for ReplyCallback { type Arg<'de> = ReplyEnvelope>; type Return = Result; type CallFut<'de> = impl 'de + Future + Send; fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> { async move { match arg { ReplyEnvelope::Ok(msg) => Ok(self.inner.call(msg).await), ReplyEnvelope::Err { message, os_code } => { if let Some(os_code) = os_code { let err = bterr!(io::Error::from_raw_os_error(os_code)).context(message); Err(err) } else { Err(bterr!(message)) } } } } } }