// SPDX-License-Identifier: AGPL-3.0-or-later use btlib::{error::BoxInIoErr, Result}; use btserde::{from_slice, read_from}; use bytes::BytesMut; use futures::Future; use serde::Deserialize; use tokio::io::{AsyncRead, AsyncReadExt}; pub struct CallbackFramed { io: I, buffer: BytesMut, } impl CallbackFramed { const INIT_CAPACITY: usize = 4096; /// The number of bytes used to encode the length of each frame. const FRAME_LEN_SZ: usize = std::mem::size_of::(); pub fn new(inner: I) -> Self { Self { io: inner, buffer: BytesMut::with_capacity(Self::INIT_CAPACITY), } } pub fn into_parts(self) -> (I, BytesMut) { (self.io, self.buffer) } pub fn from_parts(io: I, mut buffer: BytesMut) -> Self { if buffer.capacity() < Self::INIT_CAPACITY { buffer.reserve(Self::INIT_CAPACITY - buffer.capacity()); } Self { io, buffer } } async fn decode(mut slice: &[u8]) -> Result { let payload_len: u64 = match read_from(&mut slice) { Ok(payload_len) => payload_len, Err(err) => { return match err { btserde::Error::Eof => Ok(DecodeStatus::None), btserde::Error::Io(ref io_err) => match io_err.kind() { std::io::ErrorKind::UnexpectedEof => Ok(DecodeStatus::None), _ => Err(err.into()), }, _ => Err(err.into()), } } }; let payload_len: usize = payload_len.try_into().box_err()?; if slice.len() < payload_len { return Ok(DecodeStatus::Reserve(payload_len - slice.len())); } Ok(DecodeStatus::Consume(Self::FRAME_LEN_SZ + payload_len)) } } macro_rules! attempt { ($result:expr) => { match $result { Ok(value) => value, Err(err) => return Some(Err(err.into())), } }; } impl CallbackFramed { pub async fn next(&mut self, mut callback: F) -> Option> { let mut total_read = 0; loop { if self.buffer.capacity() - self.buffer.len() == 0 { // If there is no space left in the buffer we reserve additional bytes to ensure // read_buf doesn't return 0 unless we're at EOF. self.buffer.reserve(1); } let read_ct = attempt!(self.io.read_buf(&mut self.buffer).await); if 0 == read_ct { return None; } total_read += read_ct; match attempt!(Self::decode(&self.buffer[..total_read]).await) { DecodeStatus::None => continue, DecodeStatus::Reserve(count) => { self.buffer.reserve(count); continue; } DecodeStatus::Consume(consume) => { let start = self.buffer.split_to(consume); let arg: F::Arg<'_> = attempt!(from_slice(&start[Self::FRAME_LEN_SZ..])); let returned = callback.call(arg).await; return Some(Ok(returned)); } } } } } enum DecodeStatus { None, Reserve(usize), Consume(usize), } pub trait DeserCallback { type Arg<'de>: 'de + Deserialize<'de> + Send where Self: 'de; type Return; type CallFut<'de>: 'de + Future + Send where Self: 'de; fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de>; } impl<'a, T: DeserCallback> DeserCallback for &'a mut T { type Arg<'de> = T::Arg<'de> where T: 'de, 'a: 'de; type Return = T::Return; type CallFut<'de> = T::CallFut<'de> where T: 'de, 'a: 'de; fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> { (*self).call(arg) } } #[cfg(test)] mod tests { use super::*; use crate::MsgEncoder; use futures::{future::Ready, SinkExt}; use serde::Serialize; use std::{ future::ready, io::{Cursor, Seek}, task::Poll, }; use tokio_util::codec::FramedWrite; #[derive(Serialize, Deserialize)] struct Msg<'a>(&'a [u8]); #[tokio::test] async fn read_single_message() { macro_rules! test_data { () => { b"fulcrum" }; } #[derive(Clone)] struct TestCb; impl DeserCallback for TestCb { type Arg<'de> = Msg<'de> where Self: 'de; type Return = bool; type CallFut<'de> = Ready; fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> { futures::future::ready(arg.0 == test_data!()) } } let mut write = FramedWrite::new(Cursor::new(Vec::::new()), MsgEncoder); write.send(Msg(test_data!())).await.unwrap(); let mut io = write.into_inner(); io.rewind().unwrap(); let mut read = CallbackFramed::new(io); let matched = read.next(TestCb).await.unwrap().unwrap(); assert!(matched) } struct WindowedCursor { window_sz: usize, pos: usize, buf: Vec, } impl WindowedCursor { fn new(data: Vec, window_sz: usize) -> Self { WindowedCursor { window_sz, pos: 0, buf: data, } } } impl AsyncRead for WindowedCursor { fn poll_read( mut self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { let end = self.buf.len().min(self.pos + self.window_sz); let window = &self.buf[self.pos..end]; buf.put_slice(window); self.as_mut().pos += window.len(); Poll::Ready(Ok(())) } } struct CopyCallback; impl DeserCallback for CopyCallback { type Arg<'de> = Msg<'de>; type Return = Vec; type CallFut<'de> = std::future::Ready; fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> { ready(arg.0.to_owned()) } } #[tokio::test] async fn read_in_multiple_parts() { const EXPECTED: &[u8] = b"We live in the most interesting of times."; let mut write = FramedWrite::new(Cursor::new(Vec::::new()), MsgEncoder); write.send(Msg(EXPECTED)).await.unwrap(); let data = write.into_inner().into_inner(); // This will force the CallbackFramed to read the message in multiple iterations. let io = WindowedCursor::new(data, EXPECTED.len() / 2); let mut read = CallbackFramed::new(io); let actual = read.next(CopyCallback).await.unwrap().unwrap(); assert_eq!(EXPECTED, &actual); } }