buf_reader.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. // SPDX-License-Identifier: AGPL-3.0-or-later
  2. //! This module contains the [BufReader] type.
  3. use log::error;
  4. use positioned_io::Size;
  5. use std::io::{self, Cursor, Read, Seek, SeekFrom};
  6. use crate::{bterr, Decompose, Result, Sectored, Split, EMPTY_SLICE};
  7. pub use private::BufReader;
  8. mod private {
  9. use crate::{ReadExt, SeekFromExt, SizeExt};
  10. use super::*;
  11. /// Nearly identical to [std::io::BufReader] but allows a buffer to be reused between
  12. /// instantiations.
  13. pub struct BufReader<T> {
  14. cursor: Cursor<Vec<u8>>,
  15. reader: T,
  16. }
  17. impl<T> BufReader<T> {
  18. /// Create a new [BufReader] which contains the given buffer and reader type.
  19. pub fn with_buf(buf: Vec<u8>, reader: T) -> Result<BufReader<T>> {
  20. Ok(Self {
  21. cursor: Self::make_cursor(buf)?,
  22. reader,
  23. })
  24. }
  25. /// Returns a reference to the inner reader.
  26. pub fn get_ref(&self) -> &T {
  27. &self.reader
  28. }
  29. /// Extracts the buffer from this [BufReader]. The [BufReader] which is returned contains
  30. /// an empty buffer.
  31. pub fn take_buf(mut self) -> (Self, Vec<u8>) {
  32. let buf = self.cursor.into_inner();
  33. self.cursor = Cursor::new(Vec::new());
  34. (self, buf)
  35. }
  36. fn make_cursor(buf: Vec<u8>) -> Result<Cursor<Vec<u8>>> {
  37. // If buf is zero-length then a call to read will loop forever.
  38. if buf.is_empty() {
  39. return Err(bterr!("the given vector must be non-empty"));
  40. }
  41. let mut cursor = Cursor::new(buf);
  42. cursor.seek(SeekFrom::End(0)).unwrap();
  43. Ok(cursor)
  44. }
  45. /// Returns true if all bytes have been read from the cursor.
  46. fn cursor_is_empty(&self) -> bool {
  47. let cursor_len = self.cursor.get_ref().len() as u64;
  48. let cursor_pos = self.cursor.position();
  49. cursor_pos >= cursor_len
  50. }
  51. }
  52. impl<T: Sectored> BufReader<T> {
  53. /// Creates a new [BufReader] containing the given reader type and a new buffer.
  54. pub fn new(reader: T) -> Result<BufReader<T>> {
  55. let sect_sz = reader.sector_sz();
  56. Ok(Self {
  57. cursor: Self::make_cursor(vec![0u8; sect_sz])?,
  58. reader,
  59. })
  60. }
  61. }
  62. impl<T: Seek> BufReader<T> {
  63. /// Calculates the current position in this stream.
  64. pub fn pos(&mut self) -> io::Result<u64> {
  65. let inner_pos = self.reader.stream_position()?;
  66. // Because the inner stream is ahead of this stream, the current position is the
  67. // position of the inner stream minus the number of bytes remaining in the cursor.
  68. let remaining = self.cursor.get_ref().len() as u64 - self.cursor.position();
  69. let pos = inner_pos - remaining;
  70. Ok(pos)
  71. }
  72. }
  73. impl<T: Read> BufReader<T> {
  74. /// Refills the cursor by reading from the underlying stream.
  75. fn refill(&mut self) -> Result<()> {
  76. self.cursor.rewind()?;
  77. let vec = self.cursor.get_mut();
  78. let read = self.reader.fill_buf(vec)?;
  79. if read == 0 || read == vec.len() {
  80. Ok(())
  81. } else {
  82. Err(bterr!("unexpected number of bytes read: {read}"))
  83. }
  84. }
  85. }
  86. impl<T: Read> Read for BufReader<T> {
  87. fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
  88. if buf.len() == self.sector_sz() && self.cursor_is_empty() {
  89. return self.reader.read(buf);
  90. }
  91. let buf_len_start = buf.len();
  92. while !buf.is_empty() {
  93. let read = match self.cursor.read(buf) {
  94. Ok(read) => read,
  95. Err(err) => {
  96. if buf_len_start == buf.len() {
  97. return Err(err);
  98. } else {
  99. error!("{err}");
  100. break;
  101. }
  102. }
  103. };
  104. buf = &mut buf[read..];
  105. if self.cursor_is_empty() {
  106. if let Err(err) = self.refill() {
  107. if buf_len_start == buf.len() {
  108. return Err(err.into());
  109. } else {
  110. error!("error occurred in BufReader::refill: {err}");
  111. break;
  112. }
  113. }
  114. }
  115. }
  116. Ok(buf_len_start - buf.len())
  117. }
  118. }
  119. impl<T: Seek + Read + Size> Seek for BufReader<T> {
  120. fn seek(&mut self, seek_from: std::io::SeekFrom) -> std::io::Result<u64> {
  121. let pos = self.pos()?;
  122. let new_pos = seek_from.abs(|| Ok(pos), || self.reader.size_or_err())?;
  123. let sect_sz = self.sector_sz64();
  124. let buf_pos = new_pos % sect_sz;
  125. let index = pos / sect_sz;
  126. let new_index = new_pos / sect_sz;
  127. if index != new_index {
  128. // Seek to the new position and invalidate the buffer.
  129. self.reader.seek(SeekFrom::Start(sect_sz * new_index))?;
  130. self.cursor.seek(SeekFrom::End(0))?;
  131. }
  132. if buf_pos != 0 {
  133. // If the buffer position is not at the end then we must refill it.
  134. self.refill()?;
  135. self.cursor.seek(SeekFrom::Start(buf_pos))?;
  136. }
  137. Ok(new_pos)
  138. }
  139. }
  140. impl<U, T: AsRef<U>> AsRef<U> for BufReader<T> {
  141. fn as_ref(&self) -> &U {
  142. self.reader.as_ref()
  143. }
  144. }
  145. impl<U, T: AsMut<U>> AsMut<U> for BufReader<T> {
  146. fn as_mut(&mut self) -> &mut U {
  147. self.reader.as_mut()
  148. }
  149. }
  150. impl<T: Size> Size for BufReader<T> {
  151. fn size(&self) -> std::io::Result<Option<u64>> {
  152. self.reader.size()
  153. }
  154. }
  155. impl<T> Sectored for BufReader<T> {
  156. fn sector_sz(&self) -> usize {
  157. self.cursor.get_ref().len()
  158. }
  159. }
  160. impl<T> Decompose<T> for BufReader<T> {
  161. fn into_inner(self) -> T {
  162. self.reader
  163. }
  164. }
  165. impl<T> Split<BufReader<&'static [u8]>, T> for BufReader<T> {
  166. fn split(self) -> (BufReader<&'static [u8]>, T) {
  167. let reader = BufReader {
  168. cursor: self.cursor,
  169. reader: EMPTY_SLICE,
  170. };
  171. (reader, self.reader)
  172. }
  173. fn combine(left: BufReader<&'static [u8]>, right: T) -> Self {
  174. BufReader {
  175. cursor: left.cursor,
  176. reader: right,
  177. }
  178. }
  179. }
  180. }
  181. #[cfg(test)]
  182. mod tests {
  183. use super::*;
  184. use crate::test_helpers::{
  185. random_indices, read_check, read_indices, write_fill, write_indices, Randomizer,
  186. SectoredCursor,
  187. };
  188. #[test]
  189. fn can_read() {
  190. const EXPECTED: [u8; 32] = [1u8; 32];
  191. let mut reader = BufReader::new(SectoredCursor::new(EXPECTED, EXPECTED.len())).unwrap();
  192. let mut actual = [0u8; EXPECTED.len()];
  193. reader.read(actual.as_mut()).expect("read failed");
  194. assert_eq!(EXPECTED, actual);
  195. }
  196. /// Tests that the inner [Read] only sees calls to `read` with sector sized buffers.
  197. #[test]
  198. fn inner_sees_only_sector_sized_reads() {
  199. const SECT_SZ: usize = 32;
  200. const CHUNK_SZ: usize = 8;
  201. const CHUNKS: usize = SECT_SZ / CHUNK_SZ;
  202. let data = std::iter::successors(Some(1u8), |prev| Some(*prev + 1))
  203. .map(|e| [e; CHUNK_SZ])
  204. .take(CHUNKS)
  205. .fold(Vec::with_capacity(SECT_SZ), |mut prev, curr| {
  206. prev.extend_from_slice(&curr);
  207. prev
  208. });
  209. // SectoredCursor will panic if it's given a buffer that isn't exactly SECT_SZ bytes long.
  210. let mut reader =
  211. BufReader::new(SectoredCursor::new(data, SECT_SZ).require_sect_sz(true)).unwrap();
  212. let mut actual = [0u8; CHUNK_SZ];
  213. for k in 1..(CHUNKS + 1) {
  214. let expected = [k as u8; CHUNK_SZ];
  215. reader.read(&mut actual).expect("read failed");
  216. assert_eq!(expected, actual);
  217. }
  218. }
  219. #[test]
  220. fn sequential_read() {
  221. const SECT_SZ: usize = 16;
  222. const SECT_CT: usize = 8;
  223. let mut cursor = SectoredCursor::new(Vec::new(), SECT_SZ);
  224. write_fill(&mut cursor, SECT_SZ, SECT_CT);
  225. cursor.rewind().unwrap();
  226. let mut reader = BufReader::new(cursor).unwrap();
  227. read_check(&mut reader, SECT_SZ, SECT_CT);
  228. }
  229. /// Tests that a read which is larger than one sector will be handled correctly.
  230. #[test]
  231. fn read_larger_than_one_sector() {
  232. const SECT_SZ: usize = 4;
  233. const DATA: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
  234. let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
  235. let mut actual = [0u8; 6];
  236. reader.read(&mut actual).expect("read failed");
  237. assert_eq!([0, 1, 2, 3, 4, 5], actual);
  238. }
  239. #[test]
  240. fn random_sector_sized_read() {
  241. const SECT_SZ: usize = 32;
  242. const SECT_CT: usize = 10;
  243. let mut rando = Randomizer::new([7u8; Randomizer::HASH.len()]);
  244. let indices: Vec<_> = random_indices(&mut rando, SECT_CT).collect();
  245. let mut cursor = SectoredCursor::new(vec![0u8; SECT_SZ * SECT_CT], SECT_SZ);
  246. write_indices(&mut cursor, SECT_SZ, indices.iter().cloned());
  247. let mut reader = BufReader::new(cursor).unwrap();
  248. read_indices(&mut reader, SECT_SZ, indices.iter().cloned());
  249. }
  250. #[test]
  251. fn seek_with_empty_buffer() {
  252. const SECT_SZ: usize = 4;
  253. const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
  254. const EXPECTED: u64 = 3;
  255. let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
  256. let new_pos = reader.seek(SeekFrom::Start(EXPECTED)).expect("seek failed");
  257. assert_eq!(EXPECTED, new_pos);
  258. let mut actual = [0u8; 1];
  259. reader.read(&mut actual).expect("read failed");
  260. assert_eq!(EXPECTED as u8, actual[0]);
  261. }
  262. #[test]
  263. fn seek_to_middle_of_next_sector() {
  264. const SECT_SZ: usize = 4;
  265. const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
  266. const EXPECTED: u64 = 5;
  267. let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
  268. let mut actual = [0u8; 1];
  269. // This first read ensures the buffer is filled.
  270. reader.read(&mut actual).expect("first read failed");
  271. let new_pos = reader.seek(SeekFrom::Start(EXPECTED)).expect("seek failed");
  272. assert_eq!(EXPECTED, new_pos);
  273. reader.read(&mut actual).expect("read failed");
  274. assert_eq!(EXPECTED as u8, actual[0]);
  275. }
  276. #[test]
  277. fn seek_relative_to_current_position() {
  278. const SECT_SZ: usize = 4;
  279. const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
  280. const EXPECTED: u64 = 3;
  281. let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
  282. let mut actual = [0u8; SECT_SZ];
  283. reader.read(&mut actual).expect("first read failed");
  284. let new_pos = reader.seek(SeekFrom::Current(-1)).expect("seek failed");
  285. assert_eq!(EXPECTED, new_pos);
  286. reader.read(&mut actual).expect("read failed");
  287. assert_eq!(EXPECTED as u8, actual[0]);
  288. }
  289. #[test]
  290. fn seek_relative_to_end() {
  291. const SECT_SZ: usize = 4;
  292. const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
  293. const EXPECTED: u64 = 7;
  294. let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
  295. let mut actual = [0u8; SECT_SZ];
  296. reader.read(&mut actual).expect("first read failed");
  297. let new_pos = reader.seek(SeekFrom::End(-1)).expect("seek failed");
  298. assert_eq!(EXPECTED, new_pos);
  299. reader.read(&mut actual).expect("read failed");
  300. assert_eq!(EXPECTED as u8, actual[0]);
  301. }
  302. }