// SPDX-License-Identifier: AGPL-3.0-or-later //! This module contains the [BufReader] type. use log::error; use positioned_io::Size; use std::io::{self, Cursor, Read, Seek, SeekFrom}; use crate::{bterr, Decompose, Result, Sectored, Split, EMPTY_SLICE}; pub use private::BufReader; mod private { use crate::{ReadExt, SeekFromExt, SizeExt}; use super::*; /// Nearly identical to [std::io::BufReader] but allows a buffer to be reused between /// instantiations. pub struct BufReader { cursor: Cursor>, reader: T, } impl BufReader { /// Create a new [BufReader] which contains the given buffer and reader type. pub fn with_buf(buf: Vec, reader: T) -> Result> { Ok(Self { cursor: Self::make_cursor(buf)?, reader, }) } /// Returns a reference to the inner reader. pub fn get_ref(&self) -> &T { &self.reader } /// Extracts the buffer from this [BufReader]. The [BufReader] which is returned contains /// an empty buffer. pub fn take_buf(mut self) -> (Self, Vec) { let buf = self.cursor.into_inner(); self.cursor = Cursor::new(Vec::new()); (self, buf) } fn make_cursor(buf: Vec) -> Result>> { // If buf is zero-length then a call to read will loop forever. if buf.is_empty() { return Err(bterr!("the given vector must be non-empty")); } let mut cursor = Cursor::new(buf); cursor.seek(SeekFrom::End(0)).unwrap(); Ok(cursor) } /// Returns true if all bytes have been read from the cursor. fn cursor_is_empty(&self) -> bool { let cursor_len = self.cursor.get_ref().len() as u64; let cursor_pos = self.cursor.position(); cursor_pos >= cursor_len } } impl BufReader { /// Creates a new [BufReader] containing the given reader type and a new buffer. pub fn new(reader: T) -> Result> { let sect_sz = reader.sector_sz(); Ok(Self { cursor: Self::make_cursor(vec![0u8; sect_sz])?, reader, }) } } impl BufReader { /// Calculates the current position in this stream. pub fn pos(&mut self) -> io::Result { let inner_pos = self.reader.stream_position()?; // Because the inner stream is ahead of this stream, the current position is the // position of the inner stream minus the number of bytes remaining in the cursor. let remaining = self.cursor.get_ref().len() as u64 - self.cursor.position(); let pos = inner_pos - remaining; Ok(pos) } } impl BufReader { /// Refills the cursor by reading from the underlying stream. fn refill(&mut self) -> Result<()> { self.cursor.rewind()?; let vec = self.cursor.get_mut(); let read = self.reader.fill_buf(vec)?; if read == 0 || read == vec.len() { Ok(()) } else { Err(bterr!("unexpected number of bytes read: {read}")) } } } impl Read for BufReader { fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result { if buf.len() == self.sector_sz() && self.cursor_is_empty() { return self.reader.read(buf); } let buf_len_start = buf.len(); while !buf.is_empty() { let read = match self.cursor.read(buf) { Ok(read) => read, Err(err) => { if buf_len_start == buf.len() { return Err(err); } else { error!("{err}"); break; } } }; buf = &mut buf[read..]; if self.cursor_is_empty() { if let Err(err) = self.refill() { if buf_len_start == buf.len() { return Err(err.into()); } else { error!("error occurred in BufReader::refill: {err}"); break; } } } } Ok(buf_len_start - buf.len()) } } impl Seek for BufReader { fn seek(&mut self, seek_from: std::io::SeekFrom) -> std::io::Result { let pos = self.pos()?; let new_pos = seek_from.abs(|| Ok(pos), || self.reader.size_or_err())?; let sect_sz = self.sector_sz64(); let buf_pos = new_pos % sect_sz; let index = pos / sect_sz; let new_index = new_pos / sect_sz; if index != new_index { // Seek to the new position and invalidate the buffer. self.reader.seek(SeekFrom::Start(sect_sz * new_index))?; self.cursor.seek(SeekFrom::End(0))?; } if buf_pos != 0 { // If the buffer position is not at the end then we must refill it. self.refill()?; self.cursor.seek(SeekFrom::Start(buf_pos))?; } Ok(new_pos) } } impl> AsRef for BufReader { fn as_ref(&self) -> &U { self.reader.as_ref() } } impl> AsMut for BufReader { fn as_mut(&mut self) -> &mut U { self.reader.as_mut() } } impl Size for BufReader { fn size(&self) -> std::io::Result> { self.reader.size() } } impl Sectored for BufReader { fn sector_sz(&self) -> usize { self.cursor.get_ref().len() } } impl Decompose for BufReader { fn into_inner(self) -> T { self.reader } } impl Split, T> for BufReader { fn split(self) -> (BufReader<&'static [u8]>, T) { let reader = BufReader { cursor: self.cursor, reader: EMPTY_SLICE, }; (reader, self.reader) } fn combine(left: BufReader<&'static [u8]>, right: T) -> Self { BufReader { cursor: left.cursor, reader: right, } } } } #[cfg(test)] mod tests { use super::*; use crate::test_helpers::{ random_indices, read_check, read_indices, write_fill, write_indices, Randomizer, SectoredCursor, }; #[test] fn can_read() { const EXPECTED: [u8; 32] = [1u8; 32]; let mut reader = BufReader::new(SectoredCursor::new(EXPECTED, EXPECTED.len())).unwrap(); let mut actual = [0u8; EXPECTED.len()]; reader.read(actual.as_mut()).expect("read failed"); assert_eq!(EXPECTED, actual); } /// Tests that the inner [Read] only sees calls to `read` with sector sized buffers. #[test] fn inner_sees_only_sector_sized_reads() { const SECT_SZ: usize = 32; const CHUNK_SZ: usize = 8; const CHUNKS: usize = SECT_SZ / CHUNK_SZ; let data = std::iter::successors(Some(1u8), |prev| Some(*prev + 1)) .map(|e| [e; CHUNK_SZ]) .take(CHUNKS) .fold(Vec::with_capacity(SECT_SZ), |mut prev, curr| { prev.extend_from_slice(&curr); prev }); // SectoredCursor will panic if it's given a buffer that isn't exactly SECT_SZ bytes long. let mut reader = BufReader::new(SectoredCursor::new(data, SECT_SZ).require_sect_sz(true)).unwrap(); let mut actual = [0u8; CHUNK_SZ]; for k in 1..(CHUNKS + 1) { let expected = [k as u8; CHUNK_SZ]; reader.read(&mut actual).expect("read failed"); assert_eq!(expected, actual); } } #[test] fn sequential_read() { const SECT_SZ: usize = 16; const SECT_CT: usize = 8; let mut cursor = SectoredCursor::new(Vec::new(), SECT_SZ); write_fill(&mut cursor, SECT_SZ, SECT_CT); cursor.rewind().unwrap(); let mut reader = BufReader::new(cursor).unwrap(); read_check(&mut reader, SECT_SZ, SECT_CT); } /// Tests that a read which is larger than one sector will be handled correctly. #[test] fn read_larger_than_one_sector() { const SECT_SZ: usize = 4; const DATA: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7]; let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap(); let mut actual = [0u8; 6]; reader.read(&mut actual).expect("read failed"); assert_eq!([0, 1, 2, 3, 4, 5], actual); } #[test] fn random_sector_sized_read() { const SECT_SZ: usize = 32; const SECT_CT: usize = 10; let mut rando = Randomizer::new([7u8; Randomizer::HASH.len()]); let indices: Vec<_> = random_indices(&mut rando, SECT_CT).collect(); let mut cursor = SectoredCursor::new(vec![0u8; SECT_SZ * SECT_CT], SECT_SZ); write_indices(&mut cursor, SECT_SZ, indices.iter().cloned()); let mut reader = BufReader::new(cursor).unwrap(); read_indices(&mut reader, SECT_SZ, indices.iter().cloned()); } #[test] fn seek_with_empty_buffer() { const SECT_SZ: usize = 4; const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7]; const EXPECTED: u64 = 3; let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap(); let new_pos = reader.seek(SeekFrom::Start(EXPECTED)).expect("seek failed"); assert_eq!(EXPECTED, new_pos); let mut actual = [0u8; 1]; reader.read(&mut actual).expect("read failed"); assert_eq!(EXPECTED as u8, actual[0]); } #[test] fn seek_to_middle_of_next_sector() { const SECT_SZ: usize = 4; const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7]; const EXPECTED: u64 = 5; let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap(); let mut actual = [0u8; 1]; // This first read ensures the buffer is filled. reader.read(&mut actual).expect("first read failed"); let new_pos = reader.seek(SeekFrom::Start(EXPECTED)).expect("seek failed"); assert_eq!(EXPECTED, new_pos); reader.read(&mut actual).expect("read failed"); assert_eq!(EXPECTED as u8, actual[0]); } #[test] fn seek_relative_to_current_position() { const SECT_SZ: usize = 4; const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7]; const EXPECTED: u64 = 3; let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap(); let mut actual = [0u8; SECT_SZ]; reader.read(&mut actual).expect("first read failed"); let new_pos = reader.seek(SeekFrom::Current(-1)).expect("seek failed"); assert_eq!(EXPECTED, new_pos); reader.read(&mut actual).expect("read failed"); assert_eq!(EXPECTED as u8, actual[0]); } #[test] fn seek_relative_to_end() { const SECT_SZ: usize = 4; const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7]; const EXPECTED: u64 = 7; let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap(); let mut actual = [0u8; SECT_SZ]; reader.read(&mut actual).expect("first read failed"); let new_pos = reader.seek(SeekFrom::End(-1)).expect("seek failed"); assert_eq!(EXPECTED, new_pos); reader.read(&mut actual).expect("read failed"); assert_eq!(EXPECTED as u8, actual[0]); } }