lib.rs 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831
  1. //! Code which enables sending messages between processes in the blocktree system.
  2. use btlib::{
  3. bterr,
  4. crypto::{rand_array, ConcreteCreds, CredsPriv, CredsPub},
  5. error::BoxInIoErr,
  6. BlockPath, Result,
  7. };
  8. use btserde::{read_from, write_to};
  9. use bytes::{BufMut, BytesMut};
  10. use core::{
  11. future::Future,
  12. marker::Send,
  13. pin::Pin,
  14. task::{Context, Poll},
  15. };
  16. use futures::{
  17. future::{ready, Ready},
  18. sink::{Close, Send as SendFut, Sink},
  19. stream::Stream,
  20. SinkExt, StreamExt,
  21. };
  22. use lazy_static::lazy_static;
  23. use log::error;
  24. use quinn::{ClientConfig, Endpoint, SendStream, ServerConfig};
  25. use rustls::{
  26. Certificate, ConfigBuilder, ConfigSide, PrivateKey, WantsCipherSuites, WantsVerifier,
  27. };
  28. use serde::{de::DeserializeOwned, Deserialize, Serialize};
  29. use std::{
  30. collections::hash_map::DefaultHasher,
  31. hash::{Hash, Hasher},
  32. io,
  33. marker::PhantomData,
  34. net::{IpAddr, Shutdown, SocketAddr},
  35. path::PathBuf,
  36. sync::Arc,
  37. };
  38. use tokio::{
  39. io::{AsyncRead, AsyncWrite, ReadBuf},
  40. net::UnixDatagram,
  41. sync::{
  42. broadcast::{self, error::TryRecvError},
  43. mpsc,
  44. },
  45. };
  46. use tokio_stream::wrappers::ReceiverStream;
  47. use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
  48. use zerocopy::FromBytes;
  49. pub use private::*;
  50. mod private {
  51. use super::*;
  52. /// Returns a [Router] which can be used to make a [Receiver] for the given path and
  53. /// [Sender] instances for any path.
  54. pub fn router(addr: Arc<BlockAddr>, creds: &ConcreteCreds) -> Result<impl Router> {
  55. QuicRouter::new(addr, creds)
  56. }
  57. lazy_static! {
  58. /// The default directory in which to place blocktree sockets.
  59. static ref SOCK_DIR: PathBuf = {
  60. let mut path: PathBuf = std::env::var("HOME").unwrap().into();
  61. path.push(".btmsg");
  62. path
  63. };
  64. }
  65. /// Appends the given Blocktree path to the path of the given directory.
  66. #[allow(dead_code)]
  67. fn socket_path(fs_path: &mut PathBuf, addr: &BlockAddr) {
  68. fs_path.push(addr.path.to_string());
  69. }
  70. fn common_config<Side: ConfigSide>(
  71. builder: ConfigBuilder<Side, WantsCipherSuites>,
  72. ) -> Result<ConfigBuilder<Side, WantsVerifier>> {
  73. builder
  74. .with_cipher_suites(&[rustls::cipher_suite::TLS13_AES_128_GCM_SHA256])
  75. .with_kx_groups(&[&rustls::kx_group::SECP256R1])
  76. .with_protocol_versions(&[&rustls::version::TLS13])
  77. .map_err(|err| err.into())
  78. }
  79. fn server_config(creds: &ConcreteCreds) -> Result<ServerConfig> {
  80. let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
  81. let chain = writecap.to_cert_chain(creds.public_sign())?;
  82. let mut cert_chain = Vec::with_capacity(chain.len());
  83. for cert in chain {
  84. cert_chain.push(Certificate(cert))
  85. }
  86. let key = PrivateKey(creds.private_sign().to_der()?);
  87. let server_config = common_config(rustls::ServerConfig::builder())?
  88. .with_no_client_auth()
  89. .with_single_cert(cert_chain, key)?;
  90. Ok(ServerConfig::with_crypto(Arc::new(server_config)))
  91. }
  92. fn client_config() -> Result<ClientConfig> {
  93. let client_config = common_config(rustls::ClientConfig::builder())?
  94. .with_custom_certificate_verifier(CertVerifier::new())
  95. .with_no_client_auth();
  96. Ok(ClientConfig::new(Arc::new(client_config)))
  97. }
  98. /// An identifier for a block. Persistent blocks (files, directories, and servers) are
  99. /// identified by the `Inode` variant and transient blocks (processes) are identified by the
  100. /// PID variant.
  101. #[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug)]
  102. pub enum BlockNum {
  103. Inode(u64),
  104. Pid(u64),
  105. }
  106. impl BlockNum {
  107. pub fn value(&self) -> u64 {
  108. match self {
  109. BlockNum::Inode(value) => *value,
  110. BlockNum::Pid(value) => *value,
  111. }
  112. }
  113. }
  114. impl From<BlockNum> for u64 {
  115. fn from(value: BlockNum) -> Self {
  116. value.value()
  117. }
  118. }
  119. #[derive(PartialEq, Eq, Hash, Clone, Debug)]
  120. pub struct BlockAddr {
  121. ip_addr: IpAddr,
  122. path: Arc<BlockPath>,
  123. }
  124. impl BlockAddr {
  125. pub fn new(ip_addr: IpAddr, path: Arc<BlockPath>) -> Self {
  126. Self { ip_addr, path }
  127. }
  128. fn port(&self) -> u16 {
  129. let mut hasher = DefaultHasher::new();
  130. self.path.hash(&mut hasher);
  131. let hash = hasher.finish();
  132. // We compute a port in the dynamic range [49152, 65535] as defined by RFC 6335.
  133. const NUM_RES_PORTS: u16 = 49153;
  134. const PORTS_AVAIL: u64 = (u16::MAX - NUM_RES_PORTS) as u64;
  135. NUM_RES_PORTS + (hash % PORTS_AVAIL) as u16
  136. }
  137. pub fn socket_addr(&self) -> SocketAddr {
  138. let port = self.port();
  139. SocketAddr::new(self.ip_addr, port)
  140. }
  141. }
  142. /// Generates and returns a new message ID.
  143. fn rand_msg_id() -> Result<u128> {
  144. const LEN: usize = std::mem::size_of::<u128>();
  145. let bytes = rand_array::<LEN>()?;
  146. let option = u128::read_from(bytes.as_slice());
  147. // Safety: because LEN == size_of::<u128>(), read_from should have returned Some.
  148. Ok(option.unwrap())
  149. }
  150. #[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
  151. pub struct Msg<T> {
  152. pub id: u128,
  153. pub body: T,
  154. }
  155. impl<T> Msg<T> {
  156. pub fn new(id: u128, body: T) -> Self {
  157. Self { id, body }
  158. }
  159. pub fn with_rand_id(body: T) -> Result<Self> {
  160. Ok(Self {
  161. id: rand_msg_id()?,
  162. body,
  163. })
  164. }
  165. }
  166. /// A type which can be used to send messages.
  167. /// Once the "Permit impl Trait in type aliases" https://github.com/rust-lang/rust/issues/63063
  168. /// feature lands the future types in this trait should be rewritten to use it.
  169. pub trait Sender {
  170. type SendFut<'a, T>: 'a + Future<Output = Result<()>> + Send
  171. where
  172. Self: 'a,
  173. T: 'a + Serialize + Send;
  174. fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: Msg<T>) -> Self::SendFut<'a, T>;
  175. type FinishFut: Future<Output = Result<()>> + Send;
  176. fn finish(self) -> Self::FinishFut;
  177. fn addr(&self) -> &BlockAddr;
  178. fn send_with_rand_id<'a, T: 'a + Serialize + Send>(
  179. &'a mut self,
  180. body: T,
  181. ) -> Self::SendFut<'a, T> {
  182. let msg = Msg::with_rand_id(body).unwrap();
  183. self.send(msg)
  184. }
  185. }
  186. /// A type which can be used to receive messages.
  187. pub trait Receiver<T>: Stream<Item = Result<Msg<T>>> {
  188. fn addr(&self) -> &BlockAddr;
  189. }
  190. pub trait Router {
  191. type Sender: Sender + Send;
  192. type SenderFut<'a>: 'a + Future<Output = Result<Self::Sender>> + Send
  193. where
  194. Self: 'a;
  195. fn sender(&self, addr: Arc<BlockAddr>) -> Self::SenderFut<'_>;
  196. type Receiver<T: 'static + DeserializeOwned + Send>: Receiver<T> + Send + Unpin;
  197. type ReceiverFut<'a, T>: 'a + Future<Output = Result<Self::Receiver<T>>> + Send
  198. where
  199. T: 'static + DeserializeOwned + Send + Unpin,
  200. Self: 'a;
  201. fn receiver<T: 'static + DeserializeOwned + Send + Unpin>(
  202. &self,
  203. ) -> Self::ReceiverFut<'_, T>;
  204. }
  205. /// Encodes and decodes messages using [btserde].
  206. struct MsgEncoder;
  207. impl MsgEncoder {
  208. fn new() -> Self {
  209. Self
  210. }
  211. }
  212. impl<T: Serialize> Encoder<Msg<T>> for MsgEncoder {
  213. type Error = btlib::Error;
  214. fn encode(&mut self, item: Msg<T>, dst: &mut BytesMut) -> Result<()> {
  215. const U64_LEN: usize = std::mem::size_of::<u64>();
  216. let payload = dst.split_off(U64_LEN);
  217. let mut writer = payload.writer();
  218. write_to(&item, &mut writer)?;
  219. let payload = writer.into_inner();
  220. let payload_len = payload.len() as u64;
  221. let mut writer = dst.writer();
  222. write_to(&payload_len, &mut writer)?;
  223. let dst = writer.into_inner();
  224. dst.unsplit(payload);
  225. Ok(())
  226. }
  227. }
  228. struct MsgDecoder<T>(PhantomData<T>);
  229. impl<T> MsgDecoder<T> {
  230. fn new() -> Self {
  231. Self(PhantomData)
  232. }
  233. }
  234. impl<T: DeserializeOwned> Decoder for MsgDecoder<T> {
  235. type Item = Msg<T>;
  236. type Error = btlib::Error;
  237. fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
  238. let mut slice: &[u8] = src.as_ref();
  239. let payload_len: u64 = match read_from(&mut slice) {
  240. Ok(payload_len) => payload_len,
  241. Err(err) => {
  242. if let btserde::Error::Eof = err {
  243. return Ok(None);
  244. }
  245. return Err(err.into());
  246. }
  247. };
  248. let payload_len: usize = payload_len.try_into().box_err()?;
  249. if slice.len() < payload_len {
  250. src.reserve(payload_len - slice.len());
  251. return Ok(None);
  252. }
  253. let msg = read_from(&mut slice)?;
  254. // Consume all the bytes that have been read out of the buffer.
  255. let _ = src.split_to(std::mem::size_of::<u64>() + payload_len);
  256. Ok(Some(msg))
  257. }
  258. }
  259. /// Wraps a [UnixDatagram] and implements [AsyncRead] and [AsyncWrite] for it. Read operations
  260. /// are translated to calls to `recv_from` and write operations are translated to `send`. Note
  261. /// that this means that writes will fail unless the wrapped socket is connected to a peer.
  262. struct DatagramAdapter {
  263. socket: UnixDatagram,
  264. }
  265. impl DatagramAdapter {
  266. #[allow(dead_code)]
  267. fn new(socket: UnixDatagram) -> Self {
  268. Self { socket }
  269. }
  270. fn get_ref(&self) -> &UnixDatagram {
  271. &self.socket
  272. }
  273. fn get_mut(&mut self) -> &mut UnixDatagram {
  274. &mut self.socket
  275. }
  276. }
  277. impl AsRef<UnixDatagram> for DatagramAdapter {
  278. fn as_ref(&self) -> &UnixDatagram {
  279. self.get_ref()
  280. }
  281. }
  282. impl AsMut<UnixDatagram> for DatagramAdapter {
  283. fn as_mut(&mut self) -> &mut UnixDatagram {
  284. self.get_mut()
  285. }
  286. }
  287. impl AsyncRead for DatagramAdapter {
  288. fn poll_read(
  289. self: Pin<&mut Self>,
  290. cx: &mut Context<'_>,
  291. buf: &mut ReadBuf<'_>,
  292. ) -> Poll<io::Result<()>> {
  293. self.socket.poll_recv(cx, buf)
  294. }
  295. }
  296. impl AsyncWrite for DatagramAdapter {
  297. fn poll_write(
  298. self: Pin<&mut Self>,
  299. cx: &mut Context<'_>,
  300. buf: &[u8],
  301. ) -> Poll<io::Result<usize>> {
  302. self.socket.poll_send(cx, buf)
  303. }
  304. fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
  305. Poll::Ready(self.socket.shutdown(Shutdown::Write))
  306. }
  307. fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
  308. Poll::Ready(Ok(()))
  309. }
  310. }
  311. /// An implementation of [Receiver] which uses a Unix datagram socket for receiving messages.
  312. struct UnixReceiver<T> {
  313. addr: BlockAddr,
  314. socket: FramedRead<DatagramAdapter, MsgDecoder<T>>,
  315. }
  316. impl<T: DeserializeOwned> UnixReceiver<T> {
  317. #[allow(dead_code)]
  318. fn new(mut fs_path: PathBuf, addr: BlockAddr) -> Result<Self> {
  319. socket_path(&mut fs_path, &addr);
  320. let socket = DatagramAdapter::new(UnixDatagram::bind(fs_path)?);
  321. let socket = FramedRead::new(socket, MsgDecoder(PhantomData));
  322. Ok(Self { addr, socket })
  323. }
  324. }
  325. impl<T: DeserializeOwned> Stream for UnixReceiver<T> {
  326. type Item = Result<Msg<T>>;
  327. fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
  328. self.socket.poll_next_unpin(cx)
  329. }
  330. }
  331. impl<T: DeserializeOwned> Receiver<T> for UnixReceiver<T> {
  332. fn addr(&self) -> &BlockAddr {
  333. &self.addr
  334. }
  335. }
  336. /// An implementation of [Sender] which uses a Unix datagram socket to send messages.
  337. struct UnixSender {
  338. addr: BlockAddr,
  339. socket: FramedWrite<DatagramAdapter, MsgEncoder>,
  340. }
  341. impl UnixSender {
  342. #[allow(dead_code)]
  343. fn new(mut fs_path: PathBuf, addr: BlockAddr) -> Result<Self> {
  344. let socket = UnixDatagram::unbound()?;
  345. socket_path(&mut fs_path, &addr);
  346. socket.connect(fs_path)?;
  347. let socket = FramedWrite::new(DatagramAdapter::new(socket), MsgEncoder);
  348. Ok(Self { addr, socket })
  349. }
  350. }
  351. impl<T: Serialize> Sink<Msg<T>> for UnixSender {
  352. type Error = btlib::Error;
  353. fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
  354. <tokio_util::codec::FramedWrite<DatagramAdapter, MsgEncoder> as futures::SinkExt<
  355. Msg<T>,
  356. >>::poll_ready_unpin(&mut self.socket, cx)
  357. }
  358. fn start_send(mut self: Pin<&mut Self>, item: Msg<T>) -> Result<()> {
  359. self.socket.start_send_unpin(item)
  360. }
  361. fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
  362. <tokio_util::codec::FramedWrite<DatagramAdapter, MsgEncoder> as futures::SinkExt<
  363. Msg<T>,
  364. >>::poll_flush_unpin(&mut self.socket, cx)
  365. }
  366. fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
  367. <tokio_util::codec::FramedWrite<DatagramAdapter, MsgEncoder> as futures::SinkExt<
  368. Msg<T>,
  369. >>::poll_close_unpin(&mut self.socket, cx)
  370. }
  371. }
  372. impl Sender for UnixSender {
  373. fn addr(&self) -> &BlockAddr {
  374. &self.addr
  375. }
  376. type SendFut<'a, T>
  377. = SendFut<'a, FramedWrite<DatagramAdapter, MsgEncoder>, Msg<T>>
  378. where T: 'a + Serialize + Send;
  379. fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: Msg<T>) -> Self::SendFut<'a, T> {
  380. self.socket.send(msg)
  381. }
  382. type FinishFut = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
  383. fn finish(mut self) -> Self::FinishFut {
  384. Box::pin(async move {
  385. let fut: Close<'_, _, Msg<()>> = self.socket.close();
  386. fut.await
  387. })
  388. }
  389. }
  390. /// Causes the current function to return if the given `rx` has received a stop signal.
  391. macro_rules! check_stop {
  392. ($rx:expr) => {
  393. match $rx.try_recv() {
  394. Ok(_) => return,
  395. Err(err) => {
  396. if let TryRecvError::Closed = err {
  397. return;
  398. }
  399. }
  400. }
  401. };
  402. }
  403. struct QuicRouter {
  404. recv_addr: Arc<BlockAddr>,
  405. endpoint: Endpoint,
  406. }
  407. impl QuicRouter {
  408. fn new(recv_addr: Arc<BlockAddr>, creds: &ConcreteCreds) -> Result<Self> {
  409. let socket_addr = recv_addr.socket_addr();
  410. let endpoint = Endpoint::server(server_config(creds)?, socket_addr)?;
  411. Ok(Self {
  412. endpoint,
  413. recv_addr,
  414. })
  415. }
  416. }
  417. impl Router for QuicRouter {
  418. type Receiver<T: 'static + DeserializeOwned + Send> = QuicReceiver<T>;
  419. type ReceiverFut<'a, T: 'static + DeserializeOwned + Send + Unpin> =
  420. Ready<Result<QuicReceiver<T>>>;
  421. fn receiver<T: 'static + DeserializeOwned + Send + Unpin>(
  422. &self,
  423. ) -> Self::ReceiverFut<'_, T> {
  424. ready(QuicReceiver::new(
  425. self.endpoint.clone(),
  426. self.recv_addr.clone(),
  427. ))
  428. }
  429. type Sender = QuicSender;
  430. type SenderFut<'a> = Pin<Box<dyn 'a + Future<Output = Result<QuicSender>> + Send>>;
  431. fn sender(&self, addr: Arc<BlockAddr>) -> Self::SenderFut<'_> {
  432. Box::pin(async { QuicSender::from_endpoint(self.endpoint.clone(), addr).await })
  433. }
  434. }
  435. struct QuicReceiver<T> {
  436. recv_addr: Arc<BlockAddr>,
  437. stop_tx: broadcast::Sender<()>,
  438. stream: ReceiverStream<Result<Msg<T>>>,
  439. }
  440. impl<T: DeserializeOwned + Send + 'static> QuicReceiver<T> {
  441. /// The size of the buffer to store received messages in.
  442. const MSG_BUF_SZ: usize = 64;
  443. fn new(endpoint: Endpoint, recv_addr: Arc<BlockAddr>) -> Result<Self> {
  444. let (stop_tx, mut stop_rx) = broadcast::channel(1);
  445. let (msg_tx, msg_rx) = mpsc::channel(Self::MSG_BUF_SZ);
  446. tokio::spawn(async move {
  447. loop {
  448. check_stop!(stop_rx);
  449. let connecting = match endpoint.accept().await {
  450. Some(connection) => connection,
  451. None => break,
  452. };
  453. let connection = match connecting.await {
  454. Ok(connection) => connection,
  455. Err(err) => {
  456. error!("error accepting QUIC connection: {err}");
  457. continue;
  458. }
  459. };
  460. let conn_msg_tx = msg_tx.clone();
  461. let mut conn_stop_rx = stop_rx.resubscribe();
  462. tokio::spawn(async move {
  463. let recv_stream = match connection.accept_uni().await {
  464. Ok(recv_stream) => recv_stream,
  465. Err(err) => {
  466. error!("error accepting receive stream: {err}");
  467. return;
  468. }
  469. };
  470. let mut msg_stream = FramedRead::new(recv_stream, MsgDecoder::new());
  471. loop {
  472. check_stop!(conn_stop_rx);
  473. let result = match msg_stream.next().await {
  474. Some(result) => result,
  475. None => return,
  476. };
  477. if let Err(err) = conn_msg_tx.send(result).await {
  478. error!("error sending message to mpsc queue: {err}");
  479. }
  480. }
  481. });
  482. }
  483. });
  484. Ok(Self {
  485. recv_addr,
  486. stop_tx,
  487. stream: ReceiverStream::new(msg_rx),
  488. })
  489. }
  490. }
  491. impl<T> Drop for QuicReceiver<T> {
  492. fn drop(&mut self) {
  493. // This result will be a failure if the tasks have already returned, which is not a
  494. // problem.
  495. let _ = self.stop_tx.send(());
  496. }
  497. }
  498. impl<T: DeserializeOwned + Send + 'static> Stream for QuicReceiver<T> {
  499. type Item = Result<Msg<T>>;
  500. fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
  501. self.stream.poll_next_unpin(cx)
  502. }
  503. }
  504. impl<T: DeserializeOwned + Send + 'static> Receiver<T> for QuicReceiver<T> {
  505. fn addr(&self) -> &BlockAddr {
  506. &self.recv_addr
  507. }
  508. }
  509. struct CertVerifier;
  510. impl CertVerifier {
  511. fn new() -> Arc<Self> {
  512. Arc::new(Self)
  513. }
  514. }
  515. impl rustls::client::ServerCertVerifier for CertVerifier {
  516. fn verify_server_cert(
  517. &self,
  518. _end_entity: &Certificate,
  519. _intermediates: &[Certificate],
  520. _server_name: &rustls::ServerName,
  521. _scts: &mut dyn Iterator<Item = &[u8]>,
  522. _ocsp_response: &[u8],
  523. _now: std::time::SystemTime,
  524. ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
  525. // TODO: Implement certificate verification.
  526. Ok(rustls::client::ServerCertVerified::assertion())
  527. }
  528. }
  529. struct QuicSender {
  530. addr: Arc<BlockAddr>,
  531. sink: FramedWrite<SendStream, MsgEncoder>,
  532. }
  533. impl QuicSender {
  534. async fn from_endpoint(endpoint: Endpoint, addr: Arc<BlockAddr>) -> Result<Self> {
  535. let socket_addr = addr.socket_addr();
  536. let connecting = endpoint.connect_with(client_config()?, socket_addr, "localhost")?;
  537. let connection = connecting.await?;
  538. let send_stream = connection.open_uni().await?;
  539. let sink = FramedWrite::new(send_stream, MsgEncoder::new());
  540. Ok(Self { addr, sink })
  541. }
  542. }
  543. impl Sender for QuicSender {
  544. fn addr(&self) -> &BlockAddr {
  545. &self.addr
  546. }
  547. type SendFut<'a, T> = SendFut<'a, FramedWrite<SendStream, MsgEncoder>, Msg<T>>
  548. where T: 'a + Serialize + Send;
  549. fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: Msg<T>) -> Self::SendFut<'a, T> {
  550. self.sink.send(msg)
  551. }
  552. type FinishFut = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
  553. fn finish(mut self) -> Self::FinishFut {
  554. Box::pin(async move {
  555. let steam: &mut SendStream = self.sink.get_mut();
  556. steam.finish().await.map_err(|err| bterr!(err))
  557. })
  558. }
  559. }
  560. /// This is an identify function which allows you to specify a type parameter for the output
  561. /// of a future.
  562. /// This was needed to work around a failure in type inference for types with higher-rank
  563. /// lifetimes. Once this issue is resolved this can be removed:
  564. /// https://github.com/rust-lang/rust/issues/102211
  565. pub fn assert_send<'a, T>(
  566. fut: impl 'a + Future<Output = T> + Send,
  567. ) -> impl 'a + Future<Output = T> + Send {
  568. fut
  569. }
  570. }
  571. #[cfg(test)]
  572. mod tests {
  573. use super::*;
  574. use btlib::{crypto::Creds, Epoch, Principal, Principaled};
  575. use ctor::ctor;
  576. use std::{
  577. net::Ipv6Addr,
  578. sync::atomic::{AtomicU64, Ordering},
  579. time::Duration,
  580. };
  581. #[ctor]
  582. fn setup_logging() {
  583. env_logger::init();
  584. }
  585. lazy_static! {
  586. static ref ROOT_CREDS: ConcreteCreds = ConcreteCreds::generate().unwrap();
  587. static ref NODE_CREDS: ConcreteCreds = {
  588. let mut creds = ConcreteCreds::generate().unwrap();
  589. let root_creds = &ROOT_CREDS;
  590. let writecap = root_creds
  591. .issue_writecap(
  592. creds.principal(),
  593. vec![],
  594. Epoch::now() + Duration::from_secs(3600),
  595. )
  596. .unwrap();
  597. creds.set_writecap(writecap);
  598. creds
  599. };
  600. static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal();
  601. }
  602. fn block_addr<'a, I: Iterator<Item = S>, S: ToString>(components: I) -> BlockAddr {
  603. let components = components.map(|e| e.to_string()).collect();
  604. let path = BlockPath::new(ROOT_CREDS.principal(), components);
  605. let path = Arc::new(path);
  606. BlockAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), path)
  607. }
  608. #[derive(Serialize, Deserialize)]
  609. enum MsgError {
  610. Unknown,
  611. }
  612. #[derive(Deserialize)]
  613. enum BodyOwned {
  614. Ping,
  615. Success,
  616. Fail(MsgError),
  617. Read { offset: u64, size: u64 },
  618. Write { offset: u64, buf: Vec<u8> },
  619. }
  620. #[derive(Serialize)]
  621. enum BodyRef<'a> {
  622. Ping,
  623. Success,
  624. Fail(MsgError),
  625. Read { offset: u64, size: u64 },
  626. Write { offset: u64, buf: &'a [u8] },
  627. }
  628. struct TestCase {
  629. instance_num: u64,
  630. }
  631. impl TestCase {
  632. fn new() -> TestCase {
  633. static INSTANCE_NUM: AtomicU64 = AtomicU64::new(0);
  634. let instance_num = INSTANCE_NUM.fetch_add(1, Ordering::SeqCst);
  635. Self { instance_num }
  636. }
  637. async fn endpoint(&self, inode: u64) -> (impl Sender, impl Receiver<BodyOwned>) {
  638. let addr = Arc::new(block_addr([self.instance_num, inode].iter()));
  639. let router = router(addr.clone(), &NODE_CREDS).unwrap();
  640. let receiver = router.receiver::<BodyOwned>().await.unwrap();
  641. let sender = router.sender(addr).await.unwrap();
  642. (sender, receiver)
  643. }
  644. }
  645. #[tokio::test]
  646. async fn message_received_is_message_sent() {
  647. let case = TestCase::new();
  648. let (mut sender, mut receiver) = case.endpoint(1).await;
  649. sender.send_with_rand_id(BodyRef::Ping).await.unwrap();
  650. let actual = receiver.next().await.unwrap().unwrap();
  651. let matched = if let BodyOwned::Ping = actual.body {
  652. true
  653. } else {
  654. false
  655. };
  656. assert!(matched);
  657. }
  658. #[tokio::test]
  659. async fn ping_pong() {
  660. let case = TestCase::new();
  661. let (mut sender_one, mut receiver_one) = case.endpoint(1).await;
  662. let (mut sender_two, mut receiver_two) = case.endpoint(2).await;
  663. tokio::spawn(async move {
  664. let msg = receiver_one.next().await.unwrap().unwrap();
  665. let reply_body = if let BodyOwned::Ping = msg.body {
  666. BodyRef::Success
  667. } else {
  668. BodyRef::Fail(MsgError::Unknown)
  669. };
  670. let fut = assert_send::<'_, Result<()>>(sender_two.send_with_rand_id(reply_body));
  671. fut.await.unwrap();
  672. sender_two.finish().await.unwrap();
  673. });
  674. sender_one.send_with_rand_id(BodyRef::Ping).await.unwrap();
  675. let reply = receiver_two.next().await.unwrap().unwrap();
  676. let matched = if let BodyOwned::Success = reply.body {
  677. true
  678. } else {
  679. false
  680. };
  681. assert!(matched)
  682. }
  683. #[tokio::test]
  684. async fn read_write() {
  685. let case = TestCase::new();
  686. let (mut sender_one, mut receiver_one) = case.endpoint(1).await;
  687. let (mut sender_two, mut receiver_two) = case.endpoint(2).await;
  688. let handle = tokio::spawn(async move {
  689. let data: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
  690. let msg = receiver_one.next().await.unwrap().unwrap();
  691. let reply_body = if let BodyOwned::Read { offset, size } = msg.body {
  692. let offset: usize = offset.try_into().unwrap();
  693. let size: usize = size.try_into().unwrap();
  694. let end: usize = offset + size;
  695. BodyRef::Write {
  696. offset: offset as u64,
  697. buf: &data[offset..end],
  698. }
  699. } else {
  700. BodyRef::Fail(MsgError::Unknown)
  701. };
  702. let msg = Msg::new(msg.id, reply_body);
  703. let fut = assert_send::<'_, Result<()>>(sender_two.send(msg));
  704. fut.await.unwrap();
  705. sender_two.finish().await.unwrap();
  706. });
  707. sender_one
  708. .send_with_rand_id(BodyRef::Read { offset: 2, size: 2 })
  709. .await
  710. .unwrap();
  711. handle.await.unwrap();
  712. let reply = receiver_two.next().await.unwrap().unwrap();
  713. if let BodyOwned::Write { offset, buf } = reply.body {
  714. assert_eq!(2, offset);
  715. assert_eq!([2, 3].as_slice(), buf.as_slice());
  716. } else {
  717. panic!("reply was not the right type");
  718. };
  719. }
  720. }