|
@@ -7,7 +7,7 @@ mod callback_framed;
|
|
|
use callback_framed::CallbackFramed;
|
|
|
pub use callback_framed::DeserCallback;
|
|
|
|
|
|
-use btlib::{bterr, crypto::Creds, BlockPath, Result, Writecap};
|
|
|
+use btlib::{bterr, crypto::Creds, error::DisplayErr, BlockPath, Result, Writecap};
|
|
|
use btserde::write_to;
|
|
|
use bytes::{BufMut, BytesMut};
|
|
|
use core::{
|
|
@@ -24,17 +24,18 @@ use std::{
|
|
|
hash::Hash,
|
|
|
marker::PhantomData,
|
|
|
net::{IpAddr, Ipv6Addr, SocketAddr},
|
|
|
- sync::Arc,
|
|
|
+ result::Result as StdResult,
|
|
|
+ sync::{Arc, Mutex as StdMutex},
|
|
|
};
|
|
|
use tokio::{
|
|
|
- runtime::Handle,
|
|
|
select,
|
|
|
- sync::{broadcast, Mutex, OwnedSemaphorePermit, Semaphore},
|
|
|
+ sync::{broadcast, Mutex},
|
|
|
+ task::{JoinError, JoinHandle},
|
|
|
};
|
|
|
use tokio_util::codec::{Encoder, Framed, FramedParts, FramedWrite};
|
|
|
|
|
|
/// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of
|
|
|
-/// the given [Writecap] of the given credentials. The returned type can be used to make
|
|
|
+/// the [Writecap] in the given credentials. The returned type can be used to make
|
|
|
/// [Transmitter]s for any path.
|
|
|
pub fn receiver<C: 'static + Creds + Send + Sync, F: 'static + MsgCallback>(
|
|
|
ip_addr: IpAddr,
|
|
@@ -197,6 +198,13 @@ pub trait Receiver {
|
|
|
|
|
|
/// Creates a [Transmitter] which is connected to the given address.
|
|
|
fn transmitter(&self, addr: Arc<BlockAddr>) -> Self::TransmitterFut<'_>;
|
|
|
+
|
|
|
+ type CompleteErr: std::error::Error;
|
|
|
+ type CompleteFut<'a>: 'a + Future<Output = StdResult<(), Self::CompleteErr>> + Send
|
|
|
+ where
|
|
|
+ Self: 'a;
|
|
|
+ /// Returns a future which completes when this [Receiver] has completed (which may be never).
|
|
|
+ fn complete(&self) -> Result<Self::CompleteFut<'_>>;
|
|
|
}
|
|
|
|
|
|
/// A type which can be used to transmit messages.
|
|
@@ -411,12 +419,10 @@ struct QuicReceiver {
|
|
|
stop_tx: broadcast::Sender<()>,
|
|
|
endpoint: Endpoint,
|
|
|
resolver: Arc<CertResolver>,
|
|
|
+ join_handle: StdMutex<Option<JoinHandle<()>>>,
|
|
|
}
|
|
|
|
|
|
impl QuicReceiver {
|
|
|
- /// This defines the maximum number of blocking tasks which can be spawned at once.
|
|
|
- const BLOCKING_LIMIT: usize = 16;
|
|
|
-
|
|
|
fn new<F: 'static + MsgCallback>(
|
|
|
recv_addr: Arc<BlockAddr>,
|
|
|
resolver: Arc<CertResolver>,
|
|
@@ -425,12 +431,13 @@ impl QuicReceiver {
|
|
|
let socket_addr = recv_addr.socket_addr()?;
|
|
|
let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
|
|
|
let (stop_tx, stop_rx) = broadcast::channel(1);
|
|
|
- tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx));
|
|
|
+ let join_handle = tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx));
|
|
|
Ok(Self {
|
|
|
recv_addr,
|
|
|
stop_tx,
|
|
|
endpoint,
|
|
|
resolver,
|
|
|
+ join_handle: StdMutex::new(Some(join_handle)),
|
|
|
})
|
|
|
}
|
|
|
|
|
@@ -439,7 +446,6 @@ impl QuicReceiver {
|
|
|
callback: F,
|
|
|
mut stop_rx: broadcast::Receiver<()>,
|
|
|
) {
|
|
|
- let blocking_permits = Arc::new(Semaphore::new(Self::BLOCKING_LIMIT));
|
|
|
loop {
|
|
|
let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv());
|
|
|
let connection = unwrap_or_continue!(connecting.await, |err| error!(
|
|
@@ -449,7 +455,6 @@ impl QuicReceiver {
|
|
|
connection,
|
|
|
callback.clone(),
|
|
|
stop_rx.resubscribe(),
|
|
|
- blocking_permits.clone(),
|
|
|
));
|
|
|
}
|
|
|
}
|
|
@@ -458,7 +463,6 @@ impl QuicReceiver {
|
|
|
connection: Connection,
|
|
|
callback: F,
|
|
|
mut stop_rx: broadcast::Receiver<()>,
|
|
|
- blocking_permits: Arc<Semaphore>,
|
|
|
) {
|
|
|
let client_path = unwrap_or_return!(
|
|
|
Self::client_path(connection.peer_identity()),
|
|
@@ -468,20 +472,14 @@ impl QuicReceiver {
|
|
|
let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv());
|
|
|
let (send_stream, recv_stream) =
|
|
|
unwrap_or_continue!(result, |err| error!("error accepting stream: {err}"));
|
|
|
- let permit = unwrap_or_continue!(blocking_permits.clone().acquire_owned().await);
|
|
|
let client_path = client_path.clone();
|
|
|
let callback = callback.clone();
|
|
|
- // spawn_blocking is used to allow the user supplied callback to to block without
|
|
|
- // disrupting the main thread pool.
|
|
|
- tokio::task::spawn_blocking(move || {
|
|
|
- Handle::current().block_on(Self::handle_message(
|
|
|
- client_path,
|
|
|
- send_stream,
|
|
|
- recv_stream,
|
|
|
- permit,
|
|
|
- callback,
|
|
|
- ))
|
|
|
- });
|
|
|
+ tokio::task::spawn(Self::handle_message(
|
|
|
+ client_path,
|
|
|
+ send_stream,
|
|
|
+ recv_stream,
|
|
|
+ callback,
|
|
|
+ ));
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -489,8 +487,6 @@ impl QuicReceiver {
|
|
|
client_path: Arc<BlockPath>,
|
|
|
send_stream: SendStream,
|
|
|
recv_stream: RecvStream,
|
|
|
- // This argument must be kept alive until this method returns.
|
|
|
- _permit: OwnedSemaphorePermit,
|
|
|
callback: F,
|
|
|
) {
|
|
|
let framed_msg = FramedWrite::new(send_stream, MsgEncoder::new());
|
|
@@ -546,6 +542,16 @@ impl Receiver for QuicReceiver {
|
|
|
QuicTransmitter::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
|
|
|
})
|
|
|
}
|
|
|
+
|
|
|
+ type CompleteErr = JoinError;
|
|
|
+ type CompleteFut<'a> = JoinHandle<()>;
|
|
|
+ fn complete(&self) -> Result<Self::CompleteFut<'_>> {
|
|
|
+ let mut guard = self.join_handle.lock().display_err()?;
|
|
|
+ let handle = guard
|
|
|
+ .take()
|
|
|
+ .ok_or_else(|| bterr!("join handle has already been taken"))?;
|
|
|
+ Ok(handle)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
macro_rules! cleanup_on_err {
|