123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414 |
- // SPDX-License-Identifier: AGPL-3.0-or-later
- //! Types used for receiving messages over the network. Chiefly, the [Receiver] type.
- use std::{
- any::Any,
- future::Future,
- io,
- net::IpAddr,
- sync::{Arc, Mutex as StdMutex},
- };
- use btlib::{bterr, crypto::Creds, error::DisplayErr, BlockPath, Writecap};
- use futures::{FutureExt, SinkExt};
- use log::{debug, error};
- use quinn::{Connection, ConnectionError, Endpoint, RecvStream, SendStream};
- use serde::{Deserialize, Serialize};
- use tokio::{
- select,
- sync::{broadcast, Mutex},
- task::JoinHandle,
- };
- use tokio_util::codec::FramedWrite;
- use crate::{
- serialization::{CallbackFramed, MsgEncoder},
- tls::{server_config, CertResolver},
- BlockAddr, CallMsg, DeserCallback, Result, Transmitter,
- };
- macro_rules! handle_err {
- ($result:expr, $on_err:expr, $control_flow:expr) => {
- match $result {
- Ok(inner) => inner,
- Err(err) => {
- $on_err(err);
- $control_flow;
- }
- }
- };
- }
- /// Unwraps the given result, or if the result is an error, returns from the enclosing function.
- macro_rules! unwrap_or_return {
- ($result:expr, $on_err:expr) => {
- handle_err!($result, $on_err, return)
- };
- ($result:expr) => {
- unwrap_or_return!($result, |err| error!("{err}"))
- };
- }
- /// Unwraps the given result, or if the result is an error, continues the enclosing loop.
- macro_rules! unwrap_or_continue {
- ($result:expr, $on_err:expr) => {
- handle_err!($result, $on_err, continue)
- };
- ($result:expr) => {
- unwrap_or_continue!($result, |err| error!("{err}"))
- };
- }
- /// Awaits its first argument, unless interrupted by its second argument, in which case the
- /// enclosing function returns. The second argument needs to be cancel safe, but the first
- /// need not be if it is discarded when the enclosing function returns (because losing messages
- /// from the first argument doesn't matter in this case).
- macro_rules! await_or_stop {
- ($future:expr, $stop_fut:expr) => {
- select! {
- Some(connecting) = $future => connecting,
- _ = $stop_fut => break,
- }
- };
- }
- /// Type which receives messages sent over the network sent by a [Transmitter].
- pub struct Receiver {
- recv_addr: Arc<BlockAddr>,
- stop_tx: broadcast::Sender<()>,
- endpoint: Endpoint,
- resolver: Arc<CertResolver>,
- join_handle: StdMutex<Option<JoinHandle<()>>>,
- }
- impl Receiver {
- /// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of
- /// the [Writecap] in the given credentials. The returned type can be used to make
- /// [Transmitter]s for any path.
- pub fn new<C: 'static + Creds + Send + Sync, F: 'static + MsgCallback>(
- ip_addr: IpAddr,
- creds: Arc<C>,
- callback: F,
- ) -> Result<Receiver> {
- let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
- let recv_addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
- log::info!("starting Receiver with address {}", recv_addr);
- let socket_addr = recv_addr.socket_addr()?;
- let resolver = Arc::new(CertResolver::new(creds)?);
- let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
- let (stop_tx, stop_rx) = broadcast::channel(1);
- let join_handle = tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx));
- Ok(Self {
- recv_addr,
- stop_tx,
- endpoint,
- resolver,
- join_handle: StdMutex::new(Some(join_handle)),
- })
- }
- async fn server_loop<F: 'static + MsgCallback>(
- endpoint: Endpoint,
- callback: F,
- mut stop_rx: broadcast::Receiver<()>,
- ) {
- loop {
- let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv());
- let connection = unwrap_or_continue!(connecting.await, |err| error!(
- "error accepting QUIC connection: {err}"
- ));
- tokio::spawn(Self::handle_connection(
- connection,
- callback.clone(),
- stop_rx.resubscribe(),
- ));
- }
- }
- async fn handle_connection<F: 'static + MsgCallback>(
- connection: Connection,
- callback: F,
- mut stop_rx: broadcast::Receiver<()>,
- ) {
- let client_path = unwrap_or_return!(
- Self::client_path(connection.peer_identity()),
- |err| error!("failed to get client path from peer identity: {err}")
- );
- loop {
- let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv());
- let (send_stream, recv_stream) = match result {
- Ok(pair) => pair,
- Err(err) => match err {
- ConnectionError::ApplicationClosed(app) => {
- debug!("connection closed: {app}");
- return;
- }
- _ => {
- error!("error accepting stream: {err}");
- continue;
- }
- },
- };
- let client_path = client_path.clone();
- let callback = callback.clone();
- tokio::task::spawn(Self::handle_message(
- client_path,
- send_stream,
- recv_stream,
- callback,
- ));
- }
- }
- async fn handle_message<F: 'static + MsgCallback>(
- client_path: Arc<BlockPath>,
- send_stream: SendStream,
- recv_stream: RecvStream,
- callback: F,
- ) {
- let framed_msg = Arc::new(Mutex::new(FramedWrite::new(send_stream, MsgEncoder::new())));
- let callback =
- MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone());
- let mut msg_stream = CallbackFramed::new(recv_stream);
- let result = msg_stream
- .next(callback)
- .await
- .ok_or_else(|| bterr!("client closed stream before sending a message"));
- match unwrap_or_return!(result) {
- Err(err) => error!("msg_stream produced an error: {err}"),
- Ok(result) => {
- if let Err(err) = result {
- error!("callback returned an error: {err}");
- }
- }
- }
- }
- /// Returns the path the client is bound to.
- fn client_path(peer_identity: Option<Box<dyn Any>>) -> Result<Arc<BlockPath>> {
- let peer_identity =
- peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?;
- let client_certs = peer_identity
- .downcast::<Vec<rustls::Certificate>>()
- .map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?;
- let first = client_certs
- .first()
- .ok_or_else(|| bterr!("no certificates were presented by the client"))?;
- let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?;
- Ok(Arc::new(writecap.bind_path()))
- }
- /// The address at which messages will be received.
- pub fn addr(&self) -> &Arc<BlockAddr> {
- &self.recv_addr
- }
- /// Creates a [Transmitter] which is connected to the given address.
- pub async fn transmitter(&self, addr: Arc<BlockAddr>) -> Result<Transmitter> {
- Transmitter::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
- }
- /// Returns a future which completes when this [Receiver] has completed
- /// (which may be never).
- pub fn complete(&self) -> Result<JoinHandle<()>> {
- let mut guard = self.join_handle.lock().display_err()?;
- let handle = guard
- .take()
- .ok_or_else(|| bterr!("join handle has already been taken"))?;
- Ok(handle)
- }
- /// Sends a signal indicating that the task running the server loop should return.
- pub fn stop(&self) -> Result<()> {
- self.stop_tx.send(()).map(|_| ()).map_err(|err| err.into())
- }
- }
- impl Drop for Receiver {
- fn drop(&mut self) {
- // This result will be a failure if the tasks have already returned, which is not a
- // problem.
- let _ = self.stop_tx.send(());
- }
- }
- /// Trait for types which can be called to handle messages received over the network. The
- /// server loop in [Receiver] uses a type that implements this trait to react to messages it
- /// receives.
- pub trait MsgCallback: Clone + Send + Sync + Unpin {
- type Arg<'de>: CallMsg<'de>
- where
- Self: 'de;
- type CallFut<'de>: Future<Output = Result<()>> + Send
- where
- Self: 'de;
- fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de>;
- }
- impl<T: MsgCallback> MsgCallback for &T {
- type Arg<'de> = T::Arg<'de> where Self: 'de;
- type CallFut<'de> = T::CallFut<'de> where Self: 'de;
- fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
- (*self).call(arg)
- }
- }
- struct MsgRecvdCallback<F> {
- path: Arc<BlockPath>,
- replier: Replier,
- inner: F,
- }
- impl<F: MsgCallback> MsgRecvdCallback<F> {
- fn new(path: Arc<BlockPath>, framed_msg: Arc<Mutex<FramedMsg>>, inner: F) -> Self {
- Self {
- path,
- replier: Replier::new(framed_msg),
- inner,
- }
- }
- }
- impl<F: 'static + MsgCallback> DeserCallback for MsgRecvdCallback<F> {
- type Arg<'de> = Envelope<F::Arg<'de>> where Self: 'de;
- type Return = Result<()>;
- type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send where F: 'de, Self: 'de;
- fn call<'de>(&'de mut self, arg: Envelope<F::Arg<'de>>) -> Self::CallFut<'de> {
- let replier = match arg.kind {
- MsgKind::Call => Some(self.replier.clone()),
- MsgKind::Send => None,
- };
- async move {
- let result = self
- .inner
- .call(MsgReceived::new(self.path.clone(), arg, replier))
- .await;
- match result {
- Ok(value) => Ok(value),
- Err(err) => match err.downcast::<io::Error>() {
- Ok(err) => {
- self.replier
- .reply_err(err.to_string(), err.raw_os_error())
- .await
- }
- Err(err) => self.replier.reply_err(err.to_string(), None).await,
- },
- }
- }
- }
- }
- /// Indicates whether a message was sent using `call` or `send`.
- #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
- enum MsgKind {
- /// This message expects exactly one reply.
- Call,
- /// This message expects exactly zero replies.
- Send,
- }
- #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
- pub(crate) struct Envelope<T> {
- kind: MsgKind,
- msg: T,
- }
- impl<T> Envelope<T> {
- pub(crate) fn send(msg: T) -> Self {
- Self {
- msg,
- kind: MsgKind::Send,
- }
- }
- pub(crate) fn call(msg: T) -> Self {
- Self {
- msg,
- kind: MsgKind::Call,
- }
- }
- fn msg(&self) -> &T {
- &self.msg
- }
- }
- #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
- pub(crate) enum ReplyEnvelope<T> {
- Ok(T),
- Err {
- message: String,
- os_code: Option<i32>,
- },
- }
- impl<T> ReplyEnvelope<T> {
- fn err(message: String, os_code: Option<i32>) -> Self {
- Self::Err { message, os_code }
- }
- }
- /// A message tagged with the block path that it was sent from.
- pub struct MsgReceived<T> {
- from: Arc<BlockPath>,
- msg: Envelope<T>,
- replier: Option<Replier>,
- }
- impl<T> MsgReceived<T> {
- fn new(from: Arc<BlockPath>, msg: Envelope<T>, replier: Option<Replier>) -> Self {
- Self { from, msg, replier }
- }
- pub fn into_parts(self) -> (Arc<BlockPath>, T, Option<Replier>) {
- (self.from, self.msg.msg, self.replier)
- }
- /// The path from which this message was received.
- pub fn from(&self) -> &Arc<BlockPath> {
- &self.from
- }
- /// Payload contained in this message.
- pub fn body(&self) -> &T {
- self.msg.msg()
- }
- /// Returns true if and only if this messages needs to be replied to.
- pub fn needs_reply(&self) -> bool {
- self.replier.is_some()
- }
- /// Takes the replier out of this struct and returns it, if it has not previously been returned.
- pub fn take_replier(&mut self) -> Option<Replier> {
- self.replier.take()
- }
- }
- type FramedMsg = FramedWrite<SendStream, MsgEncoder>;
- type ArcMutex<T> = Arc<Mutex<T>>;
- /// A type for sending a reply to a message. Replies are sent over their own streams, so no two
- /// messages can interfere with one another.
- #[derive(Clone)]
- pub struct Replier {
- stream: ArcMutex<FramedMsg>,
- }
- impl Replier {
- fn new(stream: ArcMutex<FramedMsg>) -> Self {
- Self { stream }
- }
- pub async fn reply<T: Serialize + Send>(&mut self, reply: T) -> Result<()> {
- let mut guard = self.stream.lock().await;
- guard.send(ReplyEnvelope::Ok(reply)).await?;
- Ok(())
- }
- pub async fn reply_err(&mut self, err: String, os_code: Option<i32>) -> Result<()> {
- let mut guard = self.stream.lock().await;
- guard.send(ReplyEnvelope::<()>::err(err, os_code)).await?;
- Ok(())
- }
- }
|