lib.rs 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765
  1. // SPDX-License-Identifier: AGPL-3.0-or-later
  2. //! Code which enables sending messages between processes in the blocktree system.
  3. #![feature(impl_trait_in_assoc_type)]
  4. mod tls;
  5. use tls::*;
  6. mod callback_framed;
  7. use callback_framed::CallbackFramed;
  8. pub use callback_framed::DeserCallback;
  9. use btlib::{bterr, crypto::Creds, error::DisplayErr, BlockPath, Result, Writecap};
  10. use btserde::{field_helpers::smart_ptr, write_to};
  11. use bytes::{BufMut, BytesMut};
  12. use core::{
  13. future::{ready, Future, Ready},
  14. marker::Send,
  15. pin::Pin,
  16. };
  17. use futures::{FutureExt, SinkExt};
  18. use log::{debug, error};
  19. use quinn::{Connection, ConnectionError, Endpoint, RecvStream, SendStream};
  20. use serde::{de::DeserializeOwned, Deserialize, Serialize};
  21. use std::{
  22. any::Any,
  23. fmt::Display,
  24. hash::Hash,
  25. io,
  26. marker::PhantomData,
  27. net::{IpAddr, Ipv6Addr, SocketAddr},
  28. result::Result as StdResult,
  29. sync::{Arc, Mutex as StdMutex},
  30. };
  31. use tokio::{
  32. select,
  33. sync::{broadcast, Mutex},
  34. task::{JoinError, JoinHandle},
  35. };
  36. use tokio_util::codec::{Encoder, Framed, FramedParts, FramedWrite};
  37. /// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of
  38. /// the [Writecap] in the given credentials. The returned type can be used to make
  39. /// [Transmitter]s for any path.
  40. pub fn receiver<C: 'static + Creds + Send + Sync, F: 'static + MsgCallback>(
  41. ip_addr: IpAddr,
  42. creds: Arc<C>,
  43. callback: F,
  44. ) -> Result<impl Receiver> {
  45. let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
  46. let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
  47. QuicReceiver::new(addr, Arc::new(CertResolver::new(creds)?), callback)
  48. }
  49. pub async fn transmitter<C: 'static + Creds + Send + Sync>(
  50. addr: Arc<BlockAddr>,
  51. creds: Arc<C>,
  52. ) -> Result<impl Transmitter> {
  53. let resolver = Arc::new(CertResolver::new(creds)?);
  54. let endpoint = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))?;
  55. QuicTransmitter::from_endpoint(endpoint, addr, resolver).await
  56. }
  57. pub trait MsgCallback: Clone + Send + Sync + Unpin {
  58. type Arg<'de>: CallMsg<'de>
  59. where
  60. Self: 'de;
  61. type CallFut<'de>: Future<Output = Result<()>> + Send
  62. where
  63. Self: 'de;
  64. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de>;
  65. }
  66. impl<T: MsgCallback> MsgCallback for &T {
  67. type Arg<'de> = T::Arg<'de> where Self: 'de;
  68. type CallFut<'de> = T::CallFut<'de> where Self: 'de;
  69. fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
  70. (*self).call(arg)
  71. }
  72. }
  73. /// Trait for messages which can be transmitted using the call method.
  74. pub trait CallMsg<'de>: Serialize + Deserialize<'de> + Send + Sync {
  75. type Reply<'r>: Serialize + Deserialize<'r> + Send;
  76. }
  77. /// Trait for messages which can be transmitted using the send method.
  78. /// Types which implement this trait should specify `()` as their reply type.
  79. pub trait SendMsg<'de>: CallMsg<'de> {}
  80. /// An address which identifies a block on the network. An instance of this struct can be
  81. /// used to get a socket address for the block this address refers to.
  82. #[derive(PartialEq, Eq, Hash, Clone, Debug, Serialize, Deserialize)]
  83. pub struct BlockAddr {
  84. #[serde(rename = "ipaddr")]
  85. ip_addr: IpAddr,
  86. #[serde(with = "smart_ptr")]
  87. path: Arc<BlockPath>,
  88. }
  89. impl BlockAddr {
  90. pub fn new(ip_addr: IpAddr, path: Arc<BlockPath>) -> Self {
  91. Self { ip_addr, path }
  92. }
  93. pub fn ip_addr(&self) -> IpAddr {
  94. self.ip_addr
  95. }
  96. pub fn path(&self) -> &BlockPath {
  97. self.path.as_ref()
  98. }
  99. fn port(&self) -> Result<u16> {
  100. self.path.port()
  101. }
  102. /// Returns the socket address of the block this instance refers to.
  103. pub fn socket_addr(&self) -> Result<SocketAddr> {
  104. Ok(SocketAddr::new(self.ip_addr, self.port()?))
  105. }
  106. }
  107. impl Display for BlockAddr {
  108. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  109. write!(f, "{}@{}", self.path, self.ip_addr)
  110. }
  111. }
  112. /// Indicates whether a message was sent using `call` or `send`.
  113. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
  114. enum MsgKind {
  115. /// This message expects exactly one reply.
  116. Call,
  117. /// This message expects exactly zero replies.
  118. Send,
  119. }
  120. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
  121. struct Envelope<T> {
  122. kind: MsgKind,
  123. msg: T,
  124. }
  125. impl<T> Envelope<T> {
  126. fn send(msg: T) -> Self {
  127. Self {
  128. msg,
  129. kind: MsgKind::Send,
  130. }
  131. }
  132. fn call(msg: T) -> Self {
  133. Self {
  134. msg,
  135. kind: MsgKind::Call,
  136. }
  137. }
  138. fn msg(&self) -> &T {
  139. &self.msg
  140. }
  141. }
  142. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
  143. enum ReplyEnvelope<T> {
  144. Ok(T),
  145. Err {
  146. message: String,
  147. os_code: Option<i32>,
  148. },
  149. }
  150. impl<T> ReplyEnvelope<T> {
  151. fn err(message: String, os_code: Option<i32>) -> Self {
  152. Self::Err { message, os_code }
  153. }
  154. }
  155. /// A message tagged with the block path that it was sent from.
  156. pub struct MsgReceived<T> {
  157. from: Arc<BlockPath>,
  158. msg: Envelope<T>,
  159. replier: Option<Replier>,
  160. }
  161. impl<T> MsgReceived<T> {
  162. fn new(from: Arc<BlockPath>, msg: Envelope<T>, replier: Option<Replier>) -> Self {
  163. Self { from, msg, replier }
  164. }
  165. pub fn into_parts(self) -> (Arc<BlockPath>, T, Option<Replier>) {
  166. (self.from, self.msg.msg, self.replier)
  167. }
  168. /// The path from which this message was received.
  169. pub fn from(&self) -> &Arc<BlockPath> {
  170. &self.from
  171. }
  172. /// Payload contained in this message.
  173. pub fn body(&self) -> &T {
  174. self.msg.msg()
  175. }
  176. /// Returns true if and only if this messages needs to be replied to.
  177. pub fn needs_reply(&self) -> bool {
  178. self.replier.is_some()
  179. }
  180. /// Takes the replier out of this struct and returns it, if it has not previously been returned.
  181. pub fn take_replier(&mut self) -> Option<Replier> {
  182. self.replier.take()
  183. }
  184. }
  185. /// Trait for receiving messages and creating [Transmitter]s.
  186. pub trait Receiver {
  187. /// The address at which messages will be received.
  188. fn addr(&self) -> &Arc<BlockAddr>;
  189. type Transmitter: Transmitter + Send;
  190. type TransmitterFut<'a>: 'a + Future<Output = Result<Self::Transmitter>> + Send
  191. where
  192. Self: 'a;
  193. /// Creates a [Transmitter] which is connected to the given address.
  194. fn transmitter(&self, addr: Arc<BlockAddr>) -> Self::TransmitterFut<'_>;
  195. type CompleteErr: std::error::Error + Send;
  196. type CompleteFut<'a>: 'a + Future<Output = StdResult<(), Self::CompleteErr>> + Send
  197. where
  198. Self: 'a;
  199. /// Returns a future which completes when this [Receiver] has completed (which may be never).
  200. fn complete(&self) -> Result<Self::CompleteFut<'_>>;
  201. type StopFut<'a>: 'a + Future<Output = Result<()>> + Send
  202. where
  203. Self: 'a;
  204. fn stop(&self) -> Self::StopFut<'_>;
  205. }
  206. /// A type which can be used to transmit messages.
  207. pub trait Transmitter {
  208. type SendFut<'call, T>: 'call + Future<Output = Result<()>> + Send
  209. where
  210. Self: 'call,
  211. T: 'call + SendMsg<'call>;
  212. /// Transmit a message to the connected [Receiver] without waiting for a reply.
  213. fn send<'call, T: 'call + SendMsg<'call>>(&'call self, msg: T) -> Self::SendFut<'call, T>;
  214. type CallFut<'call, T, F>: 'call + Future<Output = Result<F::Return>> + Send
  215. where
  216. Self: 'call,
  217. T: 'call + CallMsg<'call>,
  218. F: 'static + Send + DeserCallback;
  219. /// Transmit a message to the connected [Receiver], waits for a reply, then calls the given
  220. /// [DeserCallback] with the deserialized reply.
  221. ///
  222. /// ## WARNING
  223. /// The callback must be such that `F::Arg<'a> = T::Reply<'a>` for any `'a`. If this
  224. /// is violated, then a deserilization error will occur at runtime.
  225. ///
  226. /// ## TODO
  227. /// This issue needs to be fixed. Due to the fact that
  228. /// `F::Arg` is a Generic Associated Type (GAT) I have been unable to express this constraint in
  229. /// the where clause of this method. I'm not sure if the errors I've encountered are due to a
  230. /// lack of understanding on my part or due to the current limitations of the borrow checker in
  231. /// its handling of GATs.
  232. fn call<'call, T, F>(&'call self, msg: T, callback: F) -> Self::CallFut<'call, T, F>
  233. where
  234. T: 'call + CallMsg<'call>,
  235. F: 'static + Send + DeserCallback;
  236. /// Transmits a message to the connected [Receiver], waits for a reply, then passes back the
  237. /// the reply to the caller.
  238. fn call_through<'call, T>(
  239. &'call self,
  240. msg: T,
  241. ) -> Self::CallFut<'call, T, Passthrough<T::Reply<'call>>>
  242. where
  243. T: 'call + CallMsg<'call>,
  244. T::Reply<'call>: 'static + Send + Sync + DeserializeOwned,
  245. {
  246. self.call(msg, Passthrough::new())
  247. }
  248. /// Returns the address that this instance is transmitting to.
  249. fn addr(&self) -> &Arc<BlockAddr>;
  250. }
  251. pub struct Passthrough<T> {
  252. phantom: PhantomData<T>,
  253. }
  254. impl<T> Passthrough<T> {
  255. pub fn new() -> Self {
  256. Self {
  257. phantom: PhantomData,
  258. }
  259. }
  260. }
  261. impl<T> Default for Passthrough<T> {
  262. fn default() -> Self {
  263. Self::new()
  264. }
  265. }
  266. impl<T> Clone for Passthrough<T> {
  267. fn clone(&self) -> Self {
  268. Self::new()
  269. }
  270. }
  271. impl<T: 'static + Send + DeserializeOwned> DeserCallback for Passthrough<T> {
  272. type Arg<'de> = T;
  273. type Return = T;
  274. type CallFut<'de> = Ready<T>;
  275. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  276. ready(arg)
  277. }
  278. }
  279. /// Encodes messages using [btserde].
  280. #[derive(Debug)]
  281. struct MsgEncoder;
  282. impl MsgEncoder {
  283. fn new() -> Self {
  284. Self
  285. }
  286. }
  287. impl<T: Serialize> Encoder<T> for MsgEncoder {
  288. type Error = btlib::Error;
  289. fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<()> {
  290. const U64_LEN: usize = std::mem::size_of::<u64>();
  291. let payload = dst.split_off(U64_LEN);
  292. let mut writer = payload.writer();
  293. write_to(&item, &mut writer)?;
  294. let payload = writer.into_inner();
  295. let payload_len = payload.len() as u64;
  296. let mut writer = dst.writer();
  297. write_to(&payload_len, &mut writer)?;
  298. let dst = writer.into_inner();
  299. dst.unsplit(payload);
  300. Ok(())
  301. }
  302. }
  303. type FramedMsg = FramedWrite<SendStream, MsgEncoder>;
  304. type ArcMutex<T> = Arc<Mutex<T>>;
  305. #[derive(Clone)]
  306. pub struct Replier {
  307. stream: ArcMutex<FramedMsg>,
  308. }
  309. impl Replier {
  310. fn new(stream: ArcMutex<FramedMsg>) -> Self {
  311. Self { stream }
  312. }
  313. pub async fn reply<T: Serialize + Send>(&mut self, reply: T) -> Result<()> {
  314. let mut guard = self.stream.lock().await;
  315. guard.send(ReplyEnvelope::Ok(reply)).await?;
  316. Ok(())
  317. }
  318. pub async fn reply_err(&mut self, err: String, os_code: Option<i32>) -> Result<()> {
  319. let mut guard = self.stream.lock().await;
  320. guard.send(ReplyEnvelope::<()>::err(err, os_code)).await?;
  321. Ok(())
  322. }
  323. }
  324. struct MsgRecvdCallback<F> {
  325. path: Arc<BlockPath>,
  326. replier: Replier,
  327. inner: F,
  328. }
  329. impl<F: MsgCallback> MsgRecvdCallback<F> {
  330. fn new(path: Arc<BlockPath>, framed_msg: Arc<Mutex<FramedMsg>>, inner: F) -> Self {
  331. Self {
  332. path,
  333. replier: Replier::new(framed_msg),
  334. inner,
  335. }
  336. }
  337. }
  338. impl<F: 'static + MsgCallback> DeserCallback for MsgRecvdCallback<F> {
  339. type Arg<'de> = Envelope<F::Arg<'de>> where Self: 'de;
  340. type Return = Result<()>;
  341. type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send where F: 'de, Self: 'de;
  342. fn call<'de>(&'de mut self, arg: Envelope<F::Arg<'de>>) -> Self::CallFut<'de> {
  343. let replier = match arg.kind {
  344. MsgKind::Call => Some(self.replier.clone()),
  345. MsgKind::Send => None,
  346. };
  347. async move {
  348. let result = self
  349. .inner
  350. .call(MsgReceived::new(self.path.clone(), arg, replier))
  351. .await;
  352. match result {
  353. Ok(value) => Ok(value),
  354. Err(err) => match err.downcast::<io::Error>() {
  355. Ok(err) => {
  356. self.replier
  357. .reply_err(err.to_string(), err.raw_os_error())
  358. .await
  359. }
  360. Err(err) => self.replier.reply_err(err.to_string(), None).await,
  361. },
  362. }
  363. }
  364. }
  365. }
  366. macro_rules! handle_err {
  367. ($result:expr, $on_err:expr, $control_flow:expr) => {
  368. match $result {
  369. Ok(inner) => inner,
  370. Err(err) => {
  371. $on_err(err);
  372. $control_flow;
  373. }
  374. }
  375. };
  376. }
  377. /// Unwraps the given result, or if the result is an error, returns from the enclosing function.
  378. macro_rules! unwrap_or_return {
  379. ($result:expr, $on_err:expr) => {
  380. handle_err!($result, $on_err, return)
  381. };
  382. ($result:expr) => {
  383. unwrap_or_return!($result, |err| error!("{err}"))
  384. };
  385. }
  386. /// Unwraps the given result, or if the result is an error, continues the enclosing loop.
  387. macro_rules! unwrap_or_continue {
  388. ($result:expr, $on_err:expr) => {
  389. handle_err!($result, $on_err, continue)
  390. };
  391. ($result:expr) => {
  392. unwrap_or_continue!($result, |err| error!("{err}"))
  393. };
  394. }
  395. /// Awaits its first argument, unless interrupted by its second argument, in which case the
  396. /// enclosing function returns. The second argument needs to be cancel safe, but the first
  397. /// need not be if it is discarded when the enclosing function returns (because losing messages
  398. /// from the first argument doesn't matter in this case).
  399. macro_rules! await_or_stop {
  400. ($future:expr, $stop_fut:expr) => {
  401. select! {
  402. Some(connecting) = $future => connecting,
  403. _ = $stop_fut => break,
  404. }
  405. };
  406. }
  407. struct QuicReceiver {
  408. recv_addr: Arc<BlockAddr>,
  409. stop_tx: broadcast::Sender<()>,
  410. endpoint: Endpoint,
  411. resolver: Arc<CertResolver>,
  412. join_handle: StdMutex<Option<JoinHandle<()>>>,
  413. }
  414. impl QuicReceiver {
  415. fn new<F: 'static + MsgCallback>(
  416. recv_addr: Arc<BlockAddr>,
  417. resolver: Arc<CertResolver>,
  418. callback: F,
  419. ) -> Result<Self> {
  420. log::info!("starting QuicReceiver with address {}", recv_addr);
  421. let socket_addr = recv_addr.socket_addr()?;
  422. let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
  423. let (stop_tx, stop_rx) = broadcast::channel(1);
  424. let join_handle = tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx));
  425. Ok(Self {
  426. recv_addr,
  427. stop_tx,
  428. endpoint,
  429. resolver,
  430. join_handle: StdMutex::new(Some(join_handle)),
  431. })
  432. }
  433. async fn server_loop<F: 'static + MsgCallback>(
  434. endpoint: Endpoint,
  435. callback: F,
  436. mut stop_rx: broadcast::Receiver<()>,
  437. ) {
  438. loop {
  439. let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv());
  440. let connection = unwrap_or_continue!(connecting.await, |err| error!(
  441. "error accepting QUIC connection: {err}"
  442. ));
  443. tokio::spawn(Self::handle_connection(
  444. connection,
  445. callback.clone(),
  446. stop_rx.resubscribe(),
  447. ));
  448. }
  449. }
  450. async fn handle_connection<F: 'static + MsgCallback>(
  451. connection: Connection,
  452. callback: F,
  453. mut stop_rx: broadcast::Receiver<()>,
  454. ) {
  455. let client_path = unwrap_or_return!(
  456. Self::client_path(connection.peer_identity()),
  457. |err| error!("failed to get client path from peer identity: {err}")
  458. );
  459. loop {
  460. let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv());
  461. let (send_stream, recv_stream) = match result {
  462. Ok(pair) => pair,
  463. Err(err) => match err {
  464. ConnectionError::ApplicationClosed(app) => {
  465. debug!("connection closed: {app}");
  466. return;
  467. }
  468. _ => {
  469. error!("error accepting stream: {err}");
  470. continue;
  471. }
  472. },
  473. };
  474. let client_path = client_path.clone();
  475. let callback = callback.clone();
  476. tokio::task::spawn(Self::handle_message(
  477. client_path,
  478. send_stream,
  479. recv_stream,
  480. callback,
  481. ));
  482. }
  483. }
  484. async fn handle_message<F: 'static + MsgCallback>(
  485. client_path: Arc<BlockPath>,
  486. send_stream: SendStream,
  487. recv_stream: RecvStream,
  488. callback: F,
  489. ) {
  490. let framed_msg = Arc::new(Mutex::new(FramedWrite::new(send_stream, MsgEncoder::new())));
  491. let callback =
  492. MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone());
  493. let mut msg_stream = CallbackFramed::new(recv_stream);
  494. let result = msg_stream
  495. .next(callback)
  496. .await
  497. .ok_or_else(|| bterr!("client closed stream before sending a message"));
  498. match unwrap_or_return!(result) {
  499. Err(err) => error!("msg_stream produced an error: {err}"),
  500. Ok(result) => {
  501. if let Err(err) = result {
  502. error!("callback returned an error: {err}");
  503. }
  504. }
  505. }
  506. }
  507. /// Returns the path the client is bound to.
  508. fn client_path(peer_identity: Option<Box<dyn Any>>) -> Result<Arc<BlockPath>> {
  509. let peer_identity =
  510. peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?;
  511. let client_certs = peer_identity
  512. .downcast::<Vec<rustls::Certificate>>()
  513. .map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?;
  514. let first = client_certs
  515. .first()
  516. .ok_or_else(|| bterr!("no certificates were presented by the client"))?;
  517. let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?;
  518. Ok(Arc::new(writecap.bind_path()))
  519. }
  520. }
  521. impl Drop for QuicReceiver {
  522. fn drop(&mut self) {
  523. // This result will be a failure if the tasks have already returned, which is not a
  524. // problem.
  525. let _ = self.stop_tx.send(());
  526. }
  527. }
  528. impl Receiver for QuicReceiver {
  529. fn addr(&self) -> &Arc<BlockAddr> {
  530. &self.recv_addr
  531. }
  532. type Transmitter = QuicTransmitter;
  533. type TransmitterFut<'a> = Pin<Box<dyn 'a + Future<Output = Result<QuicTransmitter>> + Send>>;
  534. fn transmitter(&self, addr: Arc<BlockAddr>) -> Self::TransmitterFut<'_> {
  535. Box::pin(async {
  536. QuicTransmitter::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
  537. })
  538. }
  539. type CompleteErr = JoinError;
  540. type CompleteFut<'a> = JoinHandle<()>;
  541. fn complete(&self) -> Result<Self::CompleteFut<'_>> {
  542. let mut guard = self.join_handle.lock().display_err()?;
  543. let handle = guard
  544. .take()
  545. .ok_or_else(|| bterr!("join handle has already been taken"))?;
  546. Ok(handle)
  547. }
  548. type StopFut<'a> = Ready<Result<()>>;
  549. fn stop(&self) -> Self::StopFut<'_> {
  550. ready(self.stop_tx.send(()).map(|_| ()).map_err(|err| err.into()))
  551. }
  552. }
  553. macro_rules! cleanup_on_err {
  554. ($result:expr, $guard:ident, $parts:ident) => {
  555. match $result {
  556. Ok(value) => value,
  557. Err(err) => {
  558. *$guard = Some($parts);
  559. return Err(err.into());
  560. }
  561. }
  562. };
  563. }
  564. struct QuicTransmitter {
  565. addr: Arc<BlockAddr>,
  566. connection: Connection,
  567. send_parts: Mutex<Option<FramedParts<SendStream, MsgEncoder>>>,
  568. recv_buf: Mutex<Option<BytesMut>>,
  569. }
  570. impl QuicTransmitter {
  571. async fn from_endpoint(
  572. endpoint: Endpoint,
  573. addr: Arc<BlockAddr>,
  574. resolver: Arc<CertResolver>,
  575. ) -> Result<Self> {
  576. let socket_addr = addr.socket_addr()?;
  577. let connecting = endpoint.connect_with(
  578. client_config(addr.path.clone(), resolver)?,
  579. socket_addr,
  580. // The ServerCertVerifier ensures we connect to the correct path.
  581. "UNIMPORTANT",
  582. )?;
  583. let connection = connecting.await?;
  584. let send_parts = Mutex::new(None);
  585. let recv_buf = Mutex::new(Some(BytesMut::new()));
  586. Ok(Self {
  587. addr,
  588. connection,
  589. send_parts,
  590. recv_buf,
  591. })
  592. }
  593. async fn transmit<T: Serialize>(&self, envelope: Envelope<T>) -> Result<RecvStream> {
  594. let mut guard = self.send_parts.lock().await;
  595. let (send_stream, recv_stream) = self.connection.open_bi().await?;
  596. let parts = match guard.take() {
  597. Some(mut parts) => {
  598. parts.io = send_stream;
  599. parts
  600. }
  601. None => FramedParts::new::<Envelope<T>>(send_stream, MsgEncoder::new()),
  602. };
  603. let mut sink = Framed::from_parts(parts);
  604. let result = sink.send(envelope).await;
  605. let parts = sink.into_parts();
  606. cleanup_on_err!(result, guard, parts);
  607. *guard = Some(parts);
  608. Ok(recv_stream)
  609. }
  610. async fn call<'ser, T, F>(&'ser self, msg: T, callback: F) -> Result<F::Return>
  611. where
  612. T: 'ser + CallMsg<'ser>,
  613. F: 'static + Send + DeserCallback,
  614. {
  615. let recv_stream = self.transmit(Envelope::call(msg)).await?;
  616. let mut guard = self.recv_buf.lock().await;
  617. let buffer = guard.take().unwrap();
  618. let mut callback_framed = CallbackFramed::from_parts(recv_stream, buffer);
  619. let result = callback_framed
  620. .next(ReplyCallback::new(callback))
  621. .await
  622. .ok_or_else(|| bterr!("server hung up before sending reply"));
  623. let (_, buffer) = callback_framed.into_parts();
  624. let output = cleanup_on_err!(result, guard, buffer);
  625. *guard = Some(buffer);
  626. output?
  627. }
  628. }
  629. impl Transmitter for QuicTransmitter {
  630. fn addr(&self) -> &Arc<BlockAddr> {
  631. &self.addr
  632. }
  633. type SendFut<'ser, T> = impl 'ser + Future<Output = Result<()>> + Send
  634. where T: 'ser + SendMsg<'ser>;
  635. fn send<'ser, 'r, T: 'ser + SendMsg<'ser>>(&'ser self, msg: T) -> Self::SendFut<'ser, T> {
  636. self.transmit(Envelope::send(msg))
  637. .map(|result| result.map(|_| ()))
  638. }
  639. type CallFut<'ser, T, F> = impl 'ser + Future<Output = Result<F::Return>> + Send
  640. where
  641. Self: 'ser,
  642. T: 'ser + CallMsg<'ser>,
  643. F: 'static + Send + DeserCallback;
  644. fn call<'ser, T, F>(&'ser self, msg: T, callback: F) -> Self::CallFut<'ser, T, F>
  645. where
  646. T: 'ser + CallMsg<'ser>,
  647. F: 'static + Send + DeserCallback,
  648. {
  649. self.call(msg, callback)
  650. }
  651. }
  652. struct ReplyCallback<F> {
  653. inner: F,
  654. }
  655. impl<F> ReplyCallback<F> {
  656. fn new(inner: F) -> Self {
  657. Self { inner }
  658. }
  659. }
  660. impl<F: 'static + Send + DeserCallback> DeserCallback for ReplyCallback<F> {
  661. type Arg<'de> = ReplyEnvelope<F::Arg<'de>>;
  662. type Return = Result<F::Return>;
  663. type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send;
  664. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  665. async move {
  666. match arg {
  667. ReplyEnvelope::Ok(msg) => Ok(self.inner.call(msg).await),
  668. ReplyEnvelope::Err { message, os_code } => {
  669. if let Some(os_code) = os_code {
  670. let err = bterr!(io::Error::from_raw_os_error(os_code)).context(message);
  671. Err(err)
  672. } else {
  673. Err(bterr!(message))
  674. }
  675. }
  676. }
  677. }
  678. }
  679. }