// SPDX-License-Identifier: AGPL-3.0-or-later //! Types used for receiving messages over the network. Chiefly, the [Receiver] type. use std::{ any::Any, future::Future, io, net::IpAddr, sync::{Arc, Mutex as StdMutex}, }; use btlib::{bterr, crypto::Creds, error::DisplayErr, BlockPath, Writecap}; use futures::{FutureExt, SinkExt}; use log::{debug, error}; use quinn::{Connection, ConnectionError, Endpoint, RecvStream, SendStream}; use serde::{Deserialize, Serialize}; use tokio::{ select, sync::{broadcast, Mutex}, task::JoinHandle, }; use tokio_util::codec::FramedWrite; use crate::{ serialization::{CallbackFramed, MsgEncoder}, tls::{server_config, CertResolver}, BlockAddr, CallMsg, DeserCallback, Result, Transmitter, }; macro_rules! handle_err { ($result:expr, $on_err:expr, $control_flow:expr) => { match $result { Ok(inner) => inner, Err(err) => { $on_err(err); $control_flow; } } }; } /// Unwraps the given result, or if the result is an error, returns from the enclosing function. macro_rules! unwrap_or_return { ($result:expr, $on_err:expr) => { handle_err!($result, $on_err, return) }; ($result:expr) => { unwrap_or_return!($result, |err| error!("{err}")) }; } /// Unwraps the given result, or if the result is an error, continues the enclosing loop. macro_rules! unwrap_or_continue { ($result:expr, $on_err:expr) => { handle_err!($result, $on_err, continue) }; ($result:expr) => { unwrap_or_continue!($result, |err| error!("{err}")) }; } /// Awaits its first argument, unless interrupted by its second argument, in which case the /// enclosing function returns. The second argument needs to be cancel safe, but the first /// need not be if it is discarded when the enclosing function returns (because losing messages /// from the first argument doesn't matter in this case). macro_rules! await_or_stop { ($future:expr, $stop_fut:expr) => { select! { Some(connecting) = $future => connecting, _ = $stop_fut => break, } }; } /// Type which receives messages sent over the network sent by a [Transmitter]. pub struct Receiver { recv_addr: Arc, stop_tx: broadcast::Sender<()>, endpoint: Endpoint, resolver: Arc, join_handle: StdMutex>>, } impl Receiver { /// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of /// the [Writecap] in the given credentials. The returned type can be used to make /// [Transmitter]s for any path. pub fn new( ip_addr: IpAddr, creds: Arc, callback: F, ) -> Result { let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?; let recv_addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path()))); log::info!("starting Receiver with address {}", recv_addr); let socket_addr = recv_addr.socket_addr()?; let resolver = Arc::new(CertResolver::new(creds)?); let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?; let (stop_tx, stop_rx) = broadcast::channel(1); let join_handle = tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx)); Ok(Self { recv_addr, stop_tx, endpoint, resolver, join_handle: StdMutex::new(Some(join_handle)), }) } async fn server_loop( endpoint: Endpoint, callback: F, mut stop_rx: broadcast::Receiver<()>, ) { loop { let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv()); let connection = unwrap_or_continue!(connecting.await, |err| error!( "error accepting QUIC connection: {err}" )); tokio::spawn(Self::handle_connection( connection, callback.clone(), stop_rx.resubscribe(), )); } } async fn handle_connection( connection: Connection, callback: F, mut stop_rx: broadcast::Receiver<()>, ) { let client_path = unwrap_or_return!( Self::client_path(connection.peer_identity()), |err| error!("failed to get client path from peer identity: {err}") ); loop { let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv()); let (send_stream, recv_stream) = match result { Ok(pair) => pair, Err(err) => match err { ConnectionError::ApplicationClosed(app) => { debug!("connection closed: {app}"); return; } _ => { error!("error accepting stream: {err}"); continue; } }, }; let client_path = client_path.clone(); let callback = callback.clone(); tokio::task::spawn(Self::handle_message( client_path, send_stream, recv_stream, callback, )); } } async fn handle_message( client_path: Arc, send_stream: SendStream, recv_stream: RecvStream, callback: F, ) { let framed_msg = Arc::new(Mutex::new(FramedWrite::new(send_stream, MsgEncoder::new()))); let callback = MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone()); let mut msg_stream = CallbackFramed::new(recv_stream); let result = msg_stream .next(callback) .await .ok_or_else(|| bterr!("client closed stream before sending a message")); match unwrap_or_return!(result) { Err(err) => error!("msg_stream produced an error: {err}"), Ok(result) => { if let Err(err) = result { error!("callback returned an error: {err}"); } } } } /// Returns the path the client is bound to. fn client_path(peer_identity: Option>) -> Result> { let peer_identity = peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?; let client_certs = peer_identity .downcast::>() .map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?; let first = client_certs .first() .ok_or_else(|| bterr!("no certificates were presented by the client"))?; let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?; Ok(Arc::new(writecap.bind_path())) } /// The address at which messages will be received. pub fn addr(&self) -> &Arc { &self.recv_addr } /// Creates a [Transmitter] which is connected to the given address. pub async fn transmitter(&self, addr: Arc) -> Result { Transmitter::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await } /// Returns a future which completes when this [Receiver] has completed /// (which may be never). pub fn complete(&self) -> Result> { let mut guard = self.join_handle.lock().display_err()?; let handle = guard .take() .ok_or_else(|| bterr!("join handle has already been taken"))?; Ok(handle) } /// Sends a signal indicating that the task running the server loop should return. pub fn stop(&self) -> Result<()> { self.stop_tx.send(()).map(|_| ()).map_err(|err| err.into()) } } impl Drop for Receiver { fn drop(&mut self) { // This result will be a failure if the tasks have already returned, which is not a // problem. let _ = self.stop_tx.send(()); } } /// Trait for types which can be called to handle messages received over the network. The /// server loop in [Receiver] uses a type that implements this trait to react to messages it /// receives. pub trait MsgCallback: Clone + Send + Sync + Unpin { type Arg<'de>: CallMsg<'de> where Self: 'de; type CallFut<'de>: Future> + Send where Self: 'de; fn call<'de>(&'de self, arg: MsgReceived>) -> Self::CallFut<'de>; } impl MsgCallback for &T { type Arg<'de> = T::Arg<'de> where Self: 'de; type CallFut<'de> = T::CallFut<'de> where Self: 'de; fn call<'de>(&'de self, arg: MsgReceived>) -> Self::CallFut<'de> { (*self).call(arg) } } struct MsgRecvdCallback { path: Arc, replier: Replier, inner: F, } impl MsgRecvdCallback { fn new(path: Arc, framed_msg: Arc>, inner: F) -> Self { Self { path, replier: Replier::new(framed_msg), inner, } } } impl DeserCallback for MsgRecvdCallback { type Arg<'de> = Envelope> where Self: 'de; type Return = Result<()>; type CallFut<'de> = impl 'de + Future + Send where F: 'de, Self: 'de; fn call<'de>(&'de mut self, arg: Envelope>) -> Self::CallFut<'de> { let replier = match arg.kind { MsgKind::Call => Some(self.replier.clone()), MsgKind::Send => None, }; async move { let result = self .inner .call(MsgReceived::new(self.path.clone(), arg, replier)) .await; match result { Ok(value) => Ok(value), Err(err) => match err.downcast::() { Ok(err) => { self.replier .reply_err(err.to_string(), err.raw_os_error()) .await } Err(err) => self.replier.reply_err(err.to_string(), None).await, }, } } } } /// Indicates whether a message was sent using `call` or `send`. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)] enum MsgKind { /// This message expects exactly one reply. Call, /// This message expects exactly zero replies. Send, } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)] pub(crate) struct Envelope { kind: MsgKind, msg: T, } impl Envelope { pub(crate) fn send(msg: T) -> Self { Self { msg, kind: MsgKind::Send, } } pub(crate) fn call(msg: T) -> Self { Self { msg, kind: MsgKind::Call, } } fn msg(&self) -> &T { &self.msg } } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)] pub(crate) enum ReplyEnvelope { Ok(T), Err { message: String, os_code: Option, }, } impl ReplyEnvelope { fn err(message: String, os_code: Option) -> Self { Self::Err { message, os_code } } } /// A message tagged with the block path that it was sent from. pub struct MsgReceived { from: Arc, msg: Envelope, replier: Option, } impl MsgReceived { fn new(from: Arc, msg: Envelope, replier: Option) -> Self { Self { from, msg, replier } } pub fn into_parts(self) -> (Arc, T, Option) { (self.from, self.msg.msg, self.replier) } /// The path from which this message was received. pub fn from(&self) -> &Arc { &self.from } /// Payload contained in this message. pub fn body(&self) -> &T { self.msg.msg() } /// Returns true if and only if this messages needs to be replied to. pub fn needs_reply(&self) -> bool { self.replier.is_some() } /// Takes the replier out of this struct and returns it, if it has not previously been returned. pub fn take_replier(&mut self) -> Option { self.replier.take() } } type FramedMsg = FramedWrite; type ArcMutex = Arc>; /// A type for sending a reply to a message. Replies are sent over their own streams, so no two /// messages can interfere with one another. #[derive(Clone)] pub struct Replier { stream: ArcMutex, } impl Replier { fn new(stream: ArcMutex) -> Self { Self { stream } } pub async fn reply(&mut self, reply: T) -> Result<()> { let mut guard = self.stream.lock().await; guard.send(ReplyEnvelope::Ok(reply)).await?; Ok(()) } pub async fn reply_err(&mut self, err: String, os_code: Option) -> Result<()> { let mut guard = self.stream.lock().await; guard.send(ReplyEnvelope::<()>::err(err, os_code)).await?; Ok(()) } }