sign_stream.rs 14 KB


  1. // SPDX-License-Identifier: AGPL-3.0-or-later
  2. pub use private::SignStream;
  3. mod private {
  4. use crate::{
  5. bterr,
  6. crypto::{Error, Result, Signature, Signer, Verifier},
  7. Block, BlockMeta, Decompose, MetaAccess, ReadExt, Sectored,
  8. };
  9. use anyhow::anyhow;
  10. use btserde::{read_from, to_vec, write_to};
  11. use std::io::{self, Read, Seek, SeekFrom, Write};
  12. /// The length of the array in the `SignStream::index_bytes` field.
  13. const INDEX_BYTES_LEN: usize = std::mem::size_of::<u64>();
  14. /// A stream which signs each sector of data written and verifies each sector read.
  15. ///
  16. /// Note that the correct [BlockId] needs to be configured for the inner stream before the
  17. /// first call to `read` or `write`. Upon the first such call the `id_bytes` field will be
  18. /// initialized by serializing the current value of the [BlockId] obtained from the inner stream
  19. /// using the [MetaAccess] trait.
  20. pub struct SignStream<T, C> {
  21. inner: T,
  22. creds: C,
  23. /// The `btserde` serialization of the `BlockId` of this stream. This data is signed
  24. /// along with the index of each sector to ensure sectors cannot be reordered or moved
  25. /// between blocks.
  26. id_bytes: Option<Vec<u8>>,
  27. /// The 0-based index of the next sector to be read or written.
  28. index: u64,
  29. /// The `btserde` serialization of index.
  30. index_bytes: [u8; INDEX_BYTES_LEN],
  31. /// The sector size of this stream. This is the size of the buffer expected by the
  32. /// `read` and `write` methods.
  33. out_sz: usize,
  34. /// The length of the `data` field expected from [Signature] instances.
  35. sig_len: usize,
  36. }
  37. impl<T: Sectored, C: Signer> SignStream<T, C> {
  38. #[allow(dead_code)]
  39. pub fn new(inner: T, creds: C) -> Result<SignStream<T, C>> {
  40. // TODO: This is way too brittle. If creds that produce a different sized signature
  41. // are ever used in the future, then this will break.
  42. let (extra, sig_len) = {
  43. let sig = creds.sign(std::iter::empty())?;
  44. let vec = to_vec(&sig)?;
  45. (vec.len(), sig.data.len())
  46. };
  47. let in_sz = inner.sector_sz();
  48. if in_sz < extra {
  49. return Err(bterr!("sector size is too small"));
  50. }
  51. let out_sz = in_sz - extra;
  52. Ok(SignStream {
  53. inner,
  54. creds,
  55. id_bytes: None,
  56. index: 0,
  57. index_bytes: [0u8; INDEX_BYTES_LEN],
  58. out_sz,
  59. sig_len,
  60. })
  61. }
  62. }
  63. impl<T, C> SignStream<T, C> {
  64. fn out_to_index(&self, outer_offset: u64) -> u64 {
  65. outer_offset / self.out_sz as u64
  66. }
  67. /// Asserts that the `data` field in the given [Signature] is the correct length.
  68. fn assert_sig_len(&self, sig: &Signature) -> Result<()> {
  69. let actual = sig.data.len();
  70. if self.sig_len != actual {
  71. Err(bterr!(Error::IncorrectSize {
  72. expected: self.sig_len,
  73. actual,
  74. }))
  75. } else {
  76. Ok(())
  77. }
  78. }
  79. fn set_index(&mut self, index: u64) -> Result<()> {
  80. let mut slice = self.index_bytes.as_mut_slice();
  81. write_to(&index, &mut slice)?;
  82. self.index = index;
  83. Ok(())
  84. }
  85. fn incr_index(&mut self) -> Result<()> {
  86. self.set_index(self.index + 1)
  87. }
  88. }
  89. impl<T: MetaAccess, C> SignStream<T, C> {
  90. fn sig_input<'b, 's: 'b>(
  91. &'s self,
  92. buf: &'b [u8],
  93. ) -> Result<impl Iterator<Item = &'b [u8]>> {
  94. let id_bytes = self
  95. .id_bytes
  96. .as_ref()
  97. .ok_or_else(|| bterr!("id_bytes has not been initialized"))?;
  98. Ok([id_bytes, self.index_bytes.as_slice(), buf].into_iter())
  99. }
  100. fn init_id_bytes(&mut self) -> Result<()> {
  101. if self.id_bytes.is_none() {
  102. let vec = to_vec(self.inner.meta_body().block_id()?)?;
  103. self.id_bytes = Some(vec);
  104. }
  105. Ok(())
  106. }
  107. }
  108. impl<T: Sectored, C> SignStream<T, C> {
  109. fn index_to_in(&self, index: u64) -> u64 {
  110. self.inner.offset_at(index)
  111. }
  112. fn index_to_out(&self, index: u64) -> u64 {
  113. self.out_sz as u64 * index
  114. }
  115. }
  116. impl<T: Sectored + Seek, C> SignStream<T, C> {
  117. fn reset_inner_pos(&mut self) -> io::Result<u64> {
  118. self.inner
  119. .seek(SeekFrom::Start(self.index_to_in(self.index)))
  120. }
  121. }
  122. impl<T, C> Sectored for SignStream<T, C> {
  123. fn sector_sz(&self) -> usize {
  124. self.out_sz
  125. }
  126. }
  127. impl<T: Write + MetaAccess, C: Signer> Write for SignStream<T, C> {
  128. fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
  129. self.assert_sector_sz(buf.len())?;
  130. self.inner.write_all(buf)?;
  131. self.init_id_bytes()?;
  132. let sig = self.creds.sign(self.sig_input(buf)?)?;
  133. self.assert_sig_len(&sig)?;
  134. write_to(&sig, &mut self.inner)?;
  135. self.incr_index()?;
  136. Ok(self.out_sz)
  137. }
  138. fn flush(&mut self) -> std::io::Result<()> {
  139. self.inner.flush()
  140. }
  141. }
  142. impl<T: Read + Seek + Sectored + MetaAccess, C: Verifier> Read for SignStream<T, C> {
  143. fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
  144. self.assert_at_least_sector_sz(buf.len())?;
  145. let buf = &mut buf[..self.sector_sz()];
  146. let read = self.inner.fill_buf(buf)?;
  147. if 0 == read {
  148. return Ok(0);
  149. }
  150. self.assert_sector_sz(read)?;
  151. let sig: Signature = match read_from(&mut self.inner) {
  152. Ok(sig) => sig,
  153. Err(err) => {
  154. self.reset_inner_pos()?;
  155. return Err(err.into());
  156. }
  157. };
  158. self.init_id_bytes()?;
  159. let result = self.creds.verify(self.sig_input(buf)?, sig.as_slice());
  160. if let Err(err) = result {
  161. self.reset_inner_pos()?;
  162. return Err(bterr!(err).into());
  163. }
  164. if let Err(err) = self.incr_index() {
  165. self.reset_inner_pos()?;
  166. return Err(bterr!(err).into());
  167. }
  168. Ok(read)
  169. }
  170. }
  171. impl<T: Seek + Sectored, C> Seek for SignStream<T, C> {
  172. fn seek(&mut self, out_pos: std::io::SeekFrom) -> std::io::Result<u64> {
  173. let out_pos = match out_pos {
  174. SeekFrom::Start(from_start) => from_start,
  175. SeekFrom::Current(from_curr) => {
  176. self.index_to_out(self.index).wrapping_add_signed(from_curr)
  177. }
  178. SeekFrom::End(_) => {
  179. return Err(crate::Error::new(
  180. anyhow!("seek from end is not supported")
  181. .context(io::ErrorKind::Unsupported),
  182. )
  183. .into());
  184. }
  185. };
  186. let index = self.out_to_index(out_pos);
  187. let in_pos = self.index_to_in(index);
  188. self.inner.seek(SeekFrom::Start(in_pos))?;
  189. self.set_index(index)?;
  190. Ok(out_pos)
  191. }
  192. }
  193. impl<T: AsRef<BlockMeta>, C> AsRef<BlockMeta> for SignStream<T, C> {
  194. fn as_ref(&self) -> &BlockMeta {
  195. self.inner.as_ref()
  196. }
  197. }
  198. impl<T: AsMut<BlockMeta>, C> AsMut<BlockMeta> for SignStream<T, C> {
  199. fn as_mut(&mut self) -> &mut BlockMeta {
  200. self.inner.as_mut()
  201. }
  202. }
  203. impl<T: MetaAccess, C> MetaAccess for SignStream<T, C> {}
  204. impl<T: Block + Sectored, C: Signer + Verifier> Block for SignStream<T, C> {
  205. fn flush_meta(&mut self) -> crate::Result<()> {
  206. self.inner.flush_meta()
  207. }
  208. }
  209. impl<T, C> Decompose<T> for SignStream<T, C> {
  210. fn into_inner(self) -> T {
  211. self.inner
  212. }
  213. }
  214. }
  215. #[cfg(test)]
  216. mod tests {
  217. use std::io::{Read, Seek, SeekFrom, Write};
  218. use super::*;
  219. use crate::{
  220. crypto::ConcreteCreds,
  221. test_helpers::{node_creds, Randomizer, SectoredCursor},
  222. Decompose, Sectored, SECTOR_SZ_DEFAULT,
  223. };
  224. fn sign_stream_with_block_id(
  225. sect_sz: usize,
  226. ) -> SignStream<SectoredCursor<Vec<u8>>, &'static ConcreteCreds> {
  227. let cursor = SectoredCursor::new(Vec::new(), sect_sz).require_sect_sz(false);
  228. SignStream::new(cursor, node_creds()).unwrap()
  229. }
  230. fn sign_stream_with_sz(
  231. sect_sz: usize,
  232. ) -> SignStream<SectoredCursor<Vec<u8>>, &'static ConcreteCreds> {
  233. sign_stream_with_block_id(sect_sz)
  234. }
  235. fn sign_stream() -> SignStream<SectoredCursor<Vec<u8>>, &'static ConcreteCreds> {
  236. sign_stream_with_sz(SECTOR_SZ_DEFAULT)
  237. }
  238. #[test]
  239. fn new_empty() {
  240. let _ = sign_stream();
  241. }
  242. #[test]
  243. fn write() {
  244. let mut stream = sign_stream();
  245. let data = vec![1u8; stream.sector_sz()];
  246. stream.write(&data).expect("write failed");
  247. }
  248. #[test]
  249. fn seek() {
  250. let in_sect_sz = SECTOR_SZ_DEFAULT;
  251. let mut stream = sign_stream_with_sz(in_sect_sz);
  252. let out_sect_sz = stream.sector_sz();
  253. let expected: u64 = out_sect_sz.try_into().unwrap();
  254. let data = vec![1u8; out_sect_sz];
  255. stream.write(&data).expect("first write failed");
  256. stream.write(&data).expect("second write failed");
  257. let actual = stream.seek(SeekFrom::Start(expected)).expect("seek failed");
  258. assert_eq!(expected, actual);
  259. let expected: u64 = in_sect_sz.try_into().unwrap();
  260. let actual = stream
  261. .into_inner()
  262. .stream_position()
  263. .expect("stream_position failed");
  264. assert_eq!(expected, actual);
  265. }
  266. #[test]
  267. fn write_read_once() {
  268. let mut stream = sign_stream();
  269. let sect_sz = stream.sector_sz();
  270. let expected = vec![1u8; sect_sz];
  271. stream.write(&expected).expect("write failed");
  272. stream.seek(SeekFrom::Start(0)).expect("seek failed");
  273. let mut actual = vec![0u8; sect_sz];
  274. let read = stream.read(&mut actual).expect("read failed");
  275. assert_eq!(sect_sz, read);
  276. assert_eq!(expected, actual);
  277. }
  278. fn fill_vec<T: Clone>(vec: &mut Vec<T>, value: T) -> &mut Vec<T> {
  279. vec.clear();
  280. vec.extend(std::iter::repeat(value).take(vec.capacity()));
  281. vec
  282. }
  283. #[test]
  284. fn write_read_many() {
  285. const ITER: u8 = 16;
  286. let mut stream = sign_stream();
  287. let sect_sz = stream.sector_sz();
  288. let mut expected = Vec::with_capacity(sect_sz);
  289. let mut actual = Vec::with_capacity(sect_sz);
  290. for k in 0..ITER {
  291. fill_vec(&mut expected, k);
  292. stream.write(&expected).expect("write failed");
  293. }
  294. stream.seek(SeekFrom::Start(0)).expect("seek failed");
  295. for k in 0..ITER {
  296. let read = stream.read(fill_vec(&mut actual, 0)).expect("read failed");
  297. assert_eq!(sect_sz, read);
  298. fill_vec(&mut expected, k);
  299. assert_eq!(expected, actual);
  300. }
  301. }
  302. #[test]
  303. fn write_read_random() {
  304. const ITER: usize = 16;
  305. let mut stream = sign_stream();
  306. let sect_sz = stream.sector_sz();
  307. let rando = Randomizer::new([37; Randomizer::HASH.len()]);
  308. let indices: Vec<u64> = rando.take(ITER).map(|e| (e % ITER) as u64).collect();
  309. let mut expected = Vec::with_capacity(sect_sz);
  310. let mut actual = Vec::with_capacity(sect_sz);
  311. // Fill the stream with zeros.
  312. for _ in 0..ITER {
  313. stream
  314. .write(fill_vec(&mut expected, 0))
  315. .expect("write failed");
  316. }
  317. for index in indices.iter().map(|e| *e) {
  318. let offset = stream.offset_at(index);
  319. stream
  320. .seek(SeekFrom::Start(offset as u64))
  321. .expect("seek failed");
  322. fill_vec(&mut expected, (index + 1) as u8);
  323. stream.write(&expected).expect("write failed");
  324. }
  325. for index in indices.iter().map(|e| *e) {
  326. let offset = stream.offset_at(index);
  327. stream
  328. .seek(SeekFrom::Start(offset as u64))
  329. .expect("seek failed");
  330. fill_vec(&mut expected, (index + 1) as u8);
  331. stream.read(fill_vec(&mut actual, 0)).expect("read failed");
  332. assert_eq!(expected, actual);
  333. }
  334. }
  335. #[test]
  336. fn modify_inner_is_err() {
  337. let in_sect_sz = SECTOR_SZ_DEFAULT;
  338. let inner = SectoredCursor::new(Vec::new(), in_sect_sz).require_sect_sz(false);
  339. let mut stream = SignStream::new(inner, node_creds()).expect("SignStream::new failed");
  340. let out_sect_sz = stream.sector_sz();
  341. let sect = vec![0u8; out_sect_sz];
  342. stream.write(&sect).expect("write failed");
  343. let mut inner = stream.into_inner();
  344. inner.seek(SeekFrom::Start(0)).expect("seek failed");
  345. inner.write(&[1u8]).expect("second write failed");
  346. inner.seek(SeekFrom::Start(0)).expect("seek failed");
  347. let mut stream =
  348. SignStream::new(inner, node_creds()).expect("second SignStream::new failed");
  349. let mut buf = vec![0u8; out_sect_sz];
  350. let result = stream.read(&mut buf);
  351. let actual_err = result.err().unwrap().into_inner().unwrap();
  352. let actual = format!("{actual_err}");
  353. assert_eq!("invalid signature", actual);
  354. }
  355. }