|
- // SPDX-License-Identifier: AGPL-3.0-or-later
- //! This module contains the [BufReader] type.
- use log::error;
- use positioned_io::Size;
- use std::io::{self, Cursor, Read, Seek, SeekFrom};
- use crate::{bterr, Decompose, Result, Sectored, Split, EMPTY_SLICE};
- pub use private::BufReader;
- mod private {
- use crate::{ReadExt, SeekFromExt, SizeExt};
- use super::*;
- /// Nearly identical to [std::io::BufReader] but allows a buffer to be reused between
- /// instantiations.
- pub struct BufReader<T> {
- cursor: Cursor<Vec<u8>>,
- reader: T,
- }
- impl<T> BufReader<T> {
- /// Create a new [BufReader] which contains the given buffer and reader type.
- pub fn with_buf(buf: Vec<u8>, reader: T) -> Result<BufReader<T>> {
- Ok(Self {
- cursor: Self::make_cursor(buf)?,
- reader,
- })
- }
- /// Returns a reference to the inner reader.
- pub fn get_ref(&self) -> &T {
- &self.reader
- }
- /// Extracts the buffer from this [BufReader]. The [BufReader] which is returned contains
- /// an empty buffer.
- pub fn take_buf(mut self) -> (Self, Vec<u8>) {
- let buf = self.cursor.into_inner();
- self.cursor = Cursor::new(Vec::new());
- (self, buf)
- }
- fn make_cursor(buf: Vec<u8>) -> Result<Cursor<Vec<u8>>> {
- // If buf is zero-length then a call to read will loop forever.
- if buf.is_empty() {
- return Err(bterr!("the given vector must be non-empty"));
- }
- let mut cursor = Cursor::new(buf);
- cursor.seek(SeekFrom::End(0)).unwrap();
- Ok(cursor)
- }
- /// Returns true if all bytes have been read from the cursor.
- fn cursor_is_empty(&self) -> bool {
- let cursor_len = self.cursor.get_ref().len() as u64;
- let cursor_pos = self.cursor.position();
- cursor_pos >= cursor_len
- }
- }
- impl<T: Sectored> BufReader<T> {
- /// Creates a new [BufReader] containing the given reader type and a new buffer.
- pub fn new(reader: T) -> Result<BufReader<T>> {
- let sect_sz = reader.sector_sz();
- Ok(Self {
- cursor: Self::make_cursor(vec![0u8; sect_sz])?,
- reader,
- })
- }
- }
- impl<T: Seek> BufReader<T> {
- /// Calculates the current position in this stream.
- pub fn pos(&mut self) -> io::Result<u64> {
- let inner_pos = self.reader.stream_position()?;
- // Because the inner stream is ahead of this stream, the current position is the
- // position of the inner stream minus the number of bytes remaining in the cursor.
- let remaining = self.cursor.get_ref().len() as u64 - self.cursor.position();
- let pos = inner_pos - remaining;
- Ok(pos)
- }
- }
- impl<T: Read> BufReader<T> {
- /// Refills the cursor by reading from the underlying stream.
- fn refill(&mut self) -> Result<()> {
- self.cursor.rewind()?;
- let vec = self.cursor.get_mut();
- let read = self.reader.fill_buf(vec)?;
- if read == 0 || read == vec.len() {
- Ok(())
- } else {
- Err(bterr!("unexpected number of bytes read: {read}"))
- }
- }
- }
- impl<T: Read> Read for BufReader<T> {
- fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
- if buf.len() == self.sector_sz() && self.cursor_is_empty() {
- return self.reader.read(buf);
- }
- let buf_len_start = buf.len();
- while !buf.is_empty() {
- let read = match self.cursor.read(buf) {
- Ok(read) => read,
- Err(err) => {
- if buf_len_start == buf.len() {
- return Err(err);
- } else {
- error!("{err}");
- break;
- }
- }
- };
- buf = &mut buf[read..];
- if self.cursor_is_empty() {
- if let Err(err) = self.refill() {
- if buf_len_start == buf.len() {
- return Err(err.into());
- } else {
- error!("error occurred in BufReader::refill: {err}");
- break;
- }
- }
- }
- }
- Ok(buf_len_start - buf.len())
- }
- }
- impl<T: Seek + Read + Size> Seek for BufReader<T> {
- fn seek(&mut self, seek_from: std::io::SeekFrom) -> std::io::Result<u64> {
- let pos = self.pos()?;
- let new_pos = seek_from.abs(|| Ok(pos), || self.reader.size_or_err())?;
- let sect_sz = self.sector_sz64();
- let buf_pos = new_pos % sect_sz;
- let index = pos / sect_sz;
- let new_index = new_pos / sect_sz;
- if index != new_index {
- // Seek to the new position and invalidate the buffer.
- self.reader.seek(SeekFrom::Start(sect_sz * new_index))?;
- self.cursor.seek(SeekFrom::End(0))?;
- }
- if buf_pos != 0 {
- // If the buffer position is not at the end then we must refill it.
- self.refill()?;
- self.cursor.seek(SeekFrom::Start(buf_pos))?;
- }
- Ok(new_pos)
- }
- }
- impl<U, T: AsRef<U>> AsRef<U> for BufReader<T> {
- fn as_ref(&self) -> &U {
- self.reader.as_ref()
- }
- }
- impl<U, T: AsMut<U>> AsMut<U> for BufReader<T> {
- fn as_mut(&mut self) -> &mut U {
- self.reader.as_mut()
- }
- }
- impl<T: Size> Size for BufReader<T> {
- fn size(&self) -> std::io::Result<Option<u64>> {
- self.reader.size()
- }
- }
- impl<T> Sectored for BufReader<T> {
- fn sector_sz(&self) -> usize {
- self.cursor.get_ref().len()
- }
- }
- impl<T> Decompose<T> for BufReader<T> {
- fn into_inner(self) -> T {
- self.reader
- }
- }
- impl<T> Split<BufReader<&'static [u8]>, T> for BufReader<T> {
- fn split(self) -> (BufReader<&'static [u8]>, T) {
- let reader = BufReader {
- cursor: self.cursor,
- reader: EMPTY_SLICE,
- };
- (reader, self.reader)
- }
- fn combine(left: BufReader<&'static [u8]>, right: T) -> Self {
- BufReader {
- cursor: left.cursor,
- reader: right,
- }
- }
- }
- }
- #[cfg(test)]
- mod tests {
- use super::*;
- use crate::test_helpers::{
- random_indices, read_check, read_indices, write_fill, write_indices, Randomizer,
- SectoredCursor,
- };
- #[test]
- fn can_read() {
- const EXPECTED: [u8; 32] = [1u8; 32];
- let mut reader = BufReader::new(SectoredCursor::new(EXPECTED, EXPECTED.len())).unwrap();
- let mut actual = [0u8; EXPECTED.len()];
- reader.read(actual.as_mut()).expect("read failed");
- assert_eq!(EXPECTED, actual);
- }
- /// Tests that the inner [Read] only sees calls to `read` with sector sized buffers.
- #[test]
- fn inner_sees_only_sector_sized_reads() {
- const SECT_SZ: usize = 32;
- const CHUNK_SZ: usize = 8;
- const CHUNKS: usize = SECT_SZ / CHUNK_SZ;
- let data = std::iter::successors(Some(1u8), |prev| Some(*prev + 1))
- .map(|e| [e; CHUNK_SZ])
- .take(CHUNKS)
- .fold(Vec::with_capacity(SECT_SZ), |mut prev, curr| {
- prev.extend_from_slice(&curr);
- prev
- });
- // SectoredCursor will panic if it's given a buffer that isn't exactly SECT_SZ bytes long.
- let mut reader =
- BufReader::new(SectoredCursor::new(data, SECT_SZ).require_sect_sz(true)).unwrap();
- let mut actual = [0u8; CHUNK_SZ];
- for k in 1..(CHUNKS + 1) {
- let expected = [k as u8; CHUNK_SZ];
- reader.read(&mut actual).expect("read failed");
- assert_eq!(expected, actual);
- }
- }
- #[test]
- fn sequential_read() {
- const SECT_SZ: usize = 16;
- const SECT_CT: usize = 8;
- let mut cursor = SectoredCursor::new(Vec::new(), SECT_SZ);
- write_fill(&mut cursor, SECT_SZ, SECT_CT);
- cursor.rewind().unwrap();
- let mut reader = BufReader::new(cursor).unwrap();
- read_check(&mut reader, SECT_SZ, SECT_CT);
- }
- /// Tests that a read which is larger than one sector will be handled correctly.
- #[test]
- fn read_larger_than_one_sector() {
- const SECT_SZ: usize = 4;
- const DATA: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
- let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
- let mut actual = [0u8; 6];
- reader.read(&mut actual).expect("read failed");
- assert_eq!([0, 1, 2, 3, 4, 5], actual);
- }
- #[test]
- fn random_sector_sized_read() {
- const SECT_SZ: usize = 32;
- const SECT_CT: usize = 10;
- let mut rando = Randomizer::new([7u8; Randomizer::HASH.len()]);
- let indices: Vec<_> = random_indices(&mut rando, SECT_CT).collect();
- let mut cursor = SectoredCursor::new(vec![0u8; SECT_SZ * SECT_CT], SECT_SZ);
- write_indices(&mut cursor, SECT_SZ, indices.iter().cloned());
- let mut reader = BufReader::new(cursor).unwrap();
- read_indices(&mut reader, SECT_SZ, indices.iter().cloned());
- }
- #[test]
- fn seek_with_empty_buffer() {
- const SECT_SZ: usize = 4;
- const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
- const EXPECTED: u64 = 3;
- let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
- let new_pos = reader.seek(SeekFrom::Start(EXPECTED)).expect("seek failed");
- assert_eq!(EXPECTED, new_pos);
- let mut actual = [0u8; 1];
- reader.read(&mut actual).expect("read failed");
- assert_eq!(EXPECTED as u8, actual[0]);
- }
- #[test]
- fn seek_to_middle_of_next_sector() {
- const SECT_SZ: usize = 4;
- const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
- const EXPECTED: u64 = 5;
- let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
- let mut actual = [0u8; 1];
- // This first read ensures the buffer is filled.
- reader.read(&mut actual).expect("first read failed");
- let new_pos = reader.seek(SeekFrom::Start(EXPECTED)).expect("seek failed");
- assert_eq!(EXPECTED, new_pos);
- reader.read(&mut actual).expect("read failed");
- assert_eq!(EXPECTED as u8, actual[0]);
- }
- #[test]
- fn seek_relative_to_current_position() {
- const SECT_SZ: usize = 4;
- const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
- const EXPECTED: u64 = 3;
- let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
- let mut actual = [0u8; SECT_SZ];
- reader.read(&mut actual).expect("first read failed");
- let new_pos = reader.seek(SeekFrom::Current(-1)).expect("seek failed");
- assert_eq!(EXPECTED, new_pos);
- reader.read(&mut actual).expect("read failed");
- assert_eq!(EXPECTED as u8, actual[0]);
- }
- #[test]
- fn seek_relative_to_end() {
- const SECT_SZ: usize = 4;
- const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
- const EXPECTED: u64 = 7;
- let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
- let mut actual = [0u8; SECT_SZ];
- reader.read(&mut actual).expect("first read failed");
- let new_pos = reader.seek(SeekFrom::End(-1)).expect("seek failed");
- assert_eq!(EXPECTED, new_pos);
- reader.read(&mut actual).expect("read failed");
- assert_eq!(EXPECTED as u8, actual[0]);
- }
- }
|