transmitter.rs 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. // SPDX-License-Identifier: AGPL-3.0-or-later
  2. //! Contains the [Transmitter] type.
  3. use std::{
  4. future::{ready, Future, Ready},
  5. io,
  6. marker::PhantomData,
  7. net::{IpAddr, Ipv6Addr, SocketAddr},
  8. sync::Arc,
  9. };
  10. use btlib::{bterr, crypto::Creds};
  11. use bytes::BytesMut;
  12. use futures::SinkExt;
  13. use quinn::{Connection, Endpoint, RecvStream, SendStream};
  14. use serde::{de::DeserializeOwned, Serialize};
  15. use tokio::sync::Mutex;
  16. use tokio_util::codec::{Framed, FramedParts};
  17. use crate::{
  18. receiver::{Envelope, ReplyEnvelope},
  19. serialization::{CallbackFramed, MsgEncoder},
  20. tls::{client_config, CertResolver},
  21. BlockAddr, CallMsg, DeserCallback, Result, SendMsg,
  22. };
  23. /// A type which can be used to transmit messages over the network to a [crate::Receiver].
  24. pub struct Transmitter {
  25. addr: Arc<BlockAddr>,
  26. connection: Connection,
  27. send_parts: Mutex<Option<FramedParts<SendStream, MsgEncoder>>>,
  28. recv_buf: Mutex<Option<BytesMut>>,
  29. }
  30. macro_rules! cleanup_on_err {
  31. ($result:expr, $guard:ident, $parts:ident) => {
  32. match $result {
  33. Ok(value) => value,
  34. Err(err) => {
  35. *$guard = Some($parts);
  36. return Err(err.into());
  37. }
  38. }
  39. };
  40. }
  41. impl Transmitter {
  42. pub async fn new<C: 'static + Creds + Send + Sync>(
  43. addr: Arc<BlockAddr>,
  44. creds: Arc<C>,
  45. ) -> Result<Transmitter> {
  46. let resolver = Arc::new(CertResolver::new(creds)?);
  47. let endpoint = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))?;
  48. Transmitter::from_endpoint(endpoint, addr, resolver).await
  49. }
  50. pub(crate) async fn from_endpoint(
  51. endpoint: Endpoint,
  52. addr: Arc<BlockAddr>,
  53. resolver: Arc<CertResolver>,
  54. ) -> Result<Self> {
  55. let socket_addr = addr.socket_addr()?;
  56. let connecting = endpoint.connect_with(
  57. client_config(addr.path().clone(), resolver)?,
  58. socket_addr,
  59. // The ServerCertVerifier ensures we connect to the correct path.
  60. "UNIMPORTANT",
  61. )?;
  62. let connection = connecting.await?;
  63. let send_parts = Mutex::new(None);
  64. let recv_buf = Mutex::new(Some(BytesMut::new()));
  65. Ok(Self {
  66. addr,
  67. connection,
  68. send_parts,
  69. recv_buf,
  70. })
  71. }
  72. async fn transmit<T: Serialize>(&self, envelope: Envelope<T>) -> Result<RecvStream> {
  73. let mut guard = self.send_parts.lock().await;
  74. let (send_stream, recv_stream) = self.connection.open_bi().await?;
  75. let parts = match guard.take() {
  76. Some(mut parts) => {
  77. parts.io = send_stream;
  78. parts
  79. }
  80. None => FramedParts::new::<Envelope<T>>(send_stream, MsgEncoder::new()),
  81. };
  82. let mut sink = Framed::from_parts(parts);
  83. let result = sink.send(envelope).await;
  84. let parts = sink.into_parts();
  85. cleanup_on_err!(result, guard, parts);
  86. *guard = Some(parts);
  87. Ok(recv_stream)
  88. }
  89. /// Returns the address that this instance is transmitting to.
  90. pub fn addr(&self) -> &Arc<BlockAddr> {
  91. &self.addr
  92. }
  93. /// Transmit a message to the connected [crate::Receiver] without waiting for a reply.
  94. pub async fn send<'ser, 'de, T>(&self, msg: T) -> Result<()>
  95. where
  96. T: 'ser + SendMsg<'de>,
  97. {
  98. self.transmit(Envelope::send(msg)).await?;
  99. Ok(())
  100. }
  101. /// Transmit a message to the connected [crate::Receiver], waits for a reply, then calls the given
  102. /// [DeserCallback] with the deserialized reply.
  103. ///
  104. /// ## WARNING
  105. /// The callback must be such that `F::Arg<'a> = T::Reply<'a>` for any `'a`. If this
  106. /// is violated, then a deserilization error will occur at runtime.
  107. ///
  108. /// ## TODO
  109. /// This issue needs to be fixed. Due to the fact that
  110. /// `F::Arg` is a Generic Associated Type (GAT) I have been unable to express this constraint in
  111. /// the where clause of this method. I'm not sure if the errors I've encountered are due to a
  112. /// lack of understanding on my part or due to the current limitations of the borrow checker in
  113. /// its handling of GATs.
  114. pub async fn call<'ser, 'de, T, F>(&self, msg: T, callback: F) -> Result<F::Return>
  115. where
  116. T: 'ser + CallMsg<'de>,
  117. F: 'static + Send + DeserCallback,
  118. {
  119. let recv_stream = self.transmit(Envelope::call(msg)).await?;
  120. let mut guard = self.recv_buf.lock().await;
  121. let buffer = guard.take().unwrap();
  122. let mut callback_framed = CallbackFramed::from_parts(recv_stream, buffer);
  123. let result = callback_framed
  124. .next(ReplyCallback::new(callback))
  125. .await
  126. .ok_or_else(|| bterr!("server hung up before sending reply"));
  127. let (_, buffer) = callback_framed.into_parts();
  128. let output = cleanup_on_err!(result, guard, buffer);
  129. *guard = Some(buffer);
  130. output?
  131. }
  132. /// Transmits a message to the connected [crate::Receiver], waits for a reply, then passes back the
  133. /// the reply to the caller. This only works for messages whose reply doesn't borrow any data,
  134. /// otherwise the `call` method must be used.
  135. pub async fn call_through<'ser, T>(&self, msg: T) -> Result<T::Reply<'static>>
  136. where
  137. // TODO: CallMsg must take a static lifetime until this issue is resolved:
  138. // https://github.com/rust-lang/rust/issues/103532
  139. T: 'ser + CallMsg<'static>,
  140. T::Reply<'static>: 'static + Send + Sync + DeserializeOwned,
  141. {
  142. self.call(msg, Passthrough::new()).await
  143. }
  144. }
  145. struct ReplyCallback<F> {
  146. inner: F,
  147. }
  148. impl<F> ReplyCallback<F> {
  149. fn new(inner: F) -> Self {
  150. Self { inner }
  151. }
  152. }
  153. impl<F: 'static + Send + DeserCallback> DeserCallback for ReplyCallback<F> {
  154. type Arg<'de> = ReplyEnvelope<F::Arg<'de>>;
  155. type Return = Result<F::Return>;
  156. type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send;
  157. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  158. async move {
  159. match arg {
  160. ReplyEnvelope::Ok(msg) => Ok(self.inner.call(msg).await),
  161. ReplyEnvelope::Err { message, os_code } => {
  162. if let Some(os_code) = os_code {
  163. let err = bterr!(io::Error::from_raw_os_error(os_code)).context(message);
  164. Err(err)
  165. } else {
  166. Err(bterr!(message))
  167. }
  168. }
  169. }
  170. }
  171. }
  172. }
  173. pub struct Passthrough<T> {
  174. phantom: PhantomData<T>,
  175. }
  176. impl<T> Passthrough<T> {
  177. pub fn new() -> Self {
  178. Self {
  179. phantom: PhantomData,
  180. }
  181. }
  182. }
  183. impl<T> Default for Passthrough<T> {
  184. fn default() -> Self {
  185. Self::new()
  186. }
  187. }
  188. impl<T> Clone for Passthrough<T> {
  189. fn clone(&self) -> Self {
  190. Self::new()
  191. }
  192. }
  193. impl<T: 'static + Send + DeserializeOwned> DeserCallback for Passthrough<T> {
  194. type Arg<'de> = T;
  195. type Return = T;
  196. type CallFut<'de> = Ready<T>;
  197. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  198. ready(arg)
  199. }
  200. }