serialization.rs 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. // SPDX-License-Identifier: AGPL-3.0-or-later
  2. //! Types for serializing and deserializing messages.
  3. use btlib::error::BoxInIoErr;
  4. use btserde::{from_slice, read_from, write_to};
  5. use bytes::{BufMut, BytesMut};
  6. use futures::Future;
  7. use serde::{Deserialize, Serialize};
  8. use tokio::io::{AsyncRead, AsyncReadExt};
  9. use tokio_util::codec::Encoder;
  10. use crate::Result;
  11. pub(crate) struct CallbackFramed<I> {
  12. io: I,
  13. buffer: BytesMut,
  14. }
  15. impl<I> CallbackFramed<I> {
  16. const INIT_CAPACITY: usize = 4096;
  17. /// The number of bytes used to encode the length of each frame.
  18. const FRAME_LEN_SZ: usize = std::mem::size_of::<u64>();
  19. pub fn new(inner: I) -> Self {
  20. Self {
  21. io: inner,
  22. buffer: BytesMut::with_capacity(Self::INIT_CAPACITY),
  23. }
  24. }
  25. pub fn into_parts(self) -> (I, BytesMut) {
  26. (self.io, self.buffer)
  27. }
  28. pub fn from_parts(io: I, mut buffer: BytesMut) -> Self {
  29. if buffer.capacity() < Self::INIT_CAPACITY {
  30. buffer.reserve(Self::INIT_CAPACITY - buffer.capacity());
  31. }
  32. Self { io, buffer }
  33. }
  34. async fn decode(mut slice: &[u8]) -> Result<DecodeStatus> {
  35. let payload_len: u64 = match read_from(&mut slice) {
  36. Ok(payload_len) => payload_len,
  37. Err(err) => {
  38. return match err {
  39. btserde::Error::Eof => Ok(DecodeStatus::None),
  40. btserde::Error::Io(ref io_err) => match io_err.kind() {
  41. std::io::ErrorKind::UnexpectedEof => Ok(DecodeStatus::None),
  42. _ => Err(err.into()),
  43. },
  44. _ => Err(err.into()),
  45. }
  46. }
  47. };
  48. let payload_len: usize = payload_len.try_into().box_err()?;
  49. if slice.len() < payload_len {
  50. return Ok(DecodeStatus::Reserve(payload_len - slice.len()));
  51. }
  52. Ok(DecodeStatus::Consume(Self::FRAME_LEN_SZ + payload_len))
  53. }
  54. }
  55. macro_rules! attempt {
  56. ($result:expr) => {
  57. match $result {
  58. Ok(value) => value,
  59. Err(err) => return Some(Err(err.into())),
  60. }
  61. };
  62. }
  63. impl<S: AsyncRead + Unpin> CallbackFramed<S> {
  64. pub(crate) async fn next<F: DeserCallback>(
  65. &mut self,
  66. mut callback: F,
  67. ) -> Option<Result<F::Return>> {
  68. let mut total_read = 0;
  69. loop {
  70. if self.buffer.capacity() - self.buffer.len() == 0 {
  71. // If there is no space left in the buffer we reserve additional bytes to ensure
  72. // read_buf doesn't return 0 unless we're at EOF.
  73. self.buffer.reserve(1);
  74. }
  75. let read_ct = attempt!(self.io.read_buf(&mut self.buffer).await);
  76. if 0 == read_ct {
  77. return None;
  78. }
  79. total_read += read_ct;
  80. match attempt!(Self::decode(&self.buffer[..total_read]).await) {
  81. DecodeStatus::None => continue,
  82. DecodeStatus::Reserve(count) => {
  83. self.buffer.reserve(count);
  84. continue;
  85. }
  86. DecodeStatus::Consume(consume) => {
  87. let start = self.buffer.split_to(consume);
  88. let arg: F::Arg<'_> = attempt!(from_slice(&start[Self::FRAME_LEN_SZ..]));
  89. let returned = callback.call(arg).await;
  90. return Some(Ok(returned));
  91. }
  92. }
  93. }
  94. }
  95. }
  96. enum DecodeStatus {
  97. None,
  98. Reserve(usize),
  99. Consume(usize),
  100. }
  101. /// A trait for types which can be called to asynchronously handle deserialization. This trait is
  102. /// what enables zero-copy handling of messages which support borrowing data during deserialization.
  103. pub trait DeserCallback {
  104. type Arg<'de>: 'de + Deserialize<'de> + Send
  105. where
  106. Self: 'de;
  107. type Return;
  108. type CallFut<'de>: 'de + Future<Output = Self::Return> + Send
  109. where
  110. Self: 'de;
  111. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de>;
  112. }
  113. impl<'a, T: DeserCallback> DeserCallback for &'a mut T {
  114. type Arg<'de> = T::Arg<'de> where T: 'de, 'a: 'de;
  115. type Return = T::Return;
  116. type CallFut<'de> = T::CallFut<'de> where T: 'de, 'a: 'de;
  117. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  118. (*self).call(arg)
  119. }
  120. }
  121. /// Encodes messages using [btserde].
  122. #[derive(Debug)]
  123. pub(crate) struct MsgEncoder;
  124. impl MsgEncoder {
  125. pub(crate) fn new() -> Self {
  126. Self
  127. }
  128. }
  129. impl<T: Serialize> Encoder<T> for MsgEncoder {
  130. type Error = btlib::Error;
  131. fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<()> {
  132. const U64_LEN: usize = std::mem::size_of::<u64>();
  133. let payload = dst.split_off(U64_LEN);
  134. let mut writer = payload.writer();
  135. write_to(&item, &mut writer)?;
  136. let payload = writer.into_inner();
  137. let payload_len = payload.len() as u64;
  138. let mut writer = dst.writer();
  139. write_to(&payload_len, &mut writer)?;
  140. let dst = writer.into_inner();
  141. dst.unsplit(payload);
  142. Ok(())
  143. }
  144. }
  145. #[cfg(test)]
  146. mod tests {
  147. use super::*;
  148. use futures::{future::Ready, SinkExt};
  149. use serde::Serialize;
  150. use std::{
  151. future::ready,
  152. io::{Cursor, Seek},
  153. task::Poll,
  154. };
  155. use tokio_util::codec::FramedWrite;
  156. #[derive(Serialize, Deserialize)]
  157. struct Msg<'a>(&'a [u8]);
  158. #[tokio::test]
  159. async fn read_single_message() {
  160. macro_rules! test_data {
  161. () => {
  162. b"fulcrum"
  163. };
  164. }
  165. #[derive(Clone)]
  166. struct TestCb;
  167. impl DeserCallback for TestCb {
  168. type Arg<'de> = Msg<'de> where Self: 'de;
  169. type Return = bool;
  170. type CallFut<'de> = Ready<Self::Return>;
  171. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  172. futures::future::ready(arg.0 == test_data!())
  173. }
  174. }
  175. let mut write = FramedWrite::new(Cursor::new(Vec::<u8>::new()), MsgEncoder);
  176. write.send(Msg(test_data!())).await.unwrap();
  177. let mut io = write.into_inner();
  178. io.rewind().unwrap();
  179. let mut read = CallbackFramed::new(io);
  180. let matched = read.next(TestCb).await.unwrap().unwrap();
  181. assert!(matched)
  182. }
  183. struct WindowedCursor {
  184. window_sz: usize,
  185. pos: usize,
  186. buf: Vec<u8>,
  187. }
  188. impl WindowedCursor {
  189. fn new(data: Vec<u8>, window_sz: usize) -> Self {
  190. WindowedCursor {
  191. window_sz,
  192. pos: 0,
  193. buf: data,
  194. }
  195. }
  196. }
  197. impl AsyncRead for WindowedCursor {
  198. fn poll_read(
  199. mut self: std::pin::Pin<&mut Self>,
  200. _cx: &mut std::task::Context<'_>,
  201. buf: &mut tokio::io::ReadBuf<'_>,
  202. ) -> std::task::Poll<std::io::Result<()>> {
  203. let end = self.buf.len().min(self.pos + self.window_sz);
  204. let window = &self.buf[self.pos..end];
  205. buf.put_slice(window);
  206. self.as_mut().pos += window.len();
  207. Poll::Ready(Ok(()))
  208. }
  209. }
  210. struct CopyCallback;
  211. impl DeserCallback for CopyCallback {
  212. type Arg<'de> = Msg<'de>;
  213. type Return = Vec<u8>;
  214. type CallFut<'de> = std::future::Ready<Self::Return>;
  215. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  216. ready(arg.0.to_owned())
  217. }
  218. }
  219. #[tokio::test]
  220. async fn read_in_multiple_parts() {
  221. const EXPECTED: &[u8] = b"We live in the most interesting of times.";
  222. let mut write = FramedWrite::new(Cursor::new(Vec::<u8>::new()), MsgEncoder);
  223. write.send(Msg(EXPECTED)).await.unwrap();
  224. let data = write.into_inner().into_inner();
  225. // This will force the CallbackFramed to read the message in multiple iterations.
  226. let io = WindowedCursor::new(data, EXPECTED.len() / 2);
  227. let mut read = CallbackFramed::new(io);
  228. let actual = read.next(CopyCallback).await.unwrap().unwrap();
  229. assert_eq!(EXPECTED, &actual);
  230. }
  231. }