// SPDX-License-Identifier: AGPL-3.0-or-later //! Contains the [Transmitter] type. use std::{ future::{ready, Future, Ready}, io, marker::PhantomData, net::{IpAddr, Ipv6Addr, SocketAddr}, sync::Arc, }; use btlib::{bterr, crypto::Creds}; use bytes::BytesMut; use futures::SinkExt; use quinn::{Connection, Endpoint, RecvStream, SendStream}; use serde::{de::DeserializeOwned, Serialize}; use tokio::sync::Mutex; use tokio_util::codec::{Framed, FramedParts}; use crate::{ receiver::{Envelope, ReplyEnvelope}, serialization::{CallbackFramed, MsgEncoder}, tls::{client_config, CertResolver}, BlockAddr, CallMsg, DeserCallback, Result, SendMsg, }; /// A type which can be used to transmit messages over the network to a [crate::Receiver]. pub struct Transmitter { addr: Arc, connection: Connection, send_parts: Mutex>>, recv_buf: Mutex>, } macro_rules! cleanup_on_err { ($result:expr, $guard:ident, $parts:ident) => { match $result { Ok(value) => value, Err(err) => { *$guard = Some($parts); return Err(err.into()); } } }; } impl Transmitter { pub async fn new( addr: Arc, creds: Arc, ) -> Result { let resolver = Arc::new(CertResolver::new(creds)?); let endpoint = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))?; Transmitter::from_endpoint(endpoint, addr, resolver).await } pub(crate) async fn from_endpoint( endpoint: Endpoint, addr: Arc, resolver: Arc, ) -> Result { let socket_addr = addr.socket_addr()?; let connecting = endpoint.connect_with( client_config(addr.path().clone(), resolver)?, socket_addr, // The ServerCertVerifier ensures we connect to the correct path. "UNIMPORTANT", )?; let connection = connecting.await?; let send_parts = Mutex::new(None); let recv_buf = Mutex::new(Some(BytesMut::new())); Ok(Self { addr, connection, send_parts, recv_buf, }) } async fn transmit(&self, envelope: Envelope) -> Result { let mut guard = self.send_parts.lock().await; let (send_stream, recv_stream) = self.connection.open_bi().await?; let parts = match guard.take() { Some(mut parts) => { parts.io = send_stream; parts } None => FramedParts::new::>(send_stream, MsgEncoder::new()), }; let mut sink = Framed::from_parts(parts); let result = sink.send(envelope).await; let parts = sink.into_parts(); cleanup_on_err!(result, guard, parts); *guard = Some(parts); Ok(recv_stream) } /// Returns the address that this instance is transmitting to. pub fn addr(&self) -> &Arc { &self.addr } /// Transmit a message to the connected [crate::Receiver] without waiting for a reply. pub async fn send<'ser, 'de, T>(&self, msg: T) -> Result<()> where T: 'ser + SendMsg<'de>, { self.transmit(Envelope::send(msg)).await?; Ok(()) } /// Transmit a message to the connected [crate::Receiver], waits for a reply, then calls the given /// [DeserCallback] with the deserialized reply. /// /// ## WARNING /// The callback must be such that `F::Arg<'a> = T::Reply<'a>` for any `'a`. If this /// is violated, then a deserilization error will occur at runtime. /// /// ## TODO /// This issue needs to be fixed. Due to the fact that /// `F::Arg` is a Generic Associated Type (GAT) I have been unable to express this constraint in /// the where clause of this method. I'm not sure if the errors I've encountered are due to a /// lack of understanding on my part or due to the current limitations of the borrow checker in /// its handling of GATs. pub async fn call<'ser, 'de, T, F>(&self, msg: T, callback: F) -> Result where T: 'ser + CallMsg<'de>, F: 'static + Send + DeserCallback, { let recv_stream = self.transmit(Envelope::call(msg)).await?; let mut guard = self.recv_buf.lock().await; let buffer = guard.take().unwrap(); let mut callback_framed = CallbackFramed::from_parts(recv_stream, buffer); let result = callback_framed .next(ReplyCallback::new(callback)) .await .ok_or_else(|| bterr!("server hung up before sending reply")); let (_, buffer) = callback_framed.into_parts(); let output = cleanup_on_err!(result, guard, buffer); *guard = Some(buffer); output? } /// Transmits a message to the connected [crate::Receiver], waits for a reply, then passes back the /// the reply to the caller. This only works for messages whose reply doesn't borrow any data, /// otherwise the `call` method must be used. pub async fn call_through<'ser, T>(&self, msg: T) -> Result> where // TODO: CallMsg must take a static lifetime until this issue is resolved: // https://github.com/rust-lang/rust/issues/103532 T: 'ser + CallMsg<'static>, T::Reply<'static>: 'static + Send + Sync + DeserializeOwned, { self.call(msg, Passthrough::new()).await } } struct ReplyCallback { inner: F, } impl ReplyCallback { fn new(inner: F) -> Self { Self { inner } } } impl DeserCallback for ReplyCallback { type Arg<'de> = ReplyEnvelope>; type Return = Result; type CallFut<'de> = impl 'de + Future + Send; fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> { async move { match arg { ReplyEnvelope::Ok(msg) => Ok(self.inner.call(msg).await), ReplyEnvelope::Err { message, os_code } => { if let Some(os_code) = os_code { let err = bterr!(io::Error::from_raw_os_error(os_code)).context(message); Err(err) } else { Err(bterr!(message)) } } } } } } pub struct Passthrough { phantom: PhantomData, } impl Passthrough { pub fn new() -> Self { Self { phantom: PhantomData, } } } impl Default for Passthrough { fn default() -> Self { Self::new() } } impl Clone for Passthrough { fn clone(&self) -> Self { Self::new() } } impl DeserCallback for Passthrough { type Arg<'de> = T; type Return = T; type CallFut<'de> = Ready; fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> { ready(arg) } }