|
@@ -2,764 +2,18 @@
|
|
|
//! Code which enables sending messages between processes in the blocktree system.
|
|
|
#![feature(impl_trait_in_assoc_type)]
|
|
|
|
|
|
-mod tls;
|
|
|
-use tls::*;
|
|
|
-mod callback_framed;
|
|
|
-use callback_framed::CallbackFramed;
|
|
|
-pub use callback_framed::DeserCallback;
|
|
|
-
|
|
|
-use btlib::{bterr, crypto::Creds, error::DisplayErr, BlockPath, Result, Writecap};
|
|
|
-use btserde::{field_helpers::smart_ptr, write_to};
|
|
|
-use bytes::{BufMut, BytesMut};
|
|
|
-use core::{
|
|
|
- future::{ready, Future, Ready},
|
|
|
- marker::Send,
|
|
|
- pin::Pin,
|
|
|
-};
|
|
|
-use futures::{FutureExt, SinkExt};
|
|
|
-use log::{debug, error};
|
|
|
-use quinn::{Connection, ConnectionError, Endpoint, RecvStream, SendStream};
|
|
|
-use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
|
|
-use std::{
|
|
|
- any::Any,
|
|
|
- fmt::Display,
|
|
|
- hash::Hash,
|
|
|
- io,
|
|
|
- marker::PhantomData,
|
|
|
- net::{IpAddr, Ipv6Addr, SocketAddr},
|
|
|
- result::Result as StdResult,
|
|
|
- sync::{Arc, Mutex as StdMutex},
|
|
|
-};
|
|
|
-use tokio::{
|
|
|
- select,
|
|
|
- sync::{broadcast, Mutex},
|
|
|
- task::{JoinError, JoinHandle},
|
|
|
-};
|
|
|
-use tokio_util::codec::{Encoder, Framed, FramedParts, FramedWrite};
|
|
|
-
|
|
|
-/// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of
|
|
|
-/// the [Writecap] in the given credentials. The returned type can be used to make
|
|
|
-/// [Transmitter]s for any path.
|
|
|
-pub fn receiver<C: 'static + Creds + Send + Sync, F: 'static + MsgCallback>(
|
|
|
- ip_addr: IpAddr,
|
|
|
- creds: Arc<C>,
|
|
|
- callback: F,
|
|
|
-) -> Result<impl Receiver> {
|
|
|
- let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
|
|
|
- let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
|
|
|
- QuicReceiver::new(addr, Arc::new(CertResolver::new(creds)?), callback)
|
|
|
-}
|
|
|
-
|
|
|
-pub async fn transmitter<C: 'static + Creds + Send + Sync>(
|
|
|
- addr: Arc<BlockAddr>,
|
|
|
- creds: Arc<C>,
|
|
|
-) -> Result<impl Transmitter> {
|
|
|
- let resolver = Arc::new(CertResolver::new(creds)?);
|
|
|
- let endpoint = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))?;
|
|
|
- QuicTransmitter::from_endpoint(endpoint, addr, resolver).await
|
|
|
-}
|
|
|
-
|
|
|
-pub trait MsgCallback: Clone + Send + Sync + Unpin {
|
|
|
- type Arg<'de>: CallMsg<'de>
|
|
|
- where
|
|
|
- Self: 'de;
|
|
|
- type CallFut<'de>: Future<Output = Result<()>> + Send
|
|
|
- where
|
|
|
- Self: 'de;
|
|
|
- fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de>;
|
|
|
-}
|
|
|
-
|
|
|
-impl<T: MsgCallback> MsgCallback for &T {
|
|
|
- type Arg<'de> = T::Arg<'de> where Self: 'de;
|
|
|
- type CallFut<'de> = T::CallFut<'de> where Self: 'de;
|
|
|
- fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
|
|
|
- (*self).call(arg)
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-/// Trait for messages which can be transmitted using the call method.
|
|
|
-pub trait CallMsg<'de>: Serialize + Deserialize<'de> + Send + Sync {
|
|
|
- type Reply<'r>: Serialize + Deserialize<'r> + Send;
|
|
|
-}
|
|
|
-
|
|
|
-/// Trait for messages which can be transmitted using the send method.
|
|
|
-/// Types which implement this trait should specify `()` as their reply type.
|
|
|
-pub trait SendMsg<'de>: CallMsg<'de> {}
|
|
|
-
|
|
|
-/// An address which identifies a block on the network. An instance of this struct can be
|
|
|
-/// used to get a socket address for the block this address refers to.
|
|
|
-#[derive(PartialEq, Eq, Hash, Clone, Debug, Serialize, Deserialize)]
|
|
|
-pub struct BlockAddr {
|
|
|
- #[serde(rename = "ipaddr")]
|
|
|
- ip_addr: IpAddr,
|
|
|
- #[serde(with = "smart_ptr")]
|
|
|
- path: Arc<BlockPath>,
|
|
|
-}
|
|
|
-
|
|
|
-impl BlockAddr {
|
|
|
- pub fn new(ip_addr: IpAddr, path: Arc<BlockPath>) -> Self {
|
|
|
- Self { ip_addr, path }
|
|
|
- }
|
|
|
-
|
|
|
- pub fn ip_addr(&self) -> IpAddr {
|
|
|
- self.ip_addr
|
|
|
- }
|
|
|
-
|
|
|
- pub fn path(&self) -> &BlockPath {
|
|
|
- self.path.as_ref()
|
|
|
- }
|
|
|
-
|
|
|
- fn port(&self) -> Result<u16> {
|
|
|
- self.path.port()
|
|
|
- }
|
|
|
-
|
|
|
- /// Returns the socket address of the block this instance refers to.
|
|
|
- pub fn socket_addr(&self) -> Result<SocketAddr> {
|
|
|
- Ok(SocketAddr::new(self.ip_addr, self.port()?))
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-impl Display for BlockAddr {
|
|
|
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
|
- write!(f, "{}@{}", self.path, self.ip_addr)
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-/// Indicates whether a message was sent using `call` or `send`.
|
|
|
-#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
|
|
|
-enum MsgKind {
|
|
|
- /// This message expects exactly one reply.
|
|
|
- Call,
|
|
|
- /// This message expects exactly zero replies.
|
|
|
- Send,
|
|
|
-}
|
|
|
-
|
|
|
-#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
|
|
|
-struct Envelope<T> {
|
|
|
- kind: MsgKind,
|
|
|
- msg: T,
|
|
|
-}
|
|
|
-
|
|
|
-impl<T> Envelope<T> {
|
|
|
- fn send(msg: T) -> Self {
|
|
|
- Self {
|
|
|
- msg,
|
|
|
- kind: MsgKind::Send,
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- fn call(msg: T) -> Self {
|
|
|
- Self {
|
|
|
- msg,
|
|
|
- kind: MsgKind::Call,
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- fn msg(&self) -> &T {
|
|
|
- &self.msg
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
|
|
|
-enum ReplyEnvelope<T> {
|
|
|
- Ok(T),
|
|
|
- Err {
|
|
|
- message: String,
|
|
|
- os_code: Option<i32>,
|
|
|
- },
|
|
|
-}
|
|
|
-
|
|
|
-impl<T> ReplyEnvelope<T> {
|
|
|
- fn err(message: String, os_code: Option<i32>) -> Self {
|
|
|
- Self::Err { message, os_code }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-/// A message tagged with the block path that it was sent from.
|
|
|
-pub struct MsgReceived<T> {
|
|
|
- from: Arc<BlockPath>,
|
|
|
- msg: Envelope<T>,
|
|
|
- replier: Option<Replier>,
|
|
|
-}
|
|
|
-
|
|
|
-impl<T> MsgReceived<T> {
|
|
|
- fn new(from: Arc<BlockPath>, msg: Envelope<T>, replier: Option<Replier>) -> Self {
|
|
|
- Self { from, msg, replier }
|
|
|
- }
|
|
|
-
|
|
|
- pub fn into_parts(self) -> (Arc<BlockPath>, T, Option<Replier>) {
|
|
|
- (self.from, self.msg.msg, self.replier)
|
|
|
- }
|
|
|
-
|
|
|
- /// The path from which this message was received.
|
|
|
- pub fn from(&self) -> &Arc<BlockPath> {
|
|
|
- &self.from
|
|
|
- }
|
|
|
-
|
|
|
- /// Payload contained in this message.
|
|
|
- pub fn body(&self) -> &T {
|
|
|
- self.msg.msg()
|
|
|
- }
|
|
|
-
|
|
|
- /// Returns true if and only if this messages needs to be replied to.
|
|
|
- pub fn needs_reply(&self) -> bool {
|
|
|
- self.replier.is_some()
|
|
|
- }
|
|
|
-
|
|
|
- /// Takes the replier out of this struct and returns it, if it has not previously been returned.
|
|
|
- pub fn take_replier(&mut self) -> Option<Replier> {
|
|
|
- self.replier.take()
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-/// Trait for receiving messages and creating [Transmitter]s.
|
|
|
-pub trait Receiver {
|
|
|
- /// The address at which messages will be received.
|
|
|
- fn addr(&self) -> &Arc<BlockAddr>;
|
|
|
-
|
|
|
- type Transmitter: Transmitter + Send;
|
|
|
- type TransmitterFut<'a>: 'a + Future<Output = Result<Self::Transmitter>> + Send
|
|
|
- where
|
|
|
- Self: 'a;
|
|
|
-
|
|
|
- /// Creates a [Transmitter] which is connected to the given address.
|
|
|
- fn transmitter(&self, addr: Arc<BlockAddr>) -> Self::TransmitterFut<'_>;
|
|
|
-
|
|
|
- type CompleteErr: std::error::Error + Send;
|
|
|
- type CompleteFut<'a>: 'a + Future<Output = StdResult<(), Self::CompleteErr>> + Send
|
|
|
- where
|
|
|
- Self: 'a;
|
|
|
- /// Returns a future which completes when this [Receiver] has completed (which may be never).
|
|
|
- fn complete(&self) -> Result<Self::CompleteFut<'_>>;
|
|
|
-
|
|
|
- type StopFut<'a>: 'a + Future<Output = Result<()>> + Send
|
|
|
- where
|
|
|
- Self: 'a;
|
|
|
- fn stop(&self) -> Self::StopFut<'_>;
|
|
|
-}
|
|
|
-
|
|
|
-/// A type which can be used to transmit messages.
|
|
|
-pub trait Transmitter {
|
|
|
- type SendFut<'call, T>: 'call + Future<Output = Result<()>> + Send
|
|
|
- where
|
|
|
- Self: 'call,
|
|
|
- T: 'call + SendMsg<'call>;
|
|
|
-
|
|
|
- /// Transmit a message to the connected [Receiver] without waiting for a reply.
|
|
|
- fn send<'call, T: 'call + SendMsg<'call>>(&'call self, msg: T) -> Self::SendFut<'call, T>;
|
|
|
-
|
|
|
- type CallFut<'call, T, F>: 'call + Future<Output = Result<F::Return>> + Send
|
|
|
- where
|
|
|
- Self: 'call,
|
|
|
- T: 'call + CallMsg<'call>,
|
|
|
- F: 'static + Send + DeserCallback;
|
|
|
-
|
|
|
- /// Transmit a message to the connected [Receiver], waits for a reply, then calls the given
|
|
|
- /// [DeserCallback] with the deserialized reply.
|
|
|
- ///
|
|
|
- /// ## WARNING
|
|
|
- /// The callback must be such that `F::Arg<'a> = T::Reply<'a>` for any `'a`. If this
|
|
|
- /// is violated, then a deserilization error will occur at runtime.
|
|
|
- ///
|
|
|
- /// ## TODO
|
|
|
- /// This issue needs to be fixed. Due to the fact that
|
|
|
- /// `F::Arg` is a Generic Associated Type (GAT) I have been unable to express this constraint in
|
|
|
- /// the where clause of this method. I'm not sure if the errors I've encountered are due to a
|
|
|
- /// lack of understanding on my part or due to the current limitations of the borrow checker in
|
|
|
- /// its handling of GATs.
|
|
|
- fn call<'call, T, F>(&'call self, msg: T, callback: F) -> Self::CallFut<'call, T, F>
|
|
|
- where
|
|
|
- T: 'call + CallMsg<'call>,
|
|
|
- F: 'static + Send + DeserCallback;
|
|
|
-
|
|
|
- /// Transmits a message to the connected [Receiver], waits for a reply, then passes back the
|
|
|
- /// the reply to the caller.
|
|
|
- fn call_through<'call, T>(
|
|
|
- &'call self,
|
|
|
- msg: T,
|
|
|
- ) -> Self::CallFut<'call, T, Passthrough<T::Reply<'call>>>
|
|
|
- where
|
|
|
- T: 'call + CallMsg<'call>,
|
|
|
- T::Reply<'call>: 'static + Send + Sync + DeserializeOwned,
|
|
|
- {
|
|
|
- self.call(msg, Passthrough::new())
|
|
|
- }
|
|
|
-
|
|
|
- /// Returns the address that this instance is transmitting to.
|
|
|
- fn addr(&self) -> &Arc<BlockAddr>;
|
|
|
-}
|
|
|
-
|
|
|
-pub struct Passthrough<T> {
|
|
|
- phantom: PhantomData<T>,
|
|
|
-}
|
|
|
+pub use btlib::Result;
|
|
|
|
|
|
-impl<T> Passthrough<T> {
|
|
|
- pub fn new() -> Self {
|
|
|
- Self {
|
|
|
- phantom: PhantomData,
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
+mod common;
|
|
|
+pub use common::{BlockAddr, CallMsg, SendMsg};
|
|
|
|
|
|
-impl<T> Default for Passthrough<T> {
|
|
|
- fn default() -> Self {
|
|
|
- Self::new()
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-impl<T> Clone for Passthrough<T> {
|
|
|
- fn clone(&self) -> Self {
|
|
|
- Self::new()
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-impl<T: 'static + Send + DeserializeOwned> DeserCallback for Passthrough<T> {
|
|
|
- type Arg<'de> = T;
|
|
|
- type Return = T;
|
|
|
- type CallFut<'de> = Ready<T>;
|
|
|
- fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
|
|
|
- ready(arg)
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-/// Encodes messages using [btserde].
|
|
|
-#[derive(Debug)]
|
|
|
-struct MsgEncoder;
|
|
|
-
|
|
|
-impl MsgEncoder {
|
|
|
- fn new() -> Self {
|
|
|
- Self
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-impl<T: Serialize> Encoder<T> for MsgEncoder {
|
|
|
- type Error = btlib::Error;
|
|
|
-
|
|
|
- fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<()> {
|
|
|
- const U64_LEN: usize = std::mem::size_of::<u64>();
|
|
|
- let payload = dst.split_off(U64_LEN);
|
|
|
- let mut writer = payload.writer();
|
|
|
- write_to(&item, &mut writer)?;
|
|
|
- let payload = writer.into_inner();
|
|
|
- let payload_len = payload.len() as u64;
|
|
|
- let mut writer = dst.writer();
|
|
|
- write_to(&payload_len, &mut writer)?;
|
|
|
- let dst = writer.into_inner();
|
|
|
- dst.unsplit(payload);
|
|
|
- Ok(())
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-type FramedMsg = FramedWrite<SendStream, MsgEncoder>;
|
|
|
-type ArcMutex<T> = Arc<Mutex<T>>;
|
|
|
-
|
|
|
-#[derive(Clone)]
|
|
|
-pub struct Replier {
|
|
|
- stream: ArcMutex<FramedMsg>,
|
|
|
-}
|
|
|
-
|
|
|
-impl Replier {
|
|
|
- fn new(stream: ArcMutex<FramedMsg>) -> Self {
|
|
|
- Self { stream }
|
|
|
- }
|
|
|
-
|
|
|
- pub async fn reply<T: Serialize + Send>(&mut self, reply: T) -> Result<()> {
|
|
|
- let mut guard = self.stream.lock().await;
|
|
|
- guard.send(ReplyEnvelope::Ok(reply)).await?;
|
|
|
- Ok(())
|
|
|
- }
|
|
|
-
|
|
|
- pub async fn reply_err(&mut self, err: String, os_code: Option<i32>) -> Result<()> {
|
|
|
- let mut guard = self.stream.lock().await;
|
|
|
- guard.send(ReplyEnvelope::<()>::err(err, os_code)).await?;
|
|
|
- Ok(())
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-struct MsgRecvdCallback<F> {
|
|
|
- path: Arc<BlockPath>,
|
|
|
- replier: Replier,
|
|
|
- inner: F,
|
|
|
-}
|
|
|
-
|
|
|
-impl<F: MsgCallback> MsgRecvdCallback<F> {
|
|
|
- fn new(path: Arc<BlockPath>, framed_msg: Arc<Mutex<FramedMsg>>, inner: F) -> Self {
|
|
|
- Self {
|
|
|
- path,
|
|
|
- replier: Replier::new(framed_msg),
|
|
|
- inner,
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-impl<F: 'static + MsgCallback> DeserCallback for MsgRecvdCallback<F> {
|
|
|
- type Arg<'de> = Envelope<F::Arg<'de>> where Self: 'de;
|
|
|
- type Return = Result<()>;
|
|
|
- type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send where F: 'de, Self: 'de;
|
|
|
- fn call<'de>(&'de mut self, arg: Envelope<F::Arg<'de>>) -> Self::CallFut<'de> {
|
|
|
- let replier = match arg.kind {
|
|
|
- MsgKind::Call => Some(self.replier.clone()),
|
|
|
- MsgKind::Send => None,
|
|
|
- };
|
|
|
- async move {
|
|
|
- let result = self
|
|
|
- .inner
|
|
|
- .call(MsgReceived::new(self.path.clone(), arg, replier))
|
|
|
- .await;
|
|
|
- match result {
|
|
|
- Ok(value) => Ok(value),
|
|
|
- Err(err) => match err.downcast::<io::Error>() {
|
|
|
- Ok(err) => {
|
|
|
- self.replier
|
|
|
- .reply_err(err.to_string(), err.raw_os_error())
|
|
|
- .await
|
|
|
- }
|
|
|
- Err(err) => self.replier.reply_err(err.to_string(), None).await,
|
|
|
- },
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-macro_rules! handle_err {
|
|
|
- ($result:expr, $on_err:expr, $control_flow:expr) => {
|
|
|
- match $result {
|
|
|
- Ok(inner) => inner,
|
|
|
- Err(err) => {
|
|
|
- $on_err(err);
|
|
|
- $control_flow;
|
|
|
- }
|
|
|
- }
|
|
|
- };
|
|
|
-}
|
|
|
-
|
|
|
-/// Unwraps the given result, or if the result is an error, returns from the enclosing function.
|
|
|
-macro_rules! unwrap_or_return {
|
|
|
- ($result:expr, $on_err:expr) => {
|
|
|
- handle_err!($result, $on_err, return)
|
|
|
- };
|
|
|
- ($result:expr) => {
|
|
|
- unwrap_or_return!($result, |err| error!("{err}"))
|
|
|
- };
|
|
|
-}
|
|
|
-
|
|
|
-/// Unwraps the given result, or if the result is an error, continues the enclosing loop.
|
|
|
-macro_rules! unwrap_or_continue {
|
|
|
- ($result:expr, $on_err:expr) => {
|
|
|
- handle_err!($result, $on_err, continue)
|
|
|
- };
|
|
|
- ($result:expr) => {
|
|
|
- unwrap_or_continue!($result, |err| error!("{err}"))
|
|
|
- };
|
|
|
-}
|
|
|
-
|
|
|
-/// Awaits its first argument, unless interrupted by its second argument, in which case the
|
|
|
-/// enclosing function returns. The second argument needs to be cancel safe, but the first
|
|
|
-/// need not be if it is discarded when the enclosing function returns (because losing messages
|
|
|
-/// from the first argument doesn't matter in this case).
|
|
|
-macro_rules! await_or_stop {
|
|
|
- ($future:expr, $stop_fut:expr) => {
|
|
|
- select! {
|
|
|
- Some(connecting) = $future => connecting,
|
|
|
- _ = $stop_fut => break,
|
|
|
- }
|
|
|
- };
|
|
|
-}
|
|
|
-
|
|
|
-struct QuicReceiver {
|
|
|
- recv_addr: Arc<BlockAddr>,
|
|
|
- stop_tx: broadcast::Sender<()>,
|
|
|
- endpoint: Endpoint,
|
|
|
- resolver: Arc<CertResolver>,
|
|
|
- join_handle: StdMutex<Option<JoinHandle<()>>>,
|
|
|
-}
|
|
|
-
|
|
|
-impl QuicReceiver {
|
|
|
- fn new<F: 'static + MsgCallback>(
|
|
|
- recv_addr: Arc<BlockAddr>,
|
|
|
- resolver: Arc<CertResolver>,
|
|
|
- callback: F,
|
|
|
- ) -> Result<Self> {
|
|
|
- log::info!("starting QuicReceiver with address {}", recv_addr);
|
|
|
- let socket_addr = recv_addr.socket_addr()?;
|
|
|
- let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
|
|
|
- let (stop_tx, stop_rx) = broadcast::channel(1);
|
|
|
- let join_handle = tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx));
|
|
|
- Ok(Self {
|
|
|
- recv_addr,
|
|
|
- stop_tx,
|
|
|
- endpoint,
|
|
|
- resolver,
|
|
|
- join_handle: StdMutex::new(Some(join_handle)),
|
|
|
- })
|
|
|
- }
|
|
|
-
|
|
|
- async fn server_loop<F: 'static + MsgCallback>(
|
|
|
- endpoint: Endpoint,
|
|
|
- callback: F,
|
|
|
- mut stop_rx: broadcast::Receiver<()>,
|
|
|
- ) {
|
|
|
- loop {
|
|
|
- let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv());
|
|
|
- let connection = unwrap_or_continue!(connecting.await, |err| error!(
|
|
|
- "error accepting QUIC connection: {err}"
|
|
|
- ));
|
|
|
- tokio::spawn(Self::handle_connection(
|
|
|
- connection,
|
|
|
- callback.clone(),
|
|
|
- stop_rx.resubscribe(),
|
|
|
- ));
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- async fn handle_connection<F: 'static + MsgCallback>(
|
|
|
- connection: Connection,
|
|
|
- callback: F,
|
|
|
- mut stop_rx: broadcast::Receiver<()>,
|
|
|
- ) {
|
|
|
- let client_path = unwrap_or_return!(
|
|
|
- Self::client_path(connection.peer_identity()),
|
|
|
- |err| error!("failed to get client path from peer identity: {err}")
|
|
|
- );
|
|
|
- loop {
|
|
|
- let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv());
|
|
|
- let (send_stream, recv_stream) = match result {
|
|
|
- Ok(pair) => pair,
|
|
|
- Err(err) => match err {
|
|
|
- ConnectionError::ApplicationClosed(app) => {
|
|
|
- debug!("connection closed: {app}");
|
|
|
- return;
|
|
|
- }
|
|
|
- _ => {
|
|
|
- error!("error accepting stream: {err}");
|
|
|
- continue;
|
|
|
- }
|
|
|
- },
|
|
|
- };
|
|
|
- let client_path = client_path.clone();
|
|
|
- let callback = callback.clone();
|
|
|
- tokio::task::spawn(Self::handle_message(
|
|
|
- client_path,
|
|
|
- send_stream,
|
|
|
- recv_stream,
|
|
|
- callback,
|
|
|
- ));
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- async fn handle_message<F: 'static + MsgCallback>(
|
|
|
- client_path: Arc<BlockPath>,
|
|
|
- send_stream: SendStream,
|
|
|
- recv_stream: RecvStream,
|
|
|
- callback: F,
|
|
|
- ) {
|
|
|
- let framed_msg = Arc::new(Mutex::new(FramedWrite::new(send_stream, MsgEncoder::new())));
|
|
|
- let callback =
|
|
|
- MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone());
|
|
|
- let mut msg_stream = CallbackFramed::new(recv_stream);
|
|
|
- let result = msg_stream
|
|
|
- .next(callback)
|
|
|
- .await
|
|
|
- .ok_or_else(|| bterr!("client closed stream before sending a message"));
|
|
|
- match unwrap_or_return!(result) {
|
|
|
- Err(err) => error!("msg_stream produced an error: {err}"),
|
|
|
- Ok(result) => {
|
|
|
- if let Err(err) = result {
|
|
|
- error!("callback returned an error: {err}");
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /// Returns the path the client is bound to.
|
|
|
- fn client_path(peer_identity: Option<Box<dyn Any>>) -> Result<Arc<BlockPath>> {
|
|
|
- let peer_identity =
|
|
|
- peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?;
|
|
|
- let client_certs = peer_identity
|
|
|
- .downcast::<Vec<rustls::Certificate>>()
|
|
|
- .map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?;
|
|
|
- let first = client_certs
|
|
|
- .first()
|
|
|
- .ok_or_else(|| bterr!("no certificates were presented by the client"))?;
|
|
|
- let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?;
|
|
|
- Ok(Arc::new(writecap.bind_path()))
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-impl Drop for QuicReceiver {
|
|
|
- fn drop(&mut self) {
|
|
|
- // This result will be a failure if the tasks have already returned, which is not a
|
|
|
- // problem.
|
|
|
- let _ = self.stop_tx.send(());
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-impl Receiver for QuicReceiver {
|
|
|
- fn addr(&self) -> &Arc<BlockAddr> {
|
|
|
- &self.recv_addr
|
|
|
- }
|
|
|
-
|
|
|
- type Transmitter = QuicTransmitter;
|
|
|
- type TransmitterFut<'a> = Pin<Box<dyn 'a + Future<Output = Result<QuicTransmitter>> + Send>>;
|
|
|
-
|
|
|
- fn transmitter(&self, addr: Arc<BlockAddr>) -> Self::TransmitterFut<'_> {
|
|
|
- Box::pin(async {
|
|
|
- QuicTransmitter::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
|
|
|
- })
|
|
|
- }
|
|
|
-
|
|
|
- type CompleteErr = JoinError;
|
|
|
- type CompleteFut<'a> = JoinHandle<()>;
|
|
|
- fn complete(&self) -> Result<Self::CompleteFut<'_>> {
|
|
|
- let mut guard = self.join_handle.lock().display_err()?;
|
|
|
- let handle = guard
|
|
|
- .take()
|
|
|
- .ok_or_else(|| bterr!("join handle has already been taken"))?;
|
|
|
- Ok(handle)
|
|
|
- }
|
|
|
-
|
|
|
- type StopFut<'a> = Ready<Result<()>>;
|
|
|
- fn stop(&self) -> Self::StopFut<'_> {
|
|
|
- ready(self.stop_tx.send(()).map(|_| ()).map_err(|err| err.into()))
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-macro_rules! cleanup_on_err {
|
|
|
- ($result:expr, $guard:ident, $parts:ident) => {
|
|
|
- match $result {
|
|
|
- Ok(value) => value,
|
|
|
- Err(err) => {
|
|
|
- *$guard = Some($parts);
|
|
|
- return Err(err.into());
|
|
|
- }
|
|
|
- }
|
|
|
- };
|
|
|
-}
|
|
|
-
|
|
|
-struct QuicTransmitter {
|
|
|
- addr: Arc<BlockAddr>,
|
|
|
- connection: Connection,
|
|
|
- send_parts: Mutex<Option<FramedParts<SendStream, MsgEncoder>>>,
|
|
|
- recv_buf: Mutex<Option<BytesMut>>,
|
|
|
-}
|
|
|
-
|
|
|
-impl QuicTransmitter {
|
|
|
- async fn from_endpoint(
|
|
|
- endpoint: Endpoint,
|
|
|
- addr: Arc<BlockAddr>,
|
|
|
- resolver: Arc<CertResolver>,
|
|
|
- ) -> Result<Self> {
|
|
|
- let socket_addr = addr.socket_addr()?;
|
|
|
- let connecting = endpoint.connect_with(
|
|
|
- client_config(addr.path.clone(), resolver)?,
|
|
|
- socket_addr,
|
|
|
- // The ServerCertVerifier ensures we connect to the correct path.
|
|
|
- "UNIMPORTANT",
|
|
|
- )?;
|
|
|
- let connection = connecting.await?;
|
|
|
- let send_parts = Mutex::new(None);
|
|
|
- let recv_buf = Mutex::new(Some(BytesMut::new()));
|
|
|
- Ok(Self {
|
|
|
- addr,
|
|
|
- connection,
|
|
|
- send_parts,
|
|
|
- recv_buf,
|
|
|
- })
|
|
|
- }
|
|
|
-
|
|
|
- async fn transmit<T: Serialize>(&self, envelope: Envelope<T>) -> Result<RecvStream> {
|
|
|
- let mut guard = self.send_parts.lock().await;
|
|
|
- let (send_stream, recv_stream) = self.connection.open_bi().await?;
|
|
|
- let parts = match guard.take() {
|
|
|
- Some(mut parts) => {
|
|
|
- parts.io = send_stream;
|
|
|
- parts
|
|
|
- }
|
|
|
- None => FramedParts::new::<Envelope<T>>(send_stream, MsgEncoder::new()),
|
|
|
- };
|
|
|
- let mut sink = Framed::from_parts(parts);
|
|
|
- let result = sink.send(envelope).await;
|
|
|
- let parts = sink.into_parts();
|
|
|
- cleanup_on_err!(result, guard, parts);
|
|
|
- *guard = Some(parts);
|
|
|
- Ok(recv_stream)
|
|
|
- }
|
|
|
-
|
|
|
- async fn call<'ser, T, F>(&'ser self, msg: T, callback: F) -> Result<F::Return>
|
|
|
- where
|
|
|
- T: 'ser + CallMsg<'ser>,
|
|
|
- F: 'static + Send + DeserCallback,
|
|
|
- {
|
|
|
- let recv_stream = self.transmit(Envelope::call(msg)).await?;
|
|
|
- let mut guard = self.recv_buf.lock().await;
|
|
|
- let buffer = guard.take().unwrap();
|
|
|
- let mut callback_framed = CallbackFramed::from_parts(recv_stream, buffer);
|
|
|
- let result = callback_framed
|
|
|
- .next(ReplyCallback::new(callback))
|
|
|
- .await
|
|
|
- .ok_or_else(|| bterr!("server hung up before sending reply"));
|
|
|
- let (_, buffer) = callback_framed.into_parts();
|
|
|
- let output = cleanup_on_err!(result, guard, buffer);
|
|
|
- *guard = Some(buffer);
|
|
|
- output?
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-impl Transmitter for QuicTransmitter {
|
|
|
- fn addr(&self) -> &Arc<BlockAddr> {
|
|
|
- &self.addr
|
|
|
- }
|
|
|
-
|
|
|
- type SendFut<'ser, T> = impl 'ser + Future<Output = Result<()>> + Send
|
|
|
- where T: 'ser + SendMsg<'ser>;
|
|
|
-
|
|
|
- fn send<'ser, 'r, T: 'ser + SendMsg<'ser>>(&'ser self, msg: T) -> Self::SendFut<'ser, T> {
|
|
|
- self.transmit(Envelope::send(msg))
|
|
|
- .map(|result| result.map(|_| ()))
|
|
|
- }
|
|
|
-
|
|
|
- type CallFut<'ser, T, F> = impl 'ser + Future<Output = Result<F::Return>> + Send
|
|
|
- where
|
|
|
- Self: 'ser,
|
|
|
- T: 'ser + CallMsg<'ser>,
|
|
|
- F: 'static + Send + DeserCallback;
|
|
|
-
|
|
|
- fn call<'ser, T, F>(&'ser self, msg: T, callback: F) -> Self::CallFut<'ser, T, F>
|
|
|
- where
|
|
|
- T: 'ser + CallMsg<'ser>,
|
|
|
- F: 'static + Send + DeserCallback,
|
|
|
- {
|
|
|
- self.call(msg, callback)
|
|
|
- }
|
|
|
-}
|
|
|
+mod tls;
|
|
|
|
|
|
-struct ReplyCallback<F> {
|
|
|
- inner: F,
|
|
|
-}
|
|
|
+mod serialization;
|
|
|
+pub use serialization::DeserCallback;
|
|
|
|
|
|
-impl<F> ReplyCallback<F> {
|
|
|
- fn new(inner: F) -> Self {
|
|
|
- Self { inner }
|
|
|
- }
|
|
|
-}
|
|
|
+mod transmitter;
|
|
|
+pub use transmitter::Transmitter;
|
|
|
|
|
|
-impl<F: 'static + Send + DeserCallback> DeserCallback for ReplyCallback<F> {
|
|
|
- type Arg<'de> = ReplyEnvelope<F::Arg<'de>>;
|
|
|
- type Return = Result<F::Return>;
|
|
|
- type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send;
|
|
|
- fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
|
|
|
- async move {
|
|
|
- match arg {
|
|
|
- ReplyEnvelope::Ok(msg) => Ok(self.inner.call(msg).await),
|
|
|
- ReplyEnvelope::Err { message, os_code } => {
|
|
|
- if let Some(os_code) = os_code {
|
|
|
- let err = bterr!(io::Error::from_raw_os_error(os_code)).context(message);
|
|
|
- Err(err)
|
|
|
- } else {
|
|
|
- Err(bterr!(message))
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
+mod receiver;
|
|
|
+pub use receiver::{MsgCallback, MsgReceived, Receiver, Replier};
|