receiver.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. // SPDX-License-Identifier: AGPL-3.0-or-later
  2. //! Types used for receiving messages over the network. Chiefly, the [Receiver] type.
  3. use std::{
  4. any::Any,
  5. future::Future,
  6. io,
  7. net::IpAddr,
  8. sync::{Arc, Mutex as StdMutex},
  9. };
  10. use btlib::{bterr, crypto::Creds, error::DisplayErr, BlockPath, Writecap};
  11. use futures::{FutureExt, SinkExt};
  12. use log::{debug, error};
  13. use quinn::{Connection, ConnectionError, Endpoint, RecvStream, SendStream};
  14. use serde::{Deserialize, Serialize};
  15. use tokio::{
  16. select,
  17. sync::{broadcast, Mutex},
  18. task::JoinHandle,
  19. };
  20. use tokio_util::codec::FramedWrite;
  21. use crate::{
  22. serialization::{CallbackFramed, MsgEncoder},
  23. tls::{server_config, CertResolver},
  24. BlockAddr, CallMsg, DeserCallback, Result, Transmitter,
  25. };
  26. macro_rules! handle_err {
  27. ($result:expr, $on_err:expr, $control_flow:expr) => {
  28. match $result {
  29. Ok(inner) => inner,
  30. Err(err) => {
  31. $on_err(err);
  32. $control_flow;
  33. }
  34. }
  35. };
  36. }
  37. /// Unwraps the given result, or if the result is an error, returns from the enclosing function.
  38. macro_rules! unwrap_or_return {
  39. ($result:expr, $on_err:expr) => {
  40. handle_err!($result, $on_err, return)
  41. };
  42. ($result:expr) => {
  43. unwrap_or_return!($result, |err| error!("{err}"))
  44. };
  45. }
  46. /// Unwraps the given result, or if the result is an error, continues the enclosing loop.
  47. macro_rules! unwrap_or_continue {
  48. ($result:expr, $on_err:expr) => {
  49. handle_err!($result, $on_err, continue)
  50. };
  51. ($result:expr) => {
  52. unwrap_or_continue!($result, |err| error!("{err}"))
  53. };
  54. }
  55. /// Awaits its first argument, unless interrupted by its second argument, in which case the
  56. /// enclosing function returns. The second argument needs to be cancel safe, but the first
  57. /// need not be if it is discarded when the enclosing function returns (because losing messages
  58. /// from the first argument doesn't matter in this case).
  59. macro_rules! await_or_stop {
  60. ($future:expr, $stop_fut:expr) => {
  61. select! {
  62. Some(connecting) = $future => connecting,
  63. _ = $stop_fut => break,
  64. }
  65. };
  66. }
  67. /// Type which receives messages sent over the network sent by a [Transmitter].
  68. pub struct Receiver {
  69. recv_addr: Arc<BlockAddr>,
  70. stop_tx: broadcast::Sender<()>,
  71. endpoint: Endpoint,
  72. resolver: Arc<CertResolver>,
  73. join_handle: StdMutex<Option<JoinHandle<()>>>,
  74. }
  75. impl Receiver {
  76. /// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of
  77. /// the [Writecap] in the given credentials. The returned type can be used to make
  78. /// [Transmitter]s for any path.
  79. pub fn new<C: 'static + Creds + Send + Sync, F: 'static + MsgCallback>(
  80. ip_addr: IpAddr,
  81. creds: Arc<C>,
  82. callback: F,
  83. ) -> Result<Receiver> {
  84. let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
  85. let recv_addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
  86. log::info!("starting Receiver with address {}", recv_addr);
  87. let socket_addr = recv_addr.socket_addr()?;
  88. let resolver = Arc::new(CertResolver::new(creds)?);
  89. let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
  90. let (stop_tx, stop_rx) = broadcast::channel(1);
  91. let join_handle = tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx));
  92. Ok(Self {
  93. recv_addr,
  94. stop_tx,
  95. endpoint,
  96. resolver,
  97. join_handle: StdMutex::new(Some(join_handle)),
  98. })
  99. }
  100. async fn server_loop<F: 'static + MsgCallback>(
  101. endpoint: Endpoint,
  102. callback: F,
  103. mut stop_rx: broadcast::Receiver<()>,
  104. ) {
  105. loop {
  106. let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv());
  107. let connection = unwrap_or_continue!(connecting.await, |err| error!(
  108. "error accepting QUIC connection: {err}"
  109. ));
  110. tokio::spawn(Self::handle_connection(
  111. connection,
  112. callback.clone(),
  113. stop_rx.resubscribe(),
  114. ));
  115. }
  116. }
  117. async fn handle_connection<F: 'static + MsgCallback>(
  118. connection: Connection,
  119. callback: F,
  120. mut stop_rx: broadcast::Receiver<()>,
  121. ) {
  122. let client_path = unwrap_or_return!(
  123. Self::client_path(connection.peer_identity()),
  124. |err| error!("failed to get client path from peer identity: {err}")
  125. );
  126. loop {
  127. let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv());
  128. let (send_stream, recv_stream) = match result {
  129. Ok(pair) => pair,
  130. Err(err) => match err {
  131. ConnectionError::ApplicationClosed(app) => {
  132. debug!("connection closed: {app}");
  133. return;
  134. }
  135. _ => {
  136. error!("error accepting stream: {err}");
  137. continue;
  138. }
  139. },
  140. };
  141. let client_path = client_path.clone();
  142. let callback = callback.clone();
  143. tokio::task::spawn(Self::handle_message(
  144. client_path,
  145. send_stream,
  146. recv_stream,
  147. callback,
  148. ));
  149. }
  150. }
  151. async fn handle_message<F: 'static + MsgCallback>(
  152. client_path: Arc<BlockPath>,
  153. send_stream: SendStream,
  154. recv_stream: RecvStream,
  155. callback: F,
  156. ) {
  157. let framed_msg = Arc::new(Mutex::new(FramedWrite::new(send_stream, MsgEncoder::new())));
  158. let callback =
  159. MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone());
  160. let mut msg_stream = CallbackFramed::new(recv_stream);
  161. let result = msg_stream
  162. .next(callback)
  163. .await
  164. .ok_or_else(|| bterr!("client closed stream before sending a message"));
  165. match unwrap_or_return!(result) {
  166. Err(err) => error!("msg_stream produced an error: {err}"),
  167. Ok(result) => {
  168. if let Err(err) = result {
  169. error!("callback returned an error: {err}");
  170. }
  171. }
  172. }
  173. }
  174. /// Returns the path the client is bound to.
  175. fn client_path(peer_identity: Option<Box<dyn Any>>) -> Result<Arc<BlockPath>> {
  176. let peer_identity =
  177. peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?;
  178. let client_certs = peer_identity
  179. .downcast::<Vec<rustls::Certificate>>()
  180. .map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?;
  181. let first = client_certs
  182. .first()
  183. .ok_or_else(|| bterr!("no certificates were presented by the client"))?;
  184. let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?;
  185. Ok(Arc::new(writecap.bind_path()))
  186. }
  187. /// The address at which messages will be received.
  188. pub fn addr(&self) -> &Arc<BlockAddr> {
  189. &self.recv_addr
  190. }
  191. /// Creates a [Transmitter] which is connected to the given address.
  192. pub async fn transmitter(&self, addr: Arc<BlockAddr>) -> Result<Transmitter> {
  193. Transmitter::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
  194. }
  195. /// Returns a future which completes when this [Receiver] has completed
  196. /// (which may be never).
  197. pub fn complete(&self) -> Result<JoinHandle<()>> {
  198. let mut guard = self.join_handle.lock().display_err()?;
  199. let handle = guard
  200. .take()
  201. .ok_or_else(|| bterr!("join handle has already been taken"))?;
  202. Ok(handle)
  203. }
  204. /// Sends a signal indicating that the task running the server loop should return.
  205. pub fn stop(&self) -> Result<()> {
  206. self.stop_tx.send(()).map(|_| ()).map_err(|err| err.into())
  207. }
  208. }
  209. impl Drop for Receiver {
  210. fn drop(&mut self) {
  211. // This result will be a failure if the tasks have already returned, which is not a
  212. // problem.
  213. let _ = self.stop_tx.send(());
  214. }
  215. }
  216. /// Trait for types which can be called to handle messages received over the network. The
  217. /// server loop in [Receiver] uses a type that implements this trait to react to messages it
  218. /// receives.
  219. pub trait MsgCallback: Clone + Send + Sync + Unpin {
  220. type Arg<'de>: CallMsg<'de>
  221. where
  222. Self: 'de;
  223. type CallFut<'de>: Future<Output = Result<()>> + Send
  224. where
  225. Self: 'de;
  226. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de>;
  227. }
  228. impl<T: MsgCallback> MsgCallback for &T {
  229. type Arg<'de> = T::Arg<'de> where Self: 'de;
  230. type CallFut<'de> = T::CallFut<'de> where Self: 'de;
  231. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
  232. (*self).call(arg)
  233. }
  234. }
  235. struct MsgRecvdCallback<F> {
  236. path: Arc<BlockPath>,
  237. replier: Replier,
  238. inner: F,
  239. }
  240. impl<F: MsgCallback> MsgRecvdCallback<F> {
  241. fn new(path: Arc<BlockPath>, framed_msg: Arc<Mutex<FramedMsg>>, inner: F) -> Self {
  242. Self {
  243. path,
  244. replier: Replier::new(framed_msg),
  245. inner,
  246. }
  247. }
  248. }
  249. impl<F: 'static + MsgCallback> DeserCallback for MsgRecvdCallback<F> {
  250. type Arg<'de> = Envelope<F::Arg<'de>> where Self: 'de;
  251. type Return = Result<()>;
  252. type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send where F: 'de, Self: 'de;
  253. fn call<'de>(&'de mut self, arg: Envelope<F::Arg<'de>>) -> Self::CallFut<'de> {
  254. let replier = match arg.kind {
  255. MsgKind::Call => Some(self.replier.clone()),
  256. MsgKind::Send => None,
  257. };
  258. async move {
  259. let result = self
  260. .inner
  261. .call(MsgReceived::new(self.path.clone(), arg, replier))
  262. .await;
  263. match result {
  264. Ok(value) => Ok(value),
  265. Err(err) => match err.downcast::<io::Error>() {
  266. Ok(err) => {
  267. self.replier
  268. .reply_err(err.to_string(), err.raw_os_error())
  269. .await
  270. }
  271. Err(err) => self.replier.reply_err(err.to_string(), None).await,
  272. },
  273. }
  274. }
  275. }
  276. }
  277. /// Indicates whether a message was sent using `call` or `send`.
  278. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
  279. enum MsgKind {
  280. /// This message expects exactly one reply.
  281. Call,
  282. /// This message expects exactly zero replies.
  283. Send,
  284. }
  285. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
  286. pub(crate) struct Envelope<T> {
  287. kind: MsgKind,
  288. msg: T,
  289. }
  290. impl<T> Envelope<T> {
  291. pub(crate) fn send(msg: T) -> Self {
  292. Self {
  293. msg,
  294. kind: MsgKind::Send,
  295. }
  296. }
  297. pub(crate) fn call(msg: T) -> Self {
  298. Self {
  299. msg,
  300. kind: MsgKind::Call,
  301. }
  302. }
  303. fn msg(&self) -> &T {
  304. &self.msg
  305. }
  306. }
  307. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
  308. pub(crate) enum ReplyEnvelope<T> {
  309. Ok(T),
  310. Err {
  311. message: String,
  312. os_code: Option<i32>,
  313. },
  314. }
  315. impl<T> ReplyEnvelope<T> {
  316. fn err(message: String, os_code: Option<i32>) -> Self {
  317. Self::Err { message, os_code }
  318. }
  319. }
  320. /// A message tagged with the block path that it was sent from.
  321. pub struct MsgReceived<T> {
  322. from: Arc<BlockPath>,
  323. msg: Envelope<T>,
  324. replier: Option<Replier>,
  325. }
  326. impl<T> MsgReceived<T> {
  327. fn new(from: Arc<BlockPath>, msg: Envelope<T>, replier: Option<Replier>) -> Self {
  328. Self { from, msg, replier }
  329. }
  330. pub fn into_parts(self) -> (Arc<BlockPath>, T, Option<Replier>) {
  331. (self.from, self.msg.msg, self.replier)
  332. }
  333. /// The path from which this message was received.
  334. pub fn from(&self) -> &Arc<BlockPath> {
  335. &self.from
  336. }
  337. /// Payload contained in this message.
  338. pub fn body(&self) -> &T {
  339. self.msg.msg()
  340. }
  341. /// Returns true if and only if this messages needs to be replied to.
  342. pub fn needs_reply(&self) -> bool {
  343. self.replier.is_some()
  344. }
  345. /// Takes the replier out of this struct and returns it, if it has not previously been returned.
  346. pub fn take_replier(&mut self) -> Option<Replier> {
  347. self.replier.take()
  348. }
  349. }
  350. type FramedMsg = FramedWrite<SendStream, MsgEncoder>;
  351. type ArcMutex<T> = Arc<Mutex<T>>;
  352. /// A type for sending a reply to a message. Replies are sent over their own streams, so no two
  353. /// messages can interfere with one another.
  354. #[derive(Clone)]
  355. pub struct Replier {
  356. stream: ArcMutex<FramedMsg>,
  357. }
  358. impl Replier {
  359. fn new(stream: ArcMutex<FramedMsg>) -> Self {
  360. Self { stream }
  361. }
  362. pub async fn reply<T: Serialize + Send>(&mut self, reply: T) -> Result<()> {
  363. let mut guard = self.stream.lock().await;
  364. guard.send(ReplyEnvelope::Ok(reply)).await?;
  365. Ok(())
  366. }
  367. pub async fn reply_err(&mut self, err: String, os_code: Option<i32>) -> Result<()> {
  368. let mut guard = self.stream.lock().await;
  369. guard.send(ReplyEnvelope::<()>::err(err, os_code)).await?;
  370. Ok(())
  371. }
  372. }