|
@@ -1,16 +1,15 @@
|
|
|
//! Code which enables sending messages between processes in the blocktree system.
|
|
|
-use btlib::{
|
|
|
- bterr,
|
|
|
- crypto::{rand_array, Creds, CredsPriv, HashKind, Scheme, Sign, Verifier},
|
|
|
- error::BoxInIoErr,
|
|
|
- BlockPath, Principal, Result, Writecap,
|
|
|
-};
|
|
|
+
|
|
|
+mod tls;
|
|
|
+use tls::*;
|
|
|
+
|
|
|
+use btlib::{bterr, crypto::Creds, error::BoxInIoErr, BlockPath, Result, Writecap};
|
|
|
use btserde::{read_from, write_to};
|
|
|
use bytes::{BufMut, BytesMut};
|
|
|
use core::{
|
|
|
future::Future,
|
|
|
marker::Send,
|
|
|
- ops::Deref,
|
|
|
+ ops::DerefMut,
|
|
|
pin::Pin,
|
|
|
task::{Context, Poll},
|
|
|
};
|
|
@@ -20,895 +19,602 @@ use futures::{
|
|
|
stream::Stream,
|
|
|
SinkExt, StreamExt,
|
|
|
};
|
|
|
-use lazy_static::lazy_static;
|
|
|
use log::error;
|
|
|
-use quinn::{ClientConfig, Endpoint, SendStream, ServerConfig};
|
|
|
-use rustls::{
|
|
|
- client::{HandshakeSignatureValid, ResolvesClientCert},
|
|
|
- internal::msgs::base::PayloadU16,
|
|
|
- server::{ClientCertVerified, ResolvesServerCert},
|
|
|
- sign::{CertifiedKey, SigningKey},
|
|
|
- Certificate, ConfigBuilder, ConfigSide, SignatureAlgorithm, SignatureScheme, WantsCipherSuites,
|
|
|
- WantsVerifier,
|
|
|
-};
|
|
|
+use quinn::{Connection, Endpoint, RecvStream, SendStream};
|
|
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
|
|
use std::{
|
|
|
+ any::Any,
|
|
|
collections::hash_map::DefaultHasher,
|
|
|
hash::{Hash, Hasher},
|
|
|
marker::PhantomData,
|
|
|
net::{IpAddr, SocketAddr},
|
|
|
- path::PathBuf,
|
|
|
sync::Arc,
|
|
|
};
|
|
|
-use tokio::sync::{
|
|
|
- broadcast::{self, error::TryRecvError},
|
|
|
- mpsc,
|
|
|
+use tokio::{
|
|
|
+ select,
|
|
|
+ sync::{broadcast, mpsc, Mutex},
|
|
|
};
|
|
|
use tokio_stream::wrappers::ReceiverStream;
|
|
|
-use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
|
|
|
-use zerocopy::FromBytes;
|
|
|
-
|
|
|
-pub use private::*;
|
|
|
-
|
|
|
-mod private {
|
|
|
- use super::*;
|
|
|
-
|
|
|
- /// Returns a [Router] which can be used to make a [Receiver] for the given path and
|
|
|
- /// [Sender] instances for any path.
|
|
|
- pub fn router<C: 'static + Creds + Send + Sync>(
|
|
|
- ip_addr: IpAddr,
|
|
|
- creds: Arc<C>,
|
|
|
- ) -> Result<impl Router> {
|
|
|
- let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
|
|
|
- let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
|
|
|
- QuicRouter::new(addr, Arc::new(CertResolver::new(creds)?))
|
|
|
- }
|
|
|
-
|
|
|
- lazy_static! {
|
|
|
- /// The default directory in which to place blocktree sockets.
|
|
|
- static ref SOCK_DIR: PathBuf = {
|
|
|
- let mut path: PathBuf = std::env::var("HOME").unwrap().into();
|
|
|
- path.push(".btmsg");
|
|
|
- path
|
|
|
- };
|
|
|
- }
|
|
|
+use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts, FramedRead, FramedWrite};
|
|
|
+
|
|
|
+/// Returns a [Router] which can be used to make a [Receiver] for the given path and
|
|
|
+/// [Sender] instances for any path.
|
|
|
+pub fn router<C: 'static + Creds + Send + Sync>(
|
|
|
+ ip_addr: IpAddr,
|
|
|
+ creds: Arc<C>,
|
|
|
+) -> Result<impl Router> {
|
|
|
+ let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
|
|
|
+ let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
|
|
|
+ QuicRouter::new(addr, Arc::new(CertResolver::new(creds)?))
|
|
|
+}
|
|
|
|
|
|
- fn common_config<Side: ConfigSide>(
|
|
|
- builder: ConfigBuilder<Side, WantsCipherSuites>,
|
|
|
- ) -> Result<ConfigBuilder<Side, WantsVerifier>> {
|
|
|
- builder
|
|
|
- .with_cipher_suites(&[rustls::cipher_suite::TLS13_AES_128_GCM_SHA256])
|
|
|
- .with_kx_groups(&[&rustls::kx_group::SECP256R1])
|
|
|
- .with_protocol_versions(&[&rustls::version::TLS13])
|
|
|
- .map_err(|err| err.into())
|
|
|
- }
|
|
|
+/// An address which identifies a block on the network. An instance of this struct can be
|
|
|
+/// used to get a socket address for the block this address refers to.
|
|
|
+#[derive(PartialEq, Eq, Hash, Clone, Debug)]
|
|
|
+pub struct BlockAddr {
|
|
|
+ ip_addr: IpAddr,
|
|
|
+ path: Arc<BlockPath>,
|
|
|
+}
|
|
|
|
|
|
- fn server_config(resolver: Arc<CertResolver>) -> Result<ServerConfig> {
|
|
|
- let server_config = common_config(rustls::ServerConfig::builder())?
|
|
|
- .with_client_cert_verifier(Arc::new(ClientCertVerifier))
|
|
|
- .with_cert_resolver(resolver);
|
|
|
- Ok(ServerConfig::with_crypto(Arc::new(server_config)))
|
|
|
+impl BlockAddr {
|
|
|
+ pub fn new(ip_addr: IpAddr, path: Arc<BlockPath>) -> Self {
|
|
|
+ Self { ip_addr, path }
|
|
|
}
|
|
|
|
|
|
- fn client_config(
|
|
|
- server_path: Arc<BlockPath>,
|
|
|
- resolver: Arc<CertResolver>,
|
|
|
- ) -> Result<ClientConfig> {
|
|
|
- let client_config = common_config(rustls::ClientConfig::builder())?
|
|
|
- .with_custom_certificate_verifier(Arc::new(ServerCertVerifier::new(server_path)))
|
|
|
- .with_client_cert_resolver(resolver);
|
|
|
- Ok(ClientConfig::new(Arc::new(client_config)))
|
|
|
- }
|
|
|
-
|
|
|
- fn to_cert_err(err: btlib::Error) -> rustls::Error {
|
|
|
- rustls::Error::InvalidCertificateData(err.to_string())
|
|
|
- }
|
|
|
-
|
|
|
- fn verify_tls13_signature(
|
|
|
- message: &[u8],
|
|
|
- cert: &Certificate,
|
|
|
- dss: &rustls::DigitallySignedStruct,
|
|
|
- ) -> std::result::Result<(), rustls::Error> {
|
|
|
- let (_, subject_key) = Writecap::from_cert_chain(cert, &[]).map_err(to_cert_err)?;
|
|
|
- subject_key
|
|
|
- .verify(std::iter::once(message), dss.signature())
|
|
|
- .map_err(|_| rustls::Error::InvalidCertificateSignature)?;
|
|
|
- Ok(())
|
|
|
+ pub fn ip_addr(&self) -> IpAddr {
|
|
|
+ self.ip_addr
|
|
|
}
|
|
|
|
|
|
- /// Verifier for the certificate chain presented by the server.
|
|
|
- struct ServerCertVerifier {
|
|
|
- server_path: Arc<BlockPath>,
|
|
|
+ pub fn path(&self) -> &BlockPath {
|
|
|
+ self.path.as_ref()
|
|
|
}
|
|
|
|
|
|
- impl ServerCertVerifier {
|
|
|
- fn new(server_path: Arc<BlockPath>) -> Self {
|
|
|
- Self { server_path }
|
|
|
- }
|
|
|
+ fn port(&self) -> u16 {
|
|
|
+ // TODO: We should probably choose a stable hasher, as the standard library devs could
|
|
|
+ // change `DefaultHasher` at any time.
|
|
|
+ let mut hasher = DefaultHasher::new();
|
|
|
+ self.path.hash(&mut hasher);
|
|
|
+ let hash = hasher.finish();
|
|
|
+ // We compute a port in the dynamic range [49152, 65535] as defined by RFC 6335.
|
|
|
+ const NUM_RES_PORTS: u16 = 49153;
|
|
|
+ const PORTS_AVAIL: u64 = (u16::MAX - NUM_RES_PORTS) as u64;
|
|
|
+ NUM_RES_PORTS + (hash % PORTS_AVAIL) as u16
|
|
|
}
|
|
|
|
|
|
- impl rustls::client::ServerCertVerifier for ServerCertVerifier {
|
|
|
- fn verify_server_cert(
|
|
|
- &self,
|
|
|
- end_entity: &Certificate,
|
|
|
- intermediates: &[Certificate],
|
|
|
- _server_name: &rustls::ServerName,
|
|
|
- _scts: &mut dyn Iterator<Item = &[u8]>,
|
|
|
- _ocsp_response: &[u8],
|
|
|
- _now: std::time::SystemTime,
|
|
|
- ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
|
|
|
- let (writecap, ..) =
|
|
|
- Writecap::from_cert_chain(end_entity, intermediates).map_err(to_cert_err)?;
|
|
|
- let path = writecap.bind_path();
|
|
|
- if &path != self.server_path.as_ref() {
|
|
|
- return Err(rustls::Error::InvalidCertificateData(format!(
|
|
|
- "expected writecap with path '{}' got writecap with path '{path}'",
|
|
|
- self.server_path
|
|
|
- )));
|
|
|
- }
|
|
|
- writecap.assert_valid_for(&path).map_err(to_cert_err)?;
|
|
|
- Ok(rustls::client::ServerCertVerified::assertion())
|
|
|
- }
|
|
|
-
|
|
|
- fn verify_tls13_signature(
|
|
|
- &self,
|
|
|
- message: &[u8],
|
|
|
- cert: &Certificate,
|
|
|
- dss: &rustls::DigitallySignedStruct,
|
|
|
- ) -> std::result::Result<rustls::client::HandshakeSignatureValid, rustls::Error> {
|
|
|
- verify_tls13_signature(message, cert, dss)?;
|
|
|
- Ok(HandshakeSignatureValid::assertion())
|
|
|
- }
|
|
|
+ /// Returns the socket address of the block this instance refers to.
|
|
|
+ pub fn socket_addr(&self) -> SocketAddr {
|
|
|
+ SocketAddr::new(self.ip_addr, self.port())
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- /// Verifier for the certificate chain presented by the client.
|
|
|
- struct ClientCertVerifier;
|
|
|
-
|
|
|
- impl rustls::server::ClientCertVerifier for ClientCertVerifier {
|
|
|
- fn verify_client_cert(
|
|
|
- &self,
|
|
|
- end_entity: &Certificate,
|
|
|
- intermediates: &[Certificate],
|
|
|
- _now: std::time::SystemTime,
|
|
|
- ) -> std::result::Result<rustls::server::ClientCertVerified, rustls::Error> {
|
|
|
- let (writecap, ..) =
|
|
|
- Writecap::from_cert_chain(end_entity, intermediates).map_err(to_cert_err)?;
|
|
|
- writecap
|
|
|
- .assert_valid_for(writecap.path())
|
|
|
- .map_err(to_cert_err)?;
|
|
|
- Ok(ClientCertVerified::assertion())
|
|
|
- }
|
|
|
-
|
|
|
- fn client_auth_root_subjects(&self) -> Option<rustls::DistinguishedNames> {
|
|
|
- let der = match Principal::default().to_name_der() {
|
|
|
- Ok(der) => der,
|
|
|
- Err(err) => {
|
|
|
- error!("failed to create distinguished name from root principal: {err}");
|
|
|
- return None;
|
|
|
- }
|
|
|
- };
|
|
|
- Some(vec![PayloadU16(der)])
|
|
|
- }
|
|
|
+/// Trait for messages which can be transmitted using the call method.
|
|
|
+pub trait CallTx: Serialize + Send {
|
|
|
+ type Reply: 'static + DeserializeOwned + Send;
|
|
|
+}
|
|
|
|
|
|
- fn verify_tls13_signature(
|
|
|
- &self,
|
|
|
- message: &[u8],
|
|
|
- cert: &Certificate,
|
|
|
- dss: &rustls::DigitallySignedStruct,
|
|
|
- ) -> std::result::Result<rustls::client::HandshakeSignatureValid, rustls::Error> {
|
|
|
- verify_tls13_signature(message, cert, dss)?;
|
|
|
- Ok(rustls::client::HandshakeSignatureValid::assertion())
|
|
|
- }
|
|
|
- }
|
|
|
+/// Trait for messages which can be transmitted using the send method.
|
|
|
+/// Types which implement this trait should choose `()` as their reply type.
|
|
|
+pub trait SendTx: CallTx {}
|
|
|
|
|
|
- struct CertResolver {
|
|
|
- cert_key: Arc<CertifiedKey>,
|
|
|
- }
|
|
|
+/// Trait for messages which are received from the call method.
|
|
|
+pub trait CallRx: 'static + DeserializeOwned + Send {
|
|
|
+ type Reply<'a>: 'a + Serialize + Send;
|
|
|
+}
|
|
|
|
|
|
- impl CertResolver {
|
|
|
- fn new<C: Creds + Send + Sync + 'static>(creds: Arc<C>) -> Result<Self> {
|
|
|
- let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
|
|
|
- let chain = writecap.to_cert_chain(creds.public_sign())?;
|
|
|
- let mut certs = Vec::with_capacity(chain.len());
|
|
|
- for cert in chain {
|
|
|
- certs.push(Certificate(cert))
|
|
|
- }
|
|
|
- let key = Arc::new(CredRef::new(creds));
|
|
|
- let cert_key = Arc::new(CertifiedKey {
|
|
|
- cert: certs,
|
|
|
- key,
|
|
|
- ocsp: None,
|
|
|
- sct_list: None,
|
|
|
- });
|
|
|
- Ok(Self { cert_key })
|
|
|
- }
|
|
|
- }
|
|
|
+/// Trait for messages which are received from the send method.
|
|
|
+/// Types which implement this trait should choose `()` as their reply type.
|
|
|
+pub trait SendRx: CallRx {}
|
|
|
+
|
|
|
+/// Indicates whether a message was sent using `call` or `send`.
|
|
|
+#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
|
|
|
+enum MsgKind {
|
|
|
+ /// This message expects exactly one reply.
|
|
|
+ Call,
|
|
|
+ /// This message expects exactly zero replies.
|
|
|
+ Send,
|
|
|
+}
|
|
|
|
|
|
- impl ResolvesClientCert for CertResolver {
|
|
|
- fn resolve(
|
|
|
- &self,
|
|
|
- _acceptable_issuers: &[&[u8]],
|
|
|
- _sigschemes: &[rustls::SignatureScheme],
|
|
|
- ) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
|
|
- Some(self.cert_key.clone())
|
|
|
- }
|
|
|
+#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
|
|
|
+struct Envelope<T> {
|
|
|
+ kind: MsgKind,
|
|
|
+ msg: T,
|
|
|
+}
|
|
|
|
|
|
- fn has_certs(&self) -> bool {
|
|
|
- true
|
|
|
+impl<T> Envelope<T> {
|
|
|
+ fn send(msg: T) -> Self {
|
|
|
+ Self {
|
|
|
+ msg,
|
|
|
+ kind: MsgKind::Send,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- impl ResolvesServerCert for CertResolver {
|
|
|
- fn resolve(
|
|
|
- &self,
|
|
|
- _client_hello: rustls::server::ClientHello,
|
|
|
- ) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
|
|
- Some(self.cert_key.clone())
|
|
|
+ fn call(msg: T) -> Self {
|
|
|
+ Self {
|
|
|
+ msg,
|
|
|
+ kind: MsgKind::Call,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- trait SignExt {
|
|
|
- fn as_signature_scheme(&self) -> rustls::SignatureScheme;
|
|
|
- fn as_signature_algorithm(&self) -> rustls::SignatureAlgorithm;
|
|
|
+ fn msg(&self) -> &T {
|
|
|
+ &self.msg
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- impl SignExt for Sign {
|
|
|
- fn as_signature_scheme(&self) -> SignatureScheme {
|
|
|
- match self {
|
|
|
- Self::RsaSsaPss(scheme) => match scheme.hash_kind() {
|
|
|
- HashKind::Sha2_256 => SignatureScheme::RSA_PSS_SHA256,
|
|
|
- HashKind::Sha2_512 => SignatureScheme::RSA_PSS_SHA512,
|
|
|
- },
|
|
|
- }
|
|
|
- }
|
|
|
+/// A message tagged with the block path that it was sent from.
|
|
|
+pub struct MsgReceived<T> {
|
|
|
+ from: Arc<BlockPath>,
|
|
|
+ msg: Envelope<T>,
|
|
|
+ replier: Replier,
|
|
|
+}
|
|
|
|
|
|
- fn as_signature_algorithm(&self) -> SignatureAlgorithm {
|
|
|
- match self {
|
|
|
- Self::RsaSsaPss(..) => SignatureAlgorithm::RSA,
|
|
|
- }
|
|
|
- }
|
|
|
+impl<T> MsgReceived<T> {
|
|
|
+ fn new(from: Arc<BlockPath>, msg: Envelope<T>, replier: Replier) -> Self {
|
|
|
+ Self { from, msg, replier }
|
|
|
}
|
|
|
|
|
|
- /// A new type around `Arc<C>` which allows rustls' traits to be implemented.
|
|
|
- struct CredRef<C> {
|
|
|
- creds: Arc<C>,
|
|
|
+ /// The path from which this message was received.
|
|
|
+ pub fn from(&self) -> &Arc<BlockPath> {
|
|
|
+ &self.from
|
|
|
}
|
|
|
|
|
|
- impl<C> CredRef<C> {
|
|
|
- fn new(creds: Arc<C>) -> Self {
|
|
|
- Self { creds }
|
|
|
- }
|
|
|
+ /// Payload contained in this message.
|
|
|
+ pub fn body(&self) -> &T {
|
|
|
+ self.msg.msg()
|
|
|
}
|
|
|
|
|
|
- impl<C> Deref for CredRef<C> {
|
|
|
- type Target = C;
|
|
|
- fn deref(&self) -> &Self::Target {
|
|
|
- &self.creds
|
|
|
- }
|
|
|
+ /// Returns true if and only if this messages needs to be replied to.
|
|
|
+ pub fn needs_reply(&self) -> bool {
|
|
|
+ self.replier.parts.is_some()
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- impl<C: CredsPriv + Send + Sync + 'static> SigningKey for CredRef<C> {
|
|
|
- fn choose_scheme(
|
|
|
- &self,
|
|
|
- offered: &[rustls::SignatureScheme],
|
|
|
- ) -> Option<Box<dyn rustls::sign::Signer>> {
|
|
|
- if offered.contains(&self.sign_kind().as_signature_scheme()) {
|
|
|
- Some(Box::new(Self::new(self.creds.clone())))
|
|
|
- } else {
|
|
|
- None
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- fn algorithm(&self) -> rustls::SignatureAlgorithm {
|
|
|
- self.sign_kind().as_signature_algorithm()
|
|
|
- }
|
|
|
+impl<T: CallRx> MsgReceived<T> {
|
|
|
+ /// Replies to this message. This method must be called exactly once for [CallRx] messages
|
|
|
+ /// and exactly zero times for [SendRx] messages. The `needs_reply` method can be called
|
|
|
+ /// on this instance to determine if a reply still needs to be sent.
|
|
|
+ pub async fn reply<'a>(&mut self, reply: T::Reply<'a>) -> Result<()> {
|
|
|
+ self.replier.reply(reply).await
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- impl<C: CredsPriv + Send + Sync> rustls::sign::Signer for CredRef<C> {
|
|
|
- fn sign(&self, message: &[u8]) -> std::result::Result<Vec<u8>, rustls::Error> {
|
|
|
- self.creds
|
|
|
- .sign(std::iter::once(message))
|
|
|
- .map(|sig| sig.take_data())
|
|
|
- .map_err(|err| rustls::Error::General(err.to_string()))
|
|
|
- }
|
|
|
+/// A type which can be used to receive messages.
|
|
|
+pub trait Receiver<T>: Stream<Item = Result<MsgReceived<T>>> {
|
|
|
+ /// The address at which messages will be received.
|
|
|
+ fn addr(&self) -> &BlockAddr;
|
|
|
+}
|
|
|
|
|
|
- fn scheme(&self) -> rustls::SignatureScheme {
|
|
|
- self.sign_kind().as_signature_scheme()
|
|
|
- }
|
|
|
- }
|
|
|
+/// A type which can be used to transmit messages.
|
|
|
+pub trait Transmitter {
|
|
|
+ type SendFut<'a, T>: 'a + Future<Output = Result<()>> + Send
|
|
|
+ where
|
|
|
+ Self: 'a,
|
|
|
+ T: 'a + Serialize + Send;
|
|
|
|
|
|
- #[derive(PartialEq, Eq, Hash, Clone, Debug)]
|
|
|
- pub struct BlockAddr {
|
|
|
- ip_addr: IpAddr,
|
|
|
- path: Arc<BlockPath>,
|
|
|
- }
|
|
|
+ /// Transmit a message to the connected [Receiver] without waiting for a reply.
|
|
|
+ fn send<'a, T: 'a + SendTx>(&'a mut self, msg: T) -> Self::SendFut<'a, T>;
|
|
|
|
|
|
- impl BlockAddr {
|
|
|
- pub fn new(ip_addr: IpAddr, path: Arc<BlockPath>) -> Self {
|
|
|
- Self { ip_addr, path }
|
|
|
- }
|
|
|
+ type CallFut<'a, T>: 'a + Future<Output = Result<T::Reply>> + Send
|
|
|
+ where
|
|
|
+ Self: 'a,
|
|
|
+ T: 'a + CallTx;
|
|
|
|
|
|
- pub fn ip_addr(&self) -> IpAddr {
|
|
|
- self.ip_addr
|
|
|
- }
|
|
|
+ /// Transmit a message to the connected [Receiver] and wait for a reply.
|
|
|
+ fn call<'a, T>(&'a mut self, msg: T) -> Self::CallFut<'a, T>
|
|
|
+ where
|
|
|
+ T: 'a + CallTx;
|
|
|
|
|
|
- pub fn path(&self) -> &BlockPath {
|
|
|
- self.path.as_ref()
|
|
|
- }
|
|
|
+ type FinishFut: Future<Output = Result<()>> + Send;
|
|
|
|
|
|
- fn port(&self) -> u16 {
|
|
|
- let mut hasher = DefaultHasher::new();
|
|
|
- self.path.hash(&mut hasher);
|
|
|
- let hash = hasher.finish();
|
|
|
- // We compute a port in the dynamic range [49152, 65535] as defined by RFC 6335.
|
|
|
- const NUM_RES_PORTS: u16 = 49153;
|
|
|
- const PORTS_AVAIL: u64 = (u16::MAX - NUM_RES_PORTS) as u64;
|
|
|
- NUM_RES_PORTS + (hash % PORTS_AVAIL) as u16
|
|
|
- }
|
|
|
+ /// Finish any ongoing transmissions and close the connection to the [Receiver].
|
|
|
+ fn finish(self) -> Self::FinishFut;
|
|
|
|
|
|
- pub fn socket_addr(&self) -> SocketAddr {
|
|
|
- let port = self.port();
|
|
|
- SocketAddr::new(self.ip_addr, port)
|
|
|
- }
|
|
|
- }
|
|
|
+ /// Returns the address that this instance is transmitting to.
|
|
|
+ fn addr(&self) -> &BlockAddr;
|
|
|
+}
|
|
|
|
|
|
- /// Generates and returns a new message ID.
|
|
|
- fn rand_msg_id() -> Result<u128> {
|
|
|
- const LEN: usize = std::mem::size_of::<u128>();
|
|
|
- let bytes = rand_array::<LEN>()?;
|
|
|
- let option = u128::read_from(bytes.as_slice());
|
|
|
- // Safety: because LEN == size_of::<u128>(), read_from should have returned Some.
|
|
|
- Ok(option.unwrap())
|
|
|
- }
|
|
|
+/// Trait for types which can create [Transmitter]s and [Receiver]s.
|
|
|
+pub trait Router {
|
|
|
+ type Transmitter: Transmitter + Send;
|
|
|
+ type TransmitterFut<'a>: 'a + Future<Output = Result<Self::Transmitter>> + Send
|
|
|
+ where
|
|
|
+ Self: 'a;
|
|
|
|
|
|
- #[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
|
|
|
- pub struct Msg<T> {
|
|
|
- pub id: u128,
|
|
|
- pub body: T,
|
|
|
- }
|
|
|
+ /// Creates a [Transmitter] which is connected to the given address.
|
|
|
+ fn transmitter(&self, addr: Arc<BlockAddr>) -> Self::TransmitterFut<'_>;
|
|
|
|
|
|
- impl<T> Msg<T> {
|
|
|
- pub fn new(id: u128, body: T) -> Self {
|
|
|
- Self { id, body }
|
|
|
- }
|
|
|
+ type Receiver<T: CallRx>: Receiver<T> + Send + Unpin;
|
|
|
+ type ReceiverFut<'a, T>: 'a + Future<Output = Result<Self::Receiver<T>>> + Send
|
|
|
+ where
|
|
|
+ T: CallRx,
|
|
|
+ Self: 'a;
|
|
|
|
|
|
- pub fn with_rand_id(body: T) -> Result<Self> {
|
|
|
- Ok(Self {
|
|
|
- id: rand_msg_id()?,
|
|
|
- body,
|
|
|
- })
|
|
|
- }
|
|
|
- }
|
|
|
+ /// Creates a [Receiver] which will receive message at the address of this [Router].
|
|
|
+ fn receiver<T: CallRx>(&self) -> Self::ReceiverFut<'_, T>;
|
|
|
+}
|
|
|
|
|
|
- /// A message tagged with the block path that it was sent from.
|
|
|
- pub struct MsgReceived<T> {
|
|
|
- pub from: Arc<BlockPath>,
|
|
|
- pub msg: Msg<T>,
|
|
|
- }
|
|
|
+/// Encodes messages using [btserde].
|
|
|
+struct MsgEncoder;
|
|
|
|
|
|
- impl<T> MsgReceived<T> {
|
|
|
- fn new(from: Arc<BlockPath>, msg: Msg<T>) -> Self {
|
|
|
- Self { from, msg }
|
|
|
- }
|
|
|
+impl MsgEncoder {
|
|
|
+ fn new() -> Self {
|
|
|
+ Self
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- /// A type which can be used to send messages.
|
|
|
- /// Once the "Permit impl Trait in type aliases" https://github.com/rust-lang/rust/issues/63063
|
|
|
- /// feature lands the future types in this trait should be rewritten to use it.
|
|
|
- pub trait Sender {
|
|
|
- type SendFut<'a, T>: 'a + Future<Output = Result<()>> + Send
|
|
|
- where
|
|
|
- Self: 'a,
|
|
|
- T: 'a + Serialize + Send;
|
|
|
-
|
|
|
- fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: Msg<T>) -> Self::SendFut<'a, T>;
|
|
|
-
|
|
|
- type FinishFut: Future<Output = Result<()>> + Send;
|
|
|
-
|
|
|
- fn finish(self) -> Self::FinishFut;
|
|
|
+impl<T: Serialize> Encoder<T> for MsgEncoder {
|
|
|
+ type Error = btlib::Error;
|
|
|
+
|
|
|
+ fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<()> {
|
|
|
+ const U64_LEN: usize = std::mem::size_of::<u64>();
|
|
|
+ let payload = dst.split_off(U64_LEN);
|
|
|
+ let mut writer = payload.writer();
|
|
|
+ write_to(&item, &mut writer)?;
|
|
|
+ let payload = writer.into_inner();
|
|
|
+ let payload_len = payload.len() as u64;
|
|
|
+ let mut writer = dst.writer();
|
|
|
+ write_to(&payload_len, &mut writer)?;
|
|
|
+ let dst = writer.into_inner();
|
|
|
+ dst.unsplit(payload);
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- fn addr(&self) -> &BlockAddr;
|
|
|
+/// Decodes messages using [btserde].
|
|
|
+struct MsgDecoder<T>(PhantomData<T>);
|
|
|
|
|
|
- fn send_with_rand_id<'a, T: 'a + Serialize + Send>(
|
|
|
- &'a mut self,
|
|
|
- body: T,
|
|
|
- ) -> Self::SendFut<'a, T> {
|
|
|
- let msg = Msg::with_rand_id(body).unwrap();
|
|
|
- self.send(msg)
|
|
|
- }
|
|
|
+impl<T> MsgDecoder<T> {
|
|
|
+ fn new() -> Self {
|
|
|
+ Self(PhantomData)
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- /// A type which can be used to receive messages.
|
|
|
- pub trait Receiver<T>: Stream<Item = Result<MsgReceived<T>>> {
|
|
|
- fn addr(&self) -> &BlockAddr;
|
|
|
+impl<T: DeserializeOwned> Decoder for MsgDecoder<T> {
|
|
|
+ type Item = T;
|
|
|
+ type Error = btlib::Error;
|
|
|
+
|
|
|
+ fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
|
|
|
+ let mut slice: &[u8] = src.as_ref();
|
|
|
+ let payload_len: u64 = match read_from(&mut slice) {
|
|
|
+ Ok(payload_len) => payload_len,
|
|
|
+ Err(err) => {
|
|
|
+ return match err {
|
|
|
+ btserde::Error::Eof => Ok(None),
|
|
|
+ btserde::Error::Io(ref io_err) => match io_err.kind() {
|
|
|
+ std::io::ErrorKind::UnexpectedEof => Ok(None),
|
|
|
+ _ => Err(err.into()),
|
|
|
+ },
|
|
|
+ _ => Err(err.into()),
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
+ let payload_len: usize = payload_len.try_into().box_err()?;
|
|
|
+ if slice.len() < payload_len {
|
|
|
+ src.reserve(payload_len - slice.len());
|
|
|
+ return Ok(None);
|
|
|
+ }
|
|
|
+ let msg = read_from(&mut slice)?;
|
|
|
+ // Consume all the bytes that have been read out of the buffer.
|
|
|
+ let _ = src.split_to(std::mem::size_of::<u64>() + payload_len);
|
|
|
+ Ok(Some(msg))
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- pub trait Router {
|
|
|
- type Sender: Sender + Send;
|
|
|
- type SenderFut<'a>: 'a + Future<Output = Result<Self::Sender>> + Send
|
|
|
- where
|
|
|
- Self: 'a;
|
|
|
+struct QuicRouter {
|
|
|
+ recv_addr: Arc<BlockAddr>,
|
|
|
+ resolver: Arc<CertResolver>,
|
|
|
+ endpoint: Endpoint,
|
|
|
+}
|
|
|
|
|
|
- fn sender(&self, addr: Arc<BlockAddr>) -> Self::SenderFut<'_>;
|
|
|
+impl QuicRouter {
|
|
|
+ fn new(recv_addr: Arc<BlockAddr>, resolver: Arc<CertResolver>) -> Result<Self> {
|
|
|
+ let socket_addr = recv_addr.socket_addr();
|
|
|
+ let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
|
|
|
+ Ok(Self {
|
|
|
+ endpoint,
|
|
|
+ resolver,
|
|
|
+ recv_addr,
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- type Receiver<T: 'static + DeserializeOwned + Send>: Receiver<T> + Send + Unpin;
|
|
|
- type ReceiverFut<'a, T>: 'a + Future<Output = Result<Self::Receiver<T>>> + Send
|
|
|
- where
|
|
|
- T: 'static + DeserializeOwned + Send + Unpin,
|
|
|
- Self: 'a;
|
|
|
+impl Router for QuicRouter {
|
|
|
+ type Receiver<T: CallRx> = QuicReceiver<T>;
|
|
|
+ type ReceiverFut<'a, T: CallRx> = Ready<Result<QuicReceiver<T>>>;
|
|
|
|
|
|
- fn receiver<T: 'static + DeserializeOwned + Send + Unpin>(
|
|
|
- &self,
|
|
|
- ) -> Self::ReceiverFut<'_, T>;
|
|
|
+ fn receiver<T: CallRx>(&self) -> Self::ReceiverFut<'_, T> {
|
|
|
+ ready(QuicReceiver::new(
|
|
|
+ self.endpoint.clone(),
|
|
|
+ self.recv_addr.clone(),
|
|
|
+ ))
|
|
|
}
|
|
|
|
|
|
- /// Encodes and decodes messages using [btserde].
|
|
|
- struct MsgEncoder;
|
|
|
+ type Transmitter = QuicSender;
|
|
|
+ type TransmitterFut<'a> = Pin<Box<dyn 'a + Future<Output = Result<QuicSender>> + Send>>;
|
|
|
|
|
|
- impl MsgEncoder {
|
|
|
- fn new() -> Self {
|
|
|
- Self
|
|
|
- }
|
|
|
+ fn transmitter(&self, addr: Arc<BlockAddr>) -> Self::TransmitterFut<'_> {
|
|
|
+ Box::pin(async {
|
|
|
+ QuicSender::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
|
|
|
+ })
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- impl<T: Serialize> Encoder<Msg<T>> for MsgEncoder {
|
|
|
- type Error = btlib::Error;
|
|
|
-
|
|
|
- fn encode(&mut self, item: Msg<T>, dst: &mut BytesMut) -> Result<()> {
|
|
|
- const U64_LEN: usize = std::mem::size_of::<u64>();
|
|
|
- let payload = dst.split_off(U64_LEN);
|
|
|
- let mut writer = payload.writer();
|
|
|
- write_to(&item, &mut writer)?;
|
|
|
- let payload = writer.into_inner();
|
|
|
- let payload_len = payload.len() as u64;
|
|
|
- let mut writer = dst.writer();
|
|
|
- write_to(&payload_len, &mut writer)?;
|
|
|
- let dst = writer.into_inner();
|
|
|
- dst.unsplit(payload);
|
|
|
- Ok(())
|
|
|
- }
|
|
|
- }
|
|
|
+type SharedFrameParts = Arc<Mutex<Option<FramedParts<SendStream, MsgEncoder>>>>;
|
|
|
|
|
|
- struct MsgDecoder<T>(PhantomData<T>);
|
|
|
+#[derive(Clone)]
|
|
|
+struct Replier {
|
|
|
+ parts: Option<SharedFrameParts>,
|
|
|
+}
|
|
|
|
|
|
- impl<T> MsgDecoder<T> {
|
|
|
- fn new() -> Self {
|
|
|
- Self(PhantomData)
|
|
|
+impl Replier {
|
|
|
+ fn new(send_stream: SendStream) -> Self {
|
|
|
+ let parts = FramedParts::new::<()>(send_stream, MsgEncoder::new());
|
|
|
+ let parts = Some(Arc::new(Mutex::new(Some(parts))));
|
|
|
+ Self { parts }
|
|
|
+ }
|
|
|
+
|
|
|
+ fn empty() -> Self {
|
|
|
+ Self { parts: None }
|
|
|
+ }
|
|
|
+
|
|
|
+ async fn reply<T: Serialize + Send>(&mut self, reply: T) -> Result<()> {
|
|
|
+ let parts = self
|
|
|
+ .parts
|
|
|
+ .take()
|
|
|
+ .ok_or_else(|| bterr!("reply has already been sent"))?;
|
|
|
+ let result = {
|
|
|
+ let mut guard = parts.lock().await;
|
|
|
+ // We must ensure the parts are put back before we leave this block.
|
|
|
+ let parts = guard.take().unwrap();
|
|
|
+ let mut stream = Framed::from_parts(parts);
|
|
|
+ let result = stream.send(reply).await;
|
|
|
+ *guard = Some(stream.into_parts());
|
|
|
+ result
|
|
|
+ };
|
|
|
+ if result.is_err() {
|
|
|
+ // If the result is an error put back the parts so the caller may try again.
|
|
|
+ self.parts = Some(parts);
|
|
|
}
|
|
|
+ result
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- impl<T: DeserializeOwned> Decoder for MsgDecoder<T> {
|
|
|
- type Item = Msg<T>;
|
|
|
- type Error = btlib::Error;
|
|
|
-
|
|
|
- fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
|
|
|
- let mut slice: &[u8] = src.as_ref();
|
|
|
- let payload_len: u64 = match read_from(&mut slice) {
|
|
|
- Ok(payload_len) => payload_len,
|
|
|
- Err(err) => {
|
|
|
- if let btserde::Error::Eof = err {
|
|
|
- return Ok(None);
|
|
|
- }
|
|
|
- return Err(err.into());
|
|
|
- }
|
|
|
- };
|
|
|
- let payload_len: usize = payload_len.try_into().box_err()?;
|
|
|
- if slice.len() < payload_len {
|
|
|
- src.reserve(payload_len - slice.len());
|
|
|
- return Ok(None);
|
|
|
+macro_rules! handle_err {
|
|
|
+ ($result:expr, $on_err:expr, $control_flow:expr) => {
|
|
|
+ match $result {
|
|
|
+ Ok(inner) => inner,
|
|
|
+ Err(err) => {
|
|
|
+ $on_err(err);
|
|
|
+ $control_flow;
|
|
|
}
|
|
|
- let msg = read_from(&mut slice)?;
|
|
|
- // Consume all the bytes that have been read out of the buffer.
|
|
|
- let _ = src.split_to(std::mem::size_of::<u64>() + payload_len);
|
|
|
- Ok(Some(msg))
|
|
|
}
|
|
|
- }
|
|
|
-
|
|
|
- struct QuicRouter {
|
|
|
- recv_addr: Arc<BlockAddr>,
|
|
|
- resolver: Arc<CertResolver>,
|
|
|
- endpoint: Endpoint,
|
|
|
- }
|
|
|
+ };
|
|
|
+}
|
|
|
|
|
|
- impl QuicRouter {
|
|
|
- fn new(recv_addr: Arc<BlockAddr>, resolver: Arc<CertResolver>) -> Result<Self> {
|
|
|
- let socket_addr = recv_addr.socket_addr();
|
|
|
- let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
|
|
|
- Ok(Self {
|
|
|
- endpoint,
|
|
|
- resolver,
|
|
|
- recv_addr,
|
|
|
- })
|
|
|
- }
|
|
|
- }
|
|
|
+/// Unwraps the given result, or if the result is an error, returns from the enclosing function.
|
|
|
+macro_rules! unwrap_or_return {
|
|
|
+ ($result:expr, $on_err:expr) => {
|
|
|
+ handle_err!($result, $on_err, return)
|
|
|
+ };
|
|
|
+ ($result:expr) => {
|
|
|
+ unwrap_or_return!($result, |err| error!("{err}"))
|
|
|
+ };
|
|
|
+}
|
|
|
|
|
|
- impl Router for QuicRouter {
|
|
|
- type Receiver<T: 'static + DeserializeOwned + Send> = QuicReceiver<T>;
|
|
|
- type ReceiverFut<'a, T: 'static + DeserializeOwned + Send + Unpin> =
|
|
|
- Ready<Result<QuicReceiver<T>>>;
|
|
|
+/// Unwraps the given result, or if the result is an error, continues the enclosing loop.
|
|
|
+macro_rules! unwrap_or_continue {
|
|
|
+ ($result:expr, $on_err:expr) => {
|
|
|
+ handle_err!($result, $on_err, continue)
|
|
|
+ };
|
|
|
+ ($result:expr) => {
|
|
|
+ unwrap_or_continue!($result, |err| error!("{err}"))
|
|
|
+ };
|
|
|
+}
|
|
|
|
|
|
- fn receiver<T: 'static + DeserializeOwned + Send + Unpin>(
|
|
|
- &self,
|
|
|
- ) -> Self::ReceiverFut<'_, T> {
|
|
|
- ready(QuicReceiver::new(
|
|
|
- self.endpoint.clone(),
|
|
|
- self.recv_addr.clone(),
|
|
|
- ))
|
|
|
- }
|
|
|
+/// Awaits its first argument, unless interrupted by its second argument, in which case the
|
|
|
+/// enclosing function returns. The second argument needs to be cancel safe, but the first
|
|
|
+/// need not be if it is discarded when the enclosing function returns (because losing messages
|
|
|
+/// from the first argument doesn't matter in this case).
|
|
|
+macro_rules! await_or_stop {
|
|
|
+ ($future:expr, $stop_fut:expr) => {
|
|
|
+ select! {
|
|
|
+ Some(connecting) = $future => connecting,
|
|
|
+ _ = $stop_fut => return,
|
|
|
+ }
|
|
|
+ };
|
|
|
+}
|
|
|
|
|
|
- type Sender = QuicSender;
|
|
|
- type SenderFut<'a> = Pin<Box<dyn 'a + Future<Output = Result<QuicSender>> + Send>>;
|
|
|
+struct QuicReceiver<T> {
|
|
|
+ recv_addr: Arc<BlockAddr>,
|
|
|
+ stop_tx: broadcast::Sender<()>,
|
|
|
+ stream: ReceiverStream<Result<MsgReceived<T>>>,
|
|
|
+}
|
|
|
|
|
|
- fn sender(&self, addr: Arc<BlockAddr>) -> Self::SenderFut<'_> {
|
|
|
- Box::pin(async {
|
|
|
- QuicSender::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
|
|
|
- })
|
|
|
- }
|
|
|
- }
|
|
|
+impl<T: CallRx> QuicReceiver<T> {
|
|
|
+ /// The size of the buffer to store received messages in.
|
|
|
+ const MSG_BUF_SZ: usize = 64;
|
|
|
|
|
|
- struct QuicReceiver<T> {
|
|
|
- recv_addr: Arc<BlockAddr>,
|
|
|
- stop_tx: broadcast::Sender<()>,
|
|
|
- stream: ReceiverStream<Result<MsgReceived<T>>>,
|
|
|
+ fn new(endpoint: Endpoint, recv_addr: Arc<BlockAddr>) -> Result<Self> {
|
|
|
+ let (stop_tx, stop_rx) = broadcast::channel(1);
|
|
|
+ let (msg_tx, msg_rx) = mpsc::channel(Self::MSG_BUF_SZ);
|
|
|
+ tokio::spawn(Self::server_loop(endpoint, msg_tx, stop_rx));
|
|
|
+ Ok(Self {
|
|
|
+ recv_addr,
|
|
|
+ stop_tx,
|
|
|
+ stream: ReceiverStream::new(msg_rx),
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
- /// Causes the current function to return if the given `rx` has received a stop signal.
|
|
|
- macro_rules! check_stop {
|
|
|
- ($rx:expr) => {
|
|
|
- match $rx.try_recv() {
|
|
|
- Ok(_) => return,
|
|
|
- Err(err) => {
|
|
|
- if let TryRecvError::Closed = err {
|
|
|
- return;
|
|
|
- }
|
|
|
- }
|
|
|
+ async fn server_loop(
|
|
|
+ endpoint: Endpoint,
|
|
|
+ msg_tx: mpsc::Sender<Result<MsgReceived<T>>>,
|
|
|
+ mut stop_rx: broadcast::Receiver<()>,
|
|
|
+ ) {
|
|
|
+ loop {
|
|
|
+ let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv());
|
|
|
+ let connection = unwrap_or_continue!(connecting.await, |err| error!(
|
|
|
+ "error accepting QUIC connection: {err}"
|
|
|
+ ));
|
|
|
+ tokio::spawn(Self::handle_connection(
|
|
|
+ connection,
|
|
|
+ msg_tx.clone(),
|
|
|
+ stop_rx.resubscribe(),
|
|
|
+ ));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Returns the path the client is bound to.
|
|
|
+ fn client_path(peer_identity: Option<Box<dyn Any>>) -> Result<Arc<BlockPath>> {
|
|
|
+ let peer_identity =
|
|
|
+ peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?;
|
|
|
+ let client_certs = peer_identity
|
|
|
+ .downcast::<Vec<rustls::Certificate>>()
|
|
|
+ .map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?;
|
|
|
+ let first = client_certs
|
|
|
+ .first()
|
|
|
+ .ok_or_else(|| bterr!("no certificates were presented by the client"))?;
|
|
|
+ let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?;
|
|
|
+ Ok(Arc::new(writecap.bind_path()))
|
|
|
+ }
|
|
|
+
|
|
|
+ async fn handle_connection(
|
|
|
+ connection: Connection,
|
|
|
+ msg_tx: mpsc::Sender<Result<MsgReceived<T>>>,
|
|
|
+ mut stop_rx: broadcast::Receiver<()>,
|
|
|
+ ) {
|
|
|
+ let client_path = unwrap_or_return!(
|
|
|
+ Self::client_path(connection.peer_identity()),
|
|
|
+ |err| error!("failed to get client path from peer identity: {err}")
|
|
|
+ );
|
|
|
+ let (send_stream, recv_stream) = unwrap_or_return!(
|
|
|
+ connection.accept_bi().await,
|
|
|
+ |err| error!("error accepting receive stream: {err}")
|
|
|
+ );
|
|
|
+ let replier = Replier::new(send_stream);
|
|
|
+ let mut msg_stream = FramedRead::new(recv_stream, MsgDecoder::<Envelope<T>>::new());
|
|
|
+ loop {
|
|
|
+ let decode_result = await_or_stop!(msg_stream.next(), stop_rx.recv());
|
|
|
+ if let Err(ref err) = decode_result {
|
|
|
+ error!("msg_stream produced an error: {err}");
|
|
|
}
|
|
|
- };
|
|
|
- }
|
|
|
-
|
|
|
- impl<T: DeserializeOwned + Send + 'static> QuicReceiver<T> {
|
|
|
- /// The size of the buffer to store received messages in.
|
|
|
- const MSG_BUF_SZ: usize = 64;
|
|
|
-
|
|
|
- fn new(endpoint: Endpoint, recv_addr: Arc<BlockAddr>) -> Result<Self> {
|
|
|
- let (stop_tx, mut stop_rx) = broadcast::channel(1);
|
|
|
- let (msg_tx, msg_rx) = mpsc::channel(Self::MSG_BUF_SZ);
|
|
|
- tokio::spawn(async move {
|
|
|
- loop {
|
|
|
- check_stop!(stop_rx);
|
|
|
- let connecting = match endpoint.accept().await {
|
|
|
- Some(connection) => connection,
|
|
|
- None => break,
|
|
|
- };
|
|
|
- let connection = match connecting.await {
|
|
|
- Ok(connection) => connection,
|
|
|
- Err(err) => {
|
|
|
- error!("error accepting QUIC connection: {err}");
|
|
|
- continue;
|
|
|
- }
|
|
|
- };
|
|
|
- let conn_msg_tx = msg_tx.clone();
|
|
|
- let mut conn_stop_rx = stop_rx.resubscribe();
|
|
|
- let client_certs = match connection.peer_identity() {
|
|
|
- Some(peer_certs) => {
|
|
|
- match peer_certs.downcast::<Vec<rustls::Certificate>>() {
|
|
|
- Ok(peer_certs) => peer_certs,
|
|
|
- Err(err) => {
|
|
|
- error!("failed to downcast peer certificate chain: {:?}", err);
|
|
|
- continue;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- None => {
|
|
|
- error!("connection did not contain a peer identity");
|
|
|
- continue;
|
|
|
- }
|
|
|
- };
|
|
|
- tokio::spawn(async move {
|
|
|
- let client_path = {
|
|
|
- // There must be at least one certificate because the handshake was
|
|
|
- // successful.
|
|
|
- let first = client_certs.first().unwrap();
|
|
|
- let (writecap, ..) =
|
|
|
- match Writecap::from_cert_chain(first, &client_certs[1..]) {
|
|
|
- Ok(pair) => pair,
|
|
|
- Err(err) => {
|
|
|
- error!(
|
|
|
- "failed to create writecap from certificate chain: {err}"
|
|
|
- );
|
|
|
- return;
|
|
|
- }
|
|
|
- };
|
|
|
- drop(client_certs);
|
|
|
- Arc::new(writecap.bind_path())
|
|
|
- };
|
|
|
- let recv_stream = match connection.accept_uni().await {
|
|
|
- Ok(recv_stream) => recv_stream,
|
|
|
- Err(err) => {
|
|
|
- error!("error accepting receive stream: {err}");
|
|
|
- return;
|
|
|
- }
|
|
|
- };
|
|
|
- let mut msg_stream = FramedRead::new(recv_stream, MsgDecoder::new());
|
|
|
- loop {
|
|
|
- check_stop!(conn_stop_rx);
|
|
|
- let result = match msg_stream.next().await {
|
|
|
- Some(result) => result,
|
|
|
- None => return,
|
|
|
- };
|
|
|
- if let Err(err) = conn_msg_tx
|
|
|
- .send(result.map(|e| MsgReceived::new(client_path.clone(), e)))
|
|
|
- .await
|
|
|
- {
|
|
|
- error!("error sending message to mpsc queue: {err}");
|
|
|
- }
|
|
|
- }
|
|
|
- });
|
|
|
- }
|
|
|
+ let msg_received = decode_result.map(|envelope| {
|
|
|
+ let replier = match envelope.kind {
|
|
|
+ MsgKind::Call => replier.clone(),
|
|
|
+ MsgKind::Send => Replier::empty(),
|
|
|
+ };
|
|
|
+ MsgReceived::new(client_path.clone(), envelope, replier)
|
|
|
});
|
|
|
- Ok(Self {
|
|
|
- recv_addr,
|
|
|
- stop_tx,
|
|
|
- stream: ReceiverStream::new(msg_rx),
|
|
|
- })
|
|
|
+ let send_result = msg_tx.send(msg_received).await;
|
|
|
+ if let Err(err) = send_result {
|
|
|
+ error!("error sending message to mpsc queue: {err}");
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- impl<T> Drop for QuicReceiver<T> {
|
|
|
- fn drop(&mut self) {
|
|
|
- // This result will be a failure if the tasks have already returned, which is not a
|
|
|
- // problem.
|
|
|
- let _ = self.stop_tx.send(());
|
|
|
- }
|
|
|
+impl<T> Drop for QuicReceiver<T> {
|
|
|
+ fn drop(&mut self) {
|
|
|
+ // This result will be a failure if the tasks have already returned, which is not a
|
|
|
+ // problem.
|
|
|
+ let _ = self.stop_tx.send(());
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- impl<T: DeserializeOwned + Send + 'static> Stream for QuicReceiver<T> {
|
|
|
- type Item = Result<MsgReceived<T>>;
|
|
|
+impl<T: CallRx> Stream for QuicReceiver<T> {
|
|
|
+ type Item = Result<MsgReceived<T>>;
|
|
|
|
|
|
- fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
|
- self.stream.poll_next_unpin(cx)
|
|
|
- }
|
|
|
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
|
+ self.stream.poll_next_unpin(cx)
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- impl<T: DeserializeOwned + Send + 'static> Receiver<T> for QuicReceiver<T> {
|
|
|
- fn addr(&self) -> &BlockAddr {
|
|
|
- &self.recv_addr
|
|
|
- }
|
|
|
+impl<T: CallRx> Receiver<T> for QuicReceiver<T> {
|
|
|
+ fn addr(&self) -> &BlockAddr {
|
|
|
+ &self.recv_addr
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- struct QuicSender {
|
|
|
+struct QuicSender {
|
|
|
+ addr: Arc<BlockAddr>,
|
|
|
+ sink: FramedWrite<SendStream, MsgEncoder>,
|
|
|
+ recv_stream: Mutex<RecvStream>,
|
|
|
+}
|
|
|
+
|
|
|
+impl QuicSender {
|
|
|
+ async fn from_endpoint(
|
|
|
+ endpoint: Endpoint,
|
|
|
addr: Arc<BlockAddr>,
|
|
|
- sink: FramedWrite<SendStream, MsgEncoder>,
|
|
|
- }
|
|
|
-
|
|
|
- impl QuicSender {
|
|
|
- async fn from_endpoint(
|
|
|
- endpoint: Endpoint,
|
|
|
- addr: Arc<BlockAddr>,
|
|
|
- resolver: Arc<CertResolver>,
|
|
|
- ) -> Result<Self> {
|
|
|
- let socket_addr = addr.socket_addr();
|
|
|
- let connecting = endpoint.connect_with(
|
|
|
- client_config(addr.path.clone(), resolver)?,
|
|
|
- socket_addr,
|
|
|
- // The ServerCertVerifier ensures we connect to the correct path.
|
|
|
- "UNIMPORTANT",
|
|
|
- )?;
|
|
|
- let connection = connecting.await?;
|
|
|
- let send_stream = connection.open_uni().await?;
|
|
|
- let sink = FramedWrite::new(send_stream, MsgEncoder::new());
|
|
|
- Ok(Self { addr, sink })
|
|
|
- }
|
|
|
+ resolver: Arc<CertResolver>,
|
|
|
+ ) -> Result<Self> {
|
|
|
+ let socket_addr = addr.socket_addr();
|
|
|
+ let connecting = endpoint.connect_with(
|
|
|
+ client_config(addr.path.clone(), resolver)?,
|
|
|
+ socket_addr,
|
|
|
+ // The ServerCertVerifier ensures we connect to the correct path.
|
|
|
+ "UNIMPORTANT",
|
|
|
+ )?;
|
|
|
+ let connection = connecting.await?;
|
|
|
+ let (send_stream, recv_stream) = connection.open_bi().await?;
|
|
|
+ let sink = FramedWrite::new(send_stream, MsgEncoder::new());
|
|
|
+ Ok(Self {
|
|
|
+ addr,
|
|
|
+ sink,
|
|
|
+ recv_stream: Mutex::new(recv_stream),
|
|
|
+ })
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- impl Sender for QuicSender {
|
|
|
- fn addr(&self) -> &BlockAddr {
|
|
|
- &self.addr
|
|
|
- }
|
|
|
-
|
|
|
- type SendFut<'a, T> = SendFut<'a, FramedWrite<SendStream, MsgEncoder>, Msg<T>>
|
|
|
- where T: 'a + Serialize + Send;
|
|
|
-
|
|
|
- fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: Msg<T>) -> Self::SendFut<'a, T> {
|
|
|
- self.sink.send(msg)
|
|
|
- }
|
|
|
-
|
|
|
- type FinishFut = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
|
|
|
-
|
|
|
- fn finish(mut self) -> Self::FinishFut {
|
|
|
- Box::pin(async move {
|
|
|
- let steam: &mut SendStream = self.sink.get_mut();
|
|
|
- steam.finish().await.map_err(|err| bterr!(err))
|
|
|
- })
|
|
|
- }
|
|
|
+/// TODO: Once the "Permit impl Trait in type aliases"
|
|
|
+/// https://github.com/rust-lang/rust/issues/63063
|
|
|
+/// feature lands the future types in this implementation should be rewritten to
|
|
|
+/// use it.
|
|
|
+impl Transmitter for QuicSender {
|
|
|
+ fn addr(&self) -> &BlockAddr {
|
|
|
+ &self.addr
|
|
|
}
|
|
|
|
|
|
- /// This is an identify function which allows you to specify a type parameter for the output
|
|
|
- /// of a future.
|
|
|
- /// This was needed to work around a failure in type inference for types with higher-rank
|
|
|
- /// lifetimes. Once this issue is resolved this can be removed:
|
|
|
- /// https://github.com/rust-lang/rust/issues/102211
|
|
|
- pub fn assert_send<'a, T>(
|
|
|
- fut: impl 'a + Future<Output = T> + Send,
|
|
|
- ) -> impl 'a + Future<Output = T> + Send {
|
|
|
- fut
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-#[cfg(test)]
|
|
|
-mod tests {
|
|
|
- use super::*;
|
|
|
-
|
|
|
- use btlib::{crypto::ConcreteCreds, Epoch, Principal, Principaled};
|
|
|
- use ctor::ctor;
|
|
|
- use std::{net::Ipv6Addr, time::Duration};
|
|
|
-
|
|
|
- #[ctor]
|
|
|
- fn setup_logging() {
|
|
|
- env_logger::init();
|
|
|
- }
|
|
|
-
|
|
|
- lazy_static! {
|
|
|
- static ref ROOT_CREDS: ConcreteCreds = ConcreteCreds::generate().unwrap();
|
|
|
- static ref NODE_CREDS: ConcreteCreds = {
|
|
|
- let mut creds = ConcreteCreds::generate().unwrap();
|
|
|
- let root_creds = &ROOT_CREDS;
|
|
|
- let writecap = root_creds
|
|
|
- .issue_writecap(
|
|
|
- creds.principal(),
|
|
|
- vec![],
|
|
|
- Epoch::now() + Duration::from_secs(3600),
|
|
|
- )
|
|
|
- .unwrap();
|
|
|
- creds.set_writecap(writecap);
|
|
|
- creds
|
|
|
- };
|
|
|
- static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal();
|
|
|
- }
|
|
|
+ type SendFut<'a, T> = SendFut<'a, FramedWrite<SendStream, MsgEncoder>, Envelope<T>>
|
|
|
+ where T: 'a + Serialize + Send;
|
|
|
|
|
|
- #[derive(Serialize, Deserialize)]
|
|
|
- enum MsgError {
|
|
|
- Unknown,
|
|
|
+ fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: T) -> Self::SendFut<'a, T> {
|
|
|
+ self.sink.send(Envelope::send(msg))
|
|
|
}
|
|
|
|
|
|
- #[derive(Deserialize)]
|
|
|
- enum BodyOwned {
|
|
|
- Ping,
|
|
|
- Success,
|
|
|
- Fail(MsgError),
|
|
|
- Read { offset: u64, size: u64 },
|
|
|
- Write { offset: u64, buf: Vec<u8> },
|
|
|
- }
|
|
|
+ type FinishFut = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
|
|
|
|
|
|
- #[derive(Serialize)]
|
|
|
- enum BodyRef<'a> {
|
|
|
- Ping,
|
|
|
- Success,
|
|
|
- Fail(MsgError),
|
|
|
- Read { offset: u64, size: u64 },
|
|
|
- Write { offset: u64, buf: &'a [u8] },
|
|
|
+ fn finish(mut self) -> Self::FinishFut {
|
|
|
+ Box::pin(async move {
|
|
|
+ let steam: &mut SendStream = self.sink.get_mut();
|
|
|
+ steam.finish().await.map_err(|err| bterr!(err))
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
- struct TestCase;
|
|
|
-
|
|
|
- impl TestCase {
|
|
|
- fn new() -> TestCase {
|
|
|
- Self
|
|
|
- }
|
|
|
+ type CallFut<'a, T> = Pin<Box<dyn 'a + Future<Output = Result<T::Reply>> + Send>>
|
|
|
+ where
|
|
|
+ T: 'a + CallTx;
|
|
|
|
|
|
- /// Returns a ([Sender], [Receiver]) pair for a process identified by the given integer.
|
|
|
- async fn new_process(&self) -> (impl Sender, impl Receiver<BodyOwned>) {
|
|
|
- let ip_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
|
|
|
- let mut creds = ConcreteCreds::generate().unwrap();
|
|
|
- let writecap = NODE_CREDS
|
|
|
- .issue_writecap(
|
|
|
- creds.principal(),
|
|
|
- vec![],
|
|
|
- Epoch::now() + Duration::from_secs(3600),
|
|
|
- )
|
|
|
- .unwrap();
|
|
|
- let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
|
|
|
- creds.set_writecap(writecap);
|
|
|
- let router = router(ip_addr, Arc::new(creds)).unwrap();
|
|
|
- let receiver = router.receiver::<BodyOwned>().await.unwrap();
|
|
|
- let sender = router.sender(addr).await.unwrap();
|
|
|
- (sender, receiver)
|
|
|
- }
|
|
|
+ fn call<'a, T>(&'a mut self, msg: T) -> Self::CallFut<'a, T>
|
|
|
+ where
|
|
|
+ T: 'a + CallTx,
|
|
|
+ {
|
|
|
+ Box::pin(async move {
|
|
|
+ self.sink.send(Envelope::call(msg)).await?;
|
|
|
+ let mut guard = self.recv_stream.lock().await;
|
|
|
+ let mut source = FramedRead::new(guard.deref_mut(), MsgDecoder::<T::Reply>::new());
|
|
|
+ source
|
|
|
+ .next()
|
|
|
+ .await
|
|
|
+ .ok_or_else(|| bterr!("server hung up before sending reply"))?
|
|
|
+ })
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- #[tokio::test]
|
|
|
- async fn message_received_is_message_sent() {
|
|
|
- let case = TestCase::new();
|
|
|
- let (mut sender, mut receiver) = case.new_process().await;
|
|
|
-
|
|
|
- sender.send_with_rand_id(BodyRef::Ping).await.unwrap();
|
|
|
- let actual = receiver.next().await.unwrap().unwrap();
|
|
|
-
|
|
|
- let matched = if let BodyOwned::Ping = actual.msg.body {
|
|
|
- true
|
|
|
- } else {
|
|
|
- false
|
|
|
- };
|
|
|
- assert!(matched);
|
|
|
- }
|
|
|
-
|
|
|
- #[tokio::test]
|
|
|
- async fn message_received_from_path_is_correct() {
|
|
|
- let case = TestCase::new();
|
|
|
- let (mut sender, mut receiver) = case.new_process().await;
|
|
|
-
|
|
|
- sender.send_with_rand_id(BodyRef::Ping).await.unwrap();
|
|
|
- let actual = receiver.next().await.unwrap().unwrap();
|
|
|
-
|
|
|
- assert_eq!(receiver.addr().path(), actual.from.as_ref());
|
|
|
- }
|
|
|
-
|
|
|
- #[tokio::test]
|
|
|
- async fn ping_pong() {
|
|
|
- let case = TestCase::new();
|
|
|
- let (mut sender_one, mut receiver_one) = case.new_process().await;
|
|
|
- let (mut sender_two, mut receiver_two) = case.new_process().await;
|
|
|
-
|
|
|
- tokio::spawn(async move {
|
|
|
- let received = receiver_one.next().await.unwrap().unwrap();
|
|
|
- let reply_body = if let BodyOwned::Ping = received.msg.body {
|
|
|
- BodyRef::Success
|
|
|
- } else {
|
|
|
- BodyRef::Fail(MsgError::Unknown)
|
|
|
- };
|
|
|
- let fut = assert_send::<'_, Result<()>>(sender_two.send_with_rand_id(reply_body));
|
|
|
- fut.await.unwrap();
|
|
|
- sender_two.finish().await.unwrap();
|
|
|
- });
|
|
|
-
|
|
|
- sender_one.send_with_rand_id(BodyRef::Ping).await.unwrap();
|
|
|
- let reply = receiver_two.next().await.unwrap().unwrap();
|
|
|
- let matched = if let BodyOwned::Success = reply.msg.body {
|
|
|
- true
|
|
|
- } else {
|
|
|
- false
|
|
|
- };
|
|
|
- assert!(matched);
|
|
|
- assert_eq!(receiver_two.addr().path(), reply.from.as_ref());
|
|
|
- }
|
|
|
-
|
|
|
- #[tokio::test]
|
|
|
- async fn read_write() {
|
|
|
- let case = TestCase::new();
|
|
|
- let (mut sender_one, mut receiver_one) = case.new_process().await;
|
|
|
- let (mut sender_two, mut receiver_two) = case.new_process().await;
|
|
|
-
|
|
|
- let handle = tokio::spawn(async move {
|
|
|
- let data: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
|
|
|
- let received = receiver_one.next().await.unwrap().unwrap();
|
|
|
- let reply_body = if let BodyOwned::Read { offset, size } = received.msg.body {
|
|
|
- let offset: usize = offset.try_into().unwrap();
|
|
|
- let size: usize = size.try_into().unwrap();
|
|
|
- let end: usize = offset + size;
|
|
|
- BodyRef::Write {
|
|
|
- offset: offset as u64,
|
|
|
- buf: &data[offset..end],
|
|
|
- }
|
|
|
- } else {
|
|
|
- BodyRef::Fail(MsgError::Unknown)
|
|
|
- };
|
|
|
- let msg = Msg::new(received.msg.id, reply_body);
|
|
|
- let fut = assert_send::<'_, Result<()>>(sender_two.send(msg));
|
|
|
- fut.await.unwrap();
|
|
|
- sender_two.finish().await.unwrap();
|
|
|
- });
|
|
|
-
|
|
|
- sender_one
|
|
|
- .send_with_rand_id(BodyRef::Read { offset: 2, size: 2 })
|
|
|
- .await
|
|
|
- .unwrap();
|
|
|
- handle.await.unwrap();
|
|
|
- let reply = receiver_two.next().await.unwrap().unwrap();
|
|
|
- if let BodyOwned::Write { offset, buf } = reply.msg.body {
|
|
|
- assert_eq!(2, offset);
|
|
|
- assert_eq!([2, 3].as_slice(), buf.as_slice());
|
|
|
- } else {
|
|
|
- panic!("reply was not the right type");
|
|
|
- };
|
|
|
- }
|
|
|
+/// This is an identify function which allows you to specify a type parameter for the output
|
|
|
+/// of a future.
|
|
|
+/// TODO: This was needed to work around a failure in type inference for types with higher-rank
|
|
|
+/// lifetimes. Once this issue is resolved this can be removed:
|
|
|
+/// https://github.com/rust-lang/rust/issues/102211
|
|
|
+pub fn assert_send<'a, T>(
|
|
|
+ fut: impl 'a + Future<Output = T> + Send,
|
|
|
+) -> impl 'a + Future<Output = T> + Send {
|
|
|
+ fut
|
|
|
}
|