123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- // SPDX-License-Identifier: AGPL-3.0-or-later
- //! Contains the [Transmitter] type.
- use std::{
- future::{ready, Future, Ready},
- io,
- marker::PhantomData,
- net::{IpAddr, Ipv6Addr, SocketAddr},
- sync::Arc,
- };
- use btlib::{bterr, crypto::Creds};
- use bytes::BytesMut;
- use futures::SinkExt;
- use quinn::{Connection, Endpoint, RecvStream, SendStream};
- use serde::{de::DeserializeOwned, Serialize};
- use tokio::sync::Mutex;
- use tokio_util::codec::{Framed, FramedParts};
- use crate::{
- receiver::{Envelope, ReplyEnvelope},
- serialization::{CallbackFramed, MsgEncoder},
- tls::{client_config, CertResolver},
- BlockAddr, CallMsg, DeserCallback, Result, SendMsg,
- };
- /// A type which can be used to transmit messages over the network to a [crate::Receiver].
- pub struct Transmitter {
- addr: Arc<BlockAddr>,
- connection: Connection,
- send_parts: Mutex<Option<FramedParts<SendStream, MsgEncoder>>>,
- recv_buf: Mutex<Option<BytesMut>>,
- }
- macro_rules! cleanup_on_err {
- ($result:expr, $guard:ident, $parts:ident) => {
- match $result {
- Ok(value) => value,
- Err(err) => {
- *$guard = Some($parts);
- return Err(err.into());
- }
- }
- };
- }
- impl Transmitter {
- pub async fn new<C: 'static + Creds + Send + Sync>(
- addr: Arc<BlockAddr>,
- creds: Arc<C>,
- ) -> Result<Transmitter> {
- let resolver = Arc::new(CertResolver::new(creds)?);
- let endpoint = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))?;
- Transmitter::from_endpoint(endpoint, addr, resolver).await
- }
- pub(crate) async fn from_endpoint(
- endpoint: Endpoint,
- addr: Arc<BlockAddr>,
- resolver: Arc<CertResolver>,
- ) -> Result<Self> {
- let socket_addr = addr.socket_addr()?;
- let connecting = endpoint.connect_with(
- client_config(addr.path().clone(), resolver)?,
- socket_addr,
- // The ServerCertVerifier ensures we connect to the correct path.
- "UNIMPORTANT",
- )?;
- let connection = connecting.await?;
- let send_parts = Mutex::new(None);
- let recv_buf = Mutex::new(Some(BytesMut::new()));
- Ok(Self {
- addr,
- connection,
- send_parts,
- recv_buf,
- })
- }
- async fn transmit<T: Serialize>(&self, envelope: Envelope<T>) -> Result<RecvStream> {
- let mut guard = self.send_parts.lock().await;
- let (send_stream, recv_stream) = self.connection.open_bi().await?;
- let parts = match guard.take() {
- Some(mut parts) => {
- parts.io = send_stream;
- parts
- }
- None => FramedParts::new::<Envelope<T>>(send_stream, MsgEncoder::new()),
- };
- let mut sink = Framed::from_parts(parts);
- let result = sink.send(envelope).await;
- let parts = sink.into_parts();
- cleanup_on_err!(result, guard, parts);
- *guard = Some(parts);
- Ok(recv_stream)
- }
- /// Returns the address that this instance is transmitting to.
- pub fn addr(&self) -> &Arc<BlockAddr> {
- &self.addr
- }
- /// Transmit a message to the connected [crate::Receiver] without waiting for a reply.
- pub async fn send<'ser, 'de, T>(&self, msg: T) -> Result<()>
- where
- T: 'ser + SendMsg<'de>,
- {
- self.transmit(Envelope::send(msg)).await?;
- Ok(())
- }
- /// Transmit a message to the connected [crate::Receiver], waits for a reply, then calls the given
- /// [DeserCallback] with the deserialized reply.
- ///
- /// ## WARNING
- /// The callback must be such that `F::Arg<'a> = T::Reply<'a>` for any `'a`. If this
- /// is violated, then a deserilization error will occur at runtime.
- ///
- /// ## TODO
- /// This issue needs to be fixed. Due to the fact that
- /// `F::Arg` is a Generic Associated Type (GAT) I have been unable to express this constraint in
- /// the where clause of this method. I'm not sure if the errors I've encountered are due to a
- /// lack of understanding on my part or due to the current limitations of the borrow checker in
- /// its handling of GATs.
- pub async fn call<'ser, 'de, T, F>(&self, msg: T, callback: F) -> Result<F::Return>
- where
- T: 'ser + CallMsg<'de>,
- F: 'static + Send + DeserCallback,
- {
- let recv_stream = self.transmit(Envelope::call(msg)).await?;
- let mut guard = self.recv_buf.lock().await;
- let buffer = guard.take().unwrap();
- let mut callback_framed = CallbackFramed::from_parts(recv_stream, buffer);
- let result = callback_framed
- .next(ReplyCallback::new(callback))
- .await
- .ok_or_else(|| bterr!("server hung up before sending reply"));
- let (_, buffer) = callback_framed.into_parts();
- let output = cleanup_on_err!(result, guard, buffer);
- *guard = Some(buffer);
- output?
- }
- /// Transmits a message to the connected [crate::Receiver], waits for a reply, then passes back the
- /// the reply to the caller. This only works for messages whose reply doesn't borrow any data,
- /// otherwise the `call` method must be used.
- pub async fn call_through<'ser, T>(&self, msg: T) -> Result<T::Reply<'static>>
- where
- // TODO: CallMsg must take a static lifetime until this issue is resolved:
- // https://github.com/rust-lang/rust/issues/103532
- T: 'ser + CallMsg<'static>,
- T::Reply<'static>: 'static + Send + Sync + DeserializeOwned,
- {
- self.call(msg, Passthrough::new()).await
- }
- }
- struct ReplyCallback<F> {
- inner: F,
- }
- impl<F> ReplyCallback<F> {
- fn new(inner: F) -> Self {
- Self { inner }
- }
- }
- impl<F: 'static + Send + DeserCallback> DeserCallback for ReplyCallback<F> {
- type Arg<'de> = ReplyEnvelope<F::Arg<'de>>;
- type Return = Result<F::Return>;
- type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send;
- fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
- async move {
- match arg {
- ReplyEnvelope::Ok(msg) => Ok(self.inner.call(msg).await),
- ReplyEnvelope::Err { message, os_code } => {
- if let Some(os_code) = os_code {
- let err = bterr!(io::Error::from_raw_os_error(os_code)).context(message);
- Err(err)
- } else {
- Err(bterr!(message))
- }
- }
- }
- }
- }
- }
- pub struct Passthrough<T> {
- phantom: PhantomData<T>,
- }
- impl<T> Passthrough<T> {
- pub fn new() -> Self {
- Self {
- phantom: PhantomData,
- }
- }
- }
- impl<T> Default for Passthrough<T> {
- fn default() -> Self {
- Self::new()
- }
- }
- impl<T> Clone for Passthrough<T> {
- fn clone(&self) -> Self {
- Self::new()
- }
- }
- impl<T: 'static + Send + DeserializeOwned> DeserCallback for Passthrough<T> {
- type Arg<'de> = T;
- type Return = T;
- type CallFut<'de> = Ready<T>;
- fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
- ready(arg)
- }
- }
|