callback_framed.rs 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. // SPDX-License-Identifier: AGPL-3.0-or-later
  2. use btlib::{error::BoxInIoErr, Result};
  3. use btserde::{from_slice, read_from};
  4. use bytes::BytesMut;
  5. use futures::Future;
  6. use serde::Deserialize;
  7. use tokio::io::{AsyncRead, AsyncReadExt};
  8. pub struct CallbackFramed<I> {
  9. io: I,
  10. buffer: BytesMut,
  11. }
  12. impl<I> CallbackFramed<I> {
  13. const INIT_CAPACITY: usize = 4096;
  14. /// The number of bytes used to encode the length of each frame.
  15. const FRAME_LEN_SZ: usize = std::mem::size_of::<u64>();
  16. pub fn new(inner: I) -> Self {
  17. Self {
  18. io: inner,
  19. buffer: BytesMut::with_capacity(Self::INIT_CAPACITY),
  20. }
  21. }
  22. pub fn into_parts(self) -> (I, BytesMut) {
  23. (self.io, self.buffer)
  24. }
  25. pub fn from_parts(io: I, mut buffer: BytesMut) -> Self {
  26. if buffer.capacity() < Self::INIT_CAPACITY {
  27. buffer.reserve(Self::INIT_CAPACITY - buffer.capacity());
  28. }
  29. Self { io, buffer }
  30. }
  31. async fn decode(mut slice: &[u8]) -> Result<DecodeStatus> {
  32. let payload_len: u64 = match read_from(&mut slice) {
  33. Ok(payload_len) => payload_len,
  34. Err(err) => {
  35. return match err {
  36. btserde::Error::Eof => Ok(DecodeStatus::None),
  37. btserde::Error::Io(ref io_err) => match io_err.kind() {
  38. std::io::ErrorKind::UnexpectedEof => Ok(DecodeStatus::None),
  39. _ => Err(err.into()),
  40. },
  41. _ => Err(err.into()),
  42. }
  43. }
  44. };
  45. let payload_len: usize = payload_len.try_into().box_err()?;
  46. if slice.len() < payload_len {
  47. return Ok(DecodeStatus::Reserve(payload_len - slice.len()));
  48. }
  49. Ok(DecodeStatus::Consume(Self::FRAME_LEN_SZ + payload_len))
  50. }
  51. }
  52. macro_rules! attempt {
  53. ($result:expr) => {
  54. match $result {
  55. Ok(value) => value,
  56. Err(err) => return Some(Err(err.into())),
  57. }
  58. };
  59. }
  60. impl<S: AsyncRead + Unpin> CallbackFramed<S> {
  61. pub async fn next<F: DeserCallback>(&mut self, mut callback: F) -> Option<Result<F::Return>> {
  62. let mut total_read = 0;
  63. loop {
  64. if self.buffer.capacity() - self.buffer.len() == 0 {
  65. // If there is no space left in the buffer we reserve additional bytes to ensure
  66. // read_buf doesn't return 0 unless we're at EOF.
  67. self.buffer.reserve(1);
  68. }
  69. let read_ct = attempt!(self.io.read_buf(&mut self.buffer).await);
  70. if 0 == read_ct {
  71. return None;
  72. }
  73. total_read += read_ct;
  74. match attempt!(Self::decode(&self.buffer[..total_read]).await) {
  75. DecodeStatus::None => continue,
  76. DecodeStatus::Reserve(count) => {
  77. self.buffer.reserve(count);
  78. continue;
  79. }
  80. DecodeStatus::Consume(consume) => {
  81. let start = self.buffer.split_to(consume);
  82. let arg: F::Arg<'_> = attempt!(from_slice(&start[Self::FRAME_LEN_SZ..]));
  83. let returned = callback.call(arg).await;
  84. return Some(Ok(returned));
  85. }
  86. }
  87. }
  88. }
  89. }
  90. enum DecodeStatus {
  91. None,
  92. Reserve(usize),
  93. Consume(usize),
  94. }
  95. pub trait DeserCallback {
  96. type Arg<'de>: 'de + Deserialize<'de> + Send
  97. where
  98. Self: 'de;
  99. type Return;
  100. type CallFut<'de>: 'de + Future<Output = Self::Return> + Send
  101. where
  102. Self: 'de;
  103. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de>;
  104. }
  105. impl<'a, T: DeserCallback> DeserCallback for &'a mut T {
  106. type Arg<'de> = T::Arg<'de> where T: 'de, 'a: 'de;
  107. type Return = T::Return;
  108. type CallFut<'de> = T::CallFut<'de> where T: 'de, 'a: 'de;
  109. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  110. (*self).call(arg)
  111. }
  112. }
  113. #[cfg(test)]
  114. mod tests {
  115. use super::*;
  116. use crate::MsgEncoder;
  117. use futures::{future::Ready, SinkExt};
  118. use serde::Serialize;
  119. use std::{
  120. future::ready,
  121. io::{Cursor, Seek},
  122. task::Poll,
  123. };
  124. use tokio_util::codec::FramedWrite;
  125. #[derive(Serialize, Deserialize)]
  126. struct Msg<'a>(&'a [u8]);
  127. #[tokio::test]
  128. async fn read_single_message() {
  129. macro_rules! test_data {
  130. () => {
  131. b"fulcrum"
  132. };
  133. }
  134. #[derive(Clone)]
  135. struct TestCb;
  136. impl DeserCallback for TestCb {
  137. type Arg<'de> = Msg<'de> where Self: 'de;
  138. type Return = bool;
  139. type CallFut<'de> = Ready<Self::Return>;
  140. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  141. futures::future::ready(arg.0 == test_data!())
  142. }
  143. }
  144. let mut write = FramedWrite::new(Cursor::new(Vec::<u8>::new()), MsgEncoder);
  145. write.send(Msg(test_data!())).await.unwrap();
  146. let mut io = write.into_inner();
  147. io.rewind().unwrap();
  148. let mut read = CallbackFramed::new(io);
  149. let matched = read.next(TestCb).await.unwrap().unwrap();
  150. assert!(matched)
  151. }
  152. struct WindowedCursor {
  153. window_sz: usize,
  154. pos: usize,
  155. buf: Vec<u8>,
  156. }
  157. impl WindowedCursor {
  158. fn new(data: Vec<u8>, window_sz: usize) -> Self {
  159. WindowedCursor {
  160. window_sz,
  161. pos: 0,
  162. buf: data,
  163. }
  164. }
  165. }
  166. impl AsyncRead for WindowedCursor {
  167. fn poll_read(
  168. mut self: std::pin::Pin<&mut Self>,
  169. _cx: &mut std::task::Context<'_>,
  170. buf: &mut tokio::io::ReadBuf<'_>,
  171. ) -> std::task::Poll<std::io::Result<()>> {
  172. let end = self.buf.len().min(self.pos + self.window_sz);
  173. let window = &self.buf[self.pos..end];
  174. buf.put_slice(window);
  175. self.as_mut().pos += window.len();
  176. Poll::Ready(Ok(()))
  177. }
  178. }
  179. struct CopyCallback;
  180. impl DeserCallback for CopyCallback {
  181. type Arg<'de> = Msg<'de>;
  182. type Return = Vec<u8>;
  183. type CallFut<'de> = std::future::Ready<Self::Return>;
  184. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  185. ready(arg.0.to_owned())
  186. }
  187. }
  188. #[tokio::test]
  189. async fn read_in_multiple_parts() {
  190. const EXPECTED: &[u8] = b"We live in the most interesting of times.";
  191. let mut write = FramedWrite::new(Cursor::new(Vec::<u8>::new()), MsgEncoder);
  192. write.send(Msg(EXPECTED)).await.unwrap();
  193. let data = write.into_inner().into_inner();
  194. // This will force the CallbackFramed to read the message in multiple iterations.
  195. let io = WindowedCursor::new(data, EXPECTED.len() / 2);
  196. let mut read = CallbackFramed::new(io);
  197. let actual = read.next(CopyCallback).await.unwrap().unwrap();
  198. assert_eq!(EXPECTED, &actual);
  199. }
  200. }