123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230 |
- // SPDX-License-Identifier: AGPL-3.0-or-later
- use btlib::{error::BoxInIoErr, Result};
- use btserde::{from_slice, read_from};
- use bytes::BytesMut;
- use futures::Future;
- use serde::Deserialize;
- use tokio::io::{AsyncRead, AsyncReadExt};
- pub struct CallbackFramed<I> {
- io: I,
- buffer: BytesMut,
- }
- impl<I> CallbackFramed<I> {
- const INIT_CAPACITY: usize = 4096;
- /// The number of bytes used to encode the length of each frame.
- const FRAME_LEN_SZ: usize = std::mem::size_of::<u64>();
- pub fn new(inner: I) -> Self {
- Self {
- io: inner,
- buffer: BytesMut::with_capacity(Self::INIT_CAPACITY),
- }
- }
- pub fn into_parts(self) -> (I, BytesMut) {
- (self.io, self.buffer)
- }
- pub fn from_parts(io: I, mut buffer: BytesMut) -> Self {
- if buffer.capacity() < Self::INIT_CAPACITY {
- buffer.reserve(Self::INIT_CAPACITY - buffer.capacity());
- }
- Self { io, buffer }
- }
- async fn decode(mut slice: &[u8]) -> Result<DecodeStatus> {
- let payload_len: u64 = match read_from(&mut slice) {
- Ok(payload_len) => payload_len,
- Err(err) => {
- return match err {
- btserde::Error::Eof => Ok(DecodeStatus::None),
- btserde::Error::Io(ref io_err) => match io_err.kind() {
- std::io::ErrorKind::UnexpectedEof => Ok(DecodeStatus::None),
- _ => Err(err.into()),
- },
- _ => Err(err.into()),
- }
- }
- };
- let payload_len: usize = payload_len.try_into().box_err()?;
- if slice.len() < payload_len {
- return Ok(DecodeStatus::Reserve(payload_len - slice.len()));
- }
- Ok(DecodeStatus::Consume(Self::FRAME_LEN_SZ + payload_len))
- }
- }
- macro_rules! attempt {
- ($result:expr) => {
- match $result {
- Ok(value) => value,
- Err(err) => return Some(Err(err.into())),
- }
- };
- }
- impl<S: AsyncRead + Unpin> CallbackFramed<S> {
- pub async fn next<F: DeserCallback>(&mut self, mut callback: F) -> Option<Result<F::Return>> {
- let mut total_read = 0;
- loop {
- if self.buffer.capacity() - self.buffer.len() == 0 {
- // If there is no space left in the buffer we reserve additional bytes to ensure
- // read_buf doesn't return 0 unless we're at EOF.
- self.buffer.reserve(1);
- }
- let read_ct = attempt!(self.io.read_buf(&mut self.buffer).await);
- if 0 == read_ct {
- return None;
- }
- total_read += read_ct;
- match attempt!(Self::decode(&self.buffer[..total_read]).await) {
- DecodeStatus::None => continue,
- DecodeStatus::Reserve(count) => {
- self.buffer.reserve(count);
- continue;
- }
- DecodeStatus::Consume(consume) => {
- let start = self.buffer.split_to(consume);
- let arg: F::Arg<'_> = attempt!(from_slice(&start[Self::FRAME_LEN_SZ..]));
- let returned = callback.call(arg).await;
- return Some(Ok(returned));
- }
- }
- }
- }
- }
- enum DecodeStatus {
- None,
- Reserve(usize),
- Consume(usize),
- }
- pub trait DeserCallback {
- type Arg<'de>: 'de + Deserialize<'de> + Send
- where
- Self: 'de;
- type Return;
- type CallFut<'de>: 'de + Future<Output = Self::Return> + Send
- where
- Self: 'de;
- fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de>;
- }
- impl<'a, T: DeserCallback> DeserCallback for &'a mut T {
- type Arg<'de> = T::Arg<'de> where T: 'de, 'a: 'de;
- type Return = T::Return;
- type CallFut<'de> = T::CallFut<'de> where T: 'de, 'a: 'de;
- fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
- (*self).call(arg)
- }
- }
- #[cfg(test)]
- mod tests {
- use super::*;
- use crate::MsgEncoder;
- use futures::{future::Ready, SinkExt};
- use serde::Serialize;
- use std::{
- future::ready,
- io::{Cursor, Seek},
- task::Poll,
- };
- use tokio_util::codec::FramedWrite;
- #[derive(Serialize, Deserialize)]
- struct Msg<'a>(&'a [u8]);
- #[tokio::test]
- async fn read_single_message() {
- macro_rules! test_data {
- () => {
- b"fulcrum"
- };
- }
- #[derive(Clone)]
- struct TestCb;
- impl DeserCallback for TestCb {
- type Arg<'de> = Msg<'de> where Self: 'de;
- type Return = bool;
- type CallFut<'de> = Ready<Self::Return>;
- fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
- futures::future::ready(arg.0 == test_data!())
- }
- }
- let mut write = FramedWrite::new(Cursor::new(Vec::<u8>::new()), MsgEncoder);
- write.send(Msg(test_data!())).await.unwrap();
- let mut io = write.into_inner();
- io.rewind().unwrap();
- let mut read = CallbackFramed::new(io);
- let matched = read.next(TestCb).await.unwrap().unwrap();
- assert!(matched)
- }
- struct WindowedCursor {
- window_sz: usize,
- pos: usize,
- buf: Vec<u8>,
- }
- impl WindowedCursor {
- fn new(data: Vec<u8>, window_sz: usize) -> Self {
- WindowedCursor {
- window_sz,
- pos: 0,
- buf: data,
- }
- }
- }
- impl AsyncRead for WindowedCursor {
- fn poll_read(
- mut self: std::pin::Pin<&mut Self>,
- _cx: &mut std::task::Context<'_>,
- buf: &mut tokio::io::ReadBuf<'_>,
- ) -> std::task::Poll<std::io::Result<()>> {
- let end = self.buf.len().min(self.pos + self.window_sz);
- let window = &self.buf[self.pos..end];
- buf.put_slice(window);
- self.as_mut().pos += window.len();
- Poll::Ready(Ok(()))
- }
- }
- struct CopyCallback;
- impl DeserCallback for CopyCallback {
- type Arg<'de> = Msg<'de>;
- type Return = Vec<u8>;
- type CallFut<'de> = std::future::Ready<Self::Return>;
- fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
- ready(arg.0.to_owned())
- }
- }
- #[tokio::test]
- async fn read_in_multiple_parts() {
- const EXPECTED: &[u8] = b"We live in the most interesting of times.";
- let mut write = FramedWrite::new(Cursor::new(Vec::<u8>::new()), MsgEncoder);
- write.send(Msg(EXPECTED)).await.unwrap();
- let data = write.into_inner().into_inner();
- // This will force the CallbackFramed to read the message in multiple iterations.
- let io = WindowedCursor::new(data, EXPECTED.len() / 2);
- let mut read = CallbackFramed::new(io);
- let actual = read.next(CopyCallback).await.unwrap().unwrap();
- assert_eq!(EXPECTED, &actual);
- }
- }
|