buf_reader.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. use log::error;
  2. use positioned_io::Size;
  3. use std::io::{self, Cursor, Read, Seek, SeekFrom};
  4. use crate::{bterr, Decompose, Result, Sectored, Split, EMPTY_SLICE};
  5. pub use private::BufReader;
  6. mod private {
  7. use crate::{ReadExt, SeekFromExt, SizeExt};
  8. use super::*;
  9. pub struct BufReader<T> {
  10. cursor: Cursor<Vec<u8>>,
  11. reader: T,
  12. }
  13. impl<T> BufReader<T> {
  14. pub fn with_buf(buf: Vec<u8>, reader: T) -> Result<BufReader<T>> {
  15. Ok(Self {
  16. cursor: Self::make_cursor(buf)?,
  17. reader,
  18. })
  19. }
  20. pub fn get_ref(&self) -> &T {
  21. &self.reader
  22. }
  23. pub fn get_mut(&mut self) -> &mut T {
  24. &mut self.reader
  25. }
  26. /// Extracts the buffer from this [BufReader]. The [BufReader] which is returned contains
  27. /// an empty buffer.
  28. pub fn take_buf(mut self) -> (Self, Vec<u8>) {
  29. let buf = self.cursor.into_inner();
  30. self.cursor = Cursor::new(Vec::new());
  31. (self, buf)
  32. }
  33. fn make_cursor(buf: Vec<u8>) -> Result<Cursor<Vec<u8>>> {
  34. // If buf is zero-length then a call to read will loop forever.
  35. if buf.is_empty() {
  36. return Err(bterr!("the given vector must be non-empty"));
  37. }
  38. let mut cursor = Cursor::new(buf);
  39. cursor.seek(SeekFrom::End(0)).unwrap();
  40. Ok(cursor)
  41. }
  42. /// Returns true if all bytes have been read from the cursor.
  43. fn cursor_is_empty(&self) -> bool {
  44. let cursor_len = self.cursor.get_ref().len() as u64;
  45. let cursor_pos = self.cursor.position();
  46. cursor_pos >= cursor_len
  47. }
  48. }
  49. impl<T: Sectored> BufReader<T> {
  50. pub fn new(reader: T) -> Result<BufReader<T>> {
  51. let sect_sz = reader.sector_sz();
  52. Ok(Self {
  53. cursor: Self::make_cursor(vec![0u8; sect_sz])?,
  54. reader,
  55. })
  56. }
  57. }
  58. impl<T: Seek> BufReader<T> {
  59. /// Calculates the current position in this stream.
  60. pub fn pos(&mut self) -> io::Result<u64> {
  61. let inner_pos = self.reader.stream_position()?;
  62. // Because the inner stream is ahead of this stream, the current position is the
  63. // position of the inner stream minus the number of bytes remaining in the cursor.
  64. let remaining = self.cursor.get_ref().len() as u64 - self.cursor.position();
  65. let pos = inner_pos - remaining;
  66. Ok(pos)
  67. }
  68. }
  69. impl<T: Read> BufReader<T> {
  70. /// Refills the cursor by reading from the underlying stream.
  71. fn refill(&mut self) -> Result<()> {
  72. self.cursor.rewind()?;
  73. let vec = self.cursor.get_mut();
  74. let read = self.reader.fill_buf(vec)?;
  75. if read == 0 || read == vec.len() {
  76. Ok(())
  77. } else {
  78. Err(bterr!("unexpected number of bytes read: {read}"))
  79. }
  80. }
  81. }
  82. impl<T: Read> Read for BufReader<T> {
  83. fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
  84. if buf.len() == self.sector_sz() && self.cursor_is_empty() {
  85. return self.reader.read(buf);
  86. }
  87. let buf_len_start = buf.len();
  88. while !buf.is_empty() {
  89. let read = match self.cursor.read(buf) {
  90. Ok(read) => read,
  91. Err(err) => {
  92. if buf_len_start == buf.len() {
  93. return Err(err);
  94. } else {
  95. error!("{err}");
  96. break;
  97. }
  98. }
  99. };
  100. buf = &mut buf[read..];
  101. if self.cursor_is_empty() {
  102. if let Err(err) = self.refill() {
  103. if buf_len_start == buf.len() {
  104. return Err(err.into());
  105. } else {
  106. error!("error occurred in BufReader::refill: {err}");
  107. break;
  108. }
  109. }
  110. }
  111. }
  112. Ok(buf_len_start - buf.len())
  113. }
  114. }
  115. impl<T: Seek + Read + Size> Seek for BufReader<T> {
  116. fn seek(&mut self, seek_from: std::io::SeekFrom) -> std::io::Result<u64> {
  117. let pos = self.pos()?;
  118. let new_pos = seek_from.abs(|| Ok(pos), || self.reader.size_or_err())?;
  119. let sect_sz = self.sector_sz64();
  120. let buf_pos = new_pos % sect_sz;
  121. let index = pos / sect_sz;
  122. let new_index = new_pos / sect_sz;
  123. if index != new_index {
  124. // Seek to the new position and invalidate the buffer.
  125. self.reader.seek(SeekFrom::Start(sect_sz * new_index))?;
  126. self.cursor.seek(SeekFrom::End(0))?;
  127. }
  128. if buf_pos != 0 {
  129. // If the buffer position is not at the end then we must refill it.
  130. self.refill()?;
  131. self.cursor.seek(SeekFrom::Start(buf_pos))?;
  132. }
  133. Ok(new_pos)
  134. }
  135. }
  136. impl<U, T: AsRef<U>> AsRef<U> for BufReader<T> {
  137. fn as_ref(&self) -> &U {
  138. self.reader.as_ref()
  139. }
  140. }
  141. impl<U, T: AsMut<U>> AsMut<U> for BufReader<T> {
  142. fn as_mut(&mut self) -> &mut U {
  143. self.reader.as_mut()
  144. }
  145. }
  146. impl<T: Size> Size for BufReader<T> {
  147. fn size(&self) -> std::io::Result<Option<u64>> {
  148. self.reader.size()
  149. }
  150. }
  151. impl<T> Sectored for BufReader<T> {
  152. fn sector_sz(&self) -> usize {
  153. self.cursor.get_ref().len()
  154. }
  155. }
  156. impl<T> Decompose<T> for BufReader<T> {
  157. fn into_inner(self) -> T {
  158. self.reader
  159. }
  160. }
  161. impl<T> Split<BufReader<&'static [u8]>, T> for BufReader<T> {
  162. fn split(self) -> (BufReader<&'static [u8]>, T) {
  163. let reader = BufReader {
  164. cursor: self.cursor,
  165. reader: EMPTY_SLICE,
  166. };
  167. (reader, self.reader)
  168. }
  169. fn combine(left: BufReader<&'static [u8]>, right: T) -> Self {
  170. BufReader {
  171. cursor: left.cursor,
  172. reader: right,
  173. }
  174. }
  175. }
  176. }
  177. #[cfg(test)]
  178. mod tests {
  179. use super::*;
  180. use crate::test_helpers::{
  181. random_indices, read_check, read_indices, write_fill, write_indices, Randomizer,
  182. SectoredCursor,
  183. };
  184. #[test]
  185. fn can_read() {
  186. const EXPECTED: [u8; 32] = [1u8; 32];
  187. let mut reader = BufReader::new(SectoredCursor::new(EXPECTED, EXPECTED.len())).unwrap();
  188. let mut actual = [0u8; EXPECTED.len()];
  189. reader.read(actual.as_mut()).expect("read failed");
  190. assert_eq!(EXPECTED, actual);
  191. }
  192. /// Tests that the inner [Read] only sees calls to `read` with sector sized buffers.
  193. #[test]
  194. fn inner_sees_only_sector_sized_reads() {
  195. const SECT_SZ: usize = 32;
  196. const CHUNK_SZ: usize = 8;
  197. const CHUNKS: usize = SECT_SZ / CHUNK_SZ;
  198. let data = std::iter::successors(Some(1u8), |prev| Some(*prev + 1))
  199. .map(|e| [e; CHUNK_SZ])
  200. .take(CHUNKS)
  201. .fold(Vec::with_capacity(SECT_SZ), |mut prev, curr| {
  202. prev.extend_from_slice(&curr);
  203. prev
  204. });
  205. // SectoredCursor will panic if it's given a buffer that isn't exactly SECT_SZ bytes long.
  206. let mut reader =
  207. BufReader::new(SectoredCursor::new(data, SECT_SZ).require_sect_sz(true)).unwrap();
  208. let mut actual = [0u8; CHUNK_SZ];
  209. for k in 1..(CHUNKS + 1) {
  210. let expected = [k as u8; CHUNK_SZ];
  211. reader.read(&mut actual).expect("read failed");
  212. assert_eq!(expected, actual);
  213. }
  214. }
  215. #[test]
  216. fn sequential_read() {
  217. const SECT_SZ: usize = 16;
  218. const SECT_CT: usize = 8;
  219. let mut cursor = SectoredCursor::new(Vec::new(), SECT_SZ);
  220. write_fill(&mut cursor, SECT_SZ, SECT_CT);
  221. cursor.rewind().unwrap();
  222. let mut reader = BufReader::new(cursor).unwrap();
  223. read_check(&mut reader, SECT_SZ, SECT_CT);
  224. }
  225. /// Tests that a read which is larger than one sector will be handled correctly.
  226. #[test]
  227. fn read_larger_than_one_sector() {
  228. const SECT_SZ: usize = 4;
  229. const DATA: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
  230. let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
  231. let mut actual = [0u8; 6];
  232. reader.read(&mut actual).expect("read failed");
  233. assert_eq!([0, 1, 2, 3, 4, 5], actual);
  234. }
  235. #[test]
  236. fn random_sector_sized_read() {
  237. const SECT_SZ: usize = 32;
  238. const SECT_CT: usize = 10;
  239. let mut rando = Randomizer::new([7u8; Randomizer::HASH.len()]);
  240. let indices: Vec<_> = random_indices(&mut rando, SECT_CT).collect();
  241. let mut cursor = SectoredCursor::new(vec![0u8; SECT_SZ * SECT_CT], SECT_SZ);
  242. write_indices(&mut cursor, SECT_SZ, indices.iter().cloned());
  243. let mut reader = BufReader::new(cursor).unwrap();
  244. read_indices(&mut reader, SECT_SZ, indices.iter().cloned());
  245. }
  246. #[test]
  247. fn seek_with_empty_buffer() {
  248. const SECT_SZ: usize = 4;
  249. const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
  250. const EXPECTED: u64 = 3;
  251. let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
  252. let new_pos = reader.seek(SeekFrom::Start(EXPECTED)).expect("seek failed");
  253. assert_eq!(EXPECTED, new_pos);
  254. let mut actual = [0u8; 1];
  255. reader.read(&mut actual).expect("read failed");
  256. assert_eq!(EXPECTED as u8, actual[0]);
  257. }
  258. #[test]
  259. fn seek_to_middle_of_next_sector() {
  260. const SECT_SZ: usize = 4;
  261. const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
  262. const EXPECTED: u64 = 5;
  263. let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
  264. let mut actual = [0u8; 1];
  265. // This first read ensures the buffer is filled.
  266. reader.read(&mut actual).expect("first read failed");
  267. let new_pos = reader.seek(SeekFrom::Start(EXPECTED)).expect("seek failed");
  268. assert_eq!(EXPECTED, new_pos);
  269. reader.read(&mut actual).expect("read failed");
  270. assert_eq!(EXPECTED as u8, actual[0]);
  271. }
  272. #[test]
  273. fn seek_relative_to_current_position() {
  274. const SECT_SZ: usize = 4;
  275. const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
  276. const EXPECTED: u64 = 3;
  277. let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
  278. let mut actual = [0u8; SECT_SZ];
  279. reader.read(&mut actual).expect("first read failed");
  280. let new_pos = reader.seek(SeekFrom::Current(-1)).expect("seek failed");
  281. assert_eq!(EXPECTED, new_pos);
  282. reader.read(&mut actual).expect("read failed");
  283. assert_eq!(EXPECTED as u8, actual[0]);
  284. }
  285. #[test]
  286. fn seek_relative_to_end() {
  287. const SECT_SZ: usize = 4;
  288. const DATA: [u8; 2 * SECT_SZ] = [0, 1, 2, 3, 4, 5, 6, 7];
  289. const EXPECTED: u64 = 7;
  290. let mut reader = BufReader::new(SectoredCursor::new(DATA, SECT_SZ)).unwrap();
  291. let mut actual = [0u8; SECT_SZ];
  292. reader.read(&mut actual).expect("first read failed");
  293. let new_pos = reader.seek(SeekFrom::End(-1)).expect("seek failed");
  294. assert_eq!(EXPECTED, new_pos);
  295. reader.read(&mut actual).expect("read failed");
  296. assert_eq!(EXPECTED as u8, actual[0]);
  297. }
  298. }