secret_stream.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. use positioned_io::Size;
  2. use std::io::{self, Read, Seek, SeekFrom, Write};
  3. use crate::{
  4. bterr,
  5. crypto::{Error, Result, SymKey},
  6. Decompose, Sectored, Split, TryCompose, EMPTY_SLICE,
  7. };
  8. pub use private::SecretStream;
  9. mod private {
  10. use super::*;
  11. // A stream which encrypts all data written to it and decrypts all data read from it.
  12. pub struct SecretStream<T> {
  13. inner: T,
  14. // The sector size of the inner stream. Reads and writes are only executed using buffers of
  15. // this size.
  16. inner_sect_sz: usize,
  17. // The sector size of this stream. Reads and writes are only accepted for buffers of this size.
  18. sect_sz: usize,
  19. key: SymKey,
  20. /// Buffer for ciphertext.
  21. ct_buf: Vec<u8>,
  22. /// Buffer for plaintext.
  23. pt_buf: Vec<u8>,
  24. }
  25. impl<T> SecretStream<T> {
  26. pub fn get_ref(&self) -> &T {
  27. &self.inner
  28. }
  29. pub fn get_mut(&mut self) -> &mut T {
  30. &mut self.inner
  31. }
  32. /// Given an offset into this stream, produces the corresponding offset into the inner stream.
  33. fn inner_offset(&self, outer_offset: u64) -> u64 {
  34. let sect_sz = self.sect_sz as u64;
  35. let inner_sect_sz = self.inner_sect_sz as u64;
  36. // We return the offset into the current sector, plus the size of all previous sectors.
  37. outer_offset % sect_sz + outer_offset / sect_sz * inner_sect_sz
  38. }
  39. /// Given an offset into the inner stream, returns the corresponding offset into this stream.
  40. fn outer_offset(&self, inner_offset: u64) -> u64 {
  41. let sect_sz = self.sect_sz as u64;
  42. let inner_sect_sz = self.inner_sect_sz as u64;
  43. inner_offset % inner_sect_sz + inner_offset / inner_sect_sz * sect_sz
  44. }
  45. }
  46. impl SecretStream<()> {
  47. pub fn new(key: SymKey) -> SecretStream<()> {
  48. SecretStream {
  49. inner: (),
  50. inner_sect_sz: 0,
  51. sect_sz: 0,
  52. key,
  53. ct_buf: Vec::new(),
  54. pt_buf: Vec::new(),
  55. }
  56. }
  57. }
  58. impl<T> Split<SecretStream<&'static [u8]>, T> for SecretStream<T> {
  59. fn split(self) -> (SecretStream<&'static [u8]>, T) {
  60. let new_self = SecretStream {
  61. inner: EMPTY_SLICE,
  62. inner_sect_sz: self.inner_sect_sz,
  63. sect_sz: self.sect_sz,
  64. key: self.key,
  65. ct_buf: self.ct_buf,
  66. pt_buf: self.pt_buf,
  67. };
  68. (new_self, self.inner)
  69. }
  70. fn combine(left: SecretStream<&'static [u8]>, right: T) -> Self {
  71. SecretStream {
  72. inner: right,
  73. inner_sect_sz: left.inner_sect_sz,
  74. sect_sz: left.sect_sz,
  75. key: left.key,
  76. ct_buf: left.ct_buf,
  77. pt_buf: left.pt_buf,
  78. }
  79. }
  80. }
  81. impl<T> Decompose<T> for SecretStream<T> {
  82. fn into_inner(self) -> T {
  83. self.inner
  84. }
  85. }
  86. impl<T, U: Sectored> TryCompose<U, SecretStream<U>> for SecretStream<T> {
  87. type Error = crate::Error;
  88. fn try_compose(mut self, inner: U) -> Result<SecretStream<U>> {
  89. let inner_sect_sz = inner.sector_sz();
  90. let expansion_sz = self.key.expansion_sz();
  91. let sect_sz = inner_sect_sz - expansion_sz;
  92. let block_sz = self.key.block_size();
  93. if 0 != sect_sz % block_sz {
  94. return Err(bterr!(Error::IndivisibleSize {
  95. divisor: block_sz,
  96. actual: sect_sz,
  97. }));
  98. }
  99. self.ct_buf.resize(inner_sect_sz, 0);
  100. self.pt_buf.resize(inner_sect_sz + block_sz, 0);
  101. Ok(SecretStream {
  102. inner,
  103. inner_sect_sz,
  104. sect_sz: inner_sect_sz - expansion_sz,
  105. key: self.key,
  106. ct_buf: self.ct_buf,
  107. pt_buf: self.pt_buf,
  108. })
  109. }
  110. }
  111. impl<T> Sectored for SecretStream<T> {
  112. fn sector_sz(&self) -> usize {
  113. self.sect_sz
  114. }
  115. }
  116. impl<T: Write> Write for SecretStream<T> {
  117. fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
  118. self.assert_sector_sz(buf.len())?;
  119. self.ct_buf.resize(self.inner_sect_sz, 0);
  120. let mut encrypter = self.key.to_encrypter()?;
  121. let mut count = encrypter.update(buf, &mut self.ct_buf)?;
  122. count += encrypter.finalize(&mut self.ct_buf[count..])?;
  123. self.ct_buf.truncate(count);
  124. self.inner.write_all(&self.ct_buf).map(|_| buf.len())
  125. }
  126. fn flush(&mut self) -> io::Result<()> {
  127. self.inner.flush()
  128. }
  129. }
  130. impl<T: Read> Read for SecretStream<T> {
  131. fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
  132. self.assert_sector_sz(buf.len())?;
  133. match self.inner.read_exact(&mut self.ct_buf) {
  134. Ok(_) => (),
  135. Err(err) => {
  136. if err.kind() == io::ErrorKind::UnexpectedEof {
  137. return Ok(0);
  138. } else {
  139. return Err(err);
  140. }
  141. }
  142. }
  143. self.pt_buf
  144. .resize(self.inner_sect_sz + self.key.block_size(), 0);
  145. let mut decrypter = self.key.to_decrypter()?;
  146. let mut count = decrypter.update(&self.ct_buf, &mut self.pt_buf)?;
  147. count += decrypter.finalize(&mut self.pt_buf[count..])?;
  148. self.pt_buf.truncate(count);
  149. buf.copy_from_slice(&self.pt_buf);
  150. Ok(buf.len())
  151. }
  152. }
  153. impl<T: Seek> Seek for SecretStream<T> {
  154. fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
  155. let outer_offset = match pos {
  156. SeekFrom::Start(offset) => offset,
  157. SeekFrom::Current(offset) => {
  158. let inner_offset = self.inner.stream_position()?;
  159. let outer_offset = self.outer_offset(inner_offset);
  160. if offset >= 0 {
  161. outer_offset + offset as u64
  162. } else {
  163. outer_offset - (-offset as u64)
  164. }
  165. }
  166. SeekFrom::End(_) => {
  167. // We can support this once stream_len is stabilized:
  168. // https://github.com/rust-lang/rust/issues/59359
  169. return Err(io::Error::new(
  170. io::ErrorKind::Unsupported,
  171. "seeking from the end of the stream is not supported",
  172. ));
  173. }
  174. };
  175. let inner_offset = self.inner_offset(outer_offset);
  176. self.inner.seek(SeekFrom::Start(inner_offset))?;
  177. Ok(outer_offset)
  178. }
  179. }
  180. impl<U, T: AsRef<U>> AsRef<U> for SecretStream<T> {
  181. fn as_ref(&self) -> &U {
  182. self.inner.as_ref()
  183. }
  184. }
  185. impl<U, T: AsMut<U>> AsMut<U> for SecretStream<T> {
  186. fn as_mut(&mut self) -> &mut U {
  187. self.inner.as_mut()
  188. }
  189. }
  190. impl<T: Size> Size for SecretStream<T> {
  191. fn size(&self) -> io::Result<Option<u64>> {
  192. self.inner.size()
  193. }
  194. }
  195. }
  196. #[cfg(test)]
  197. mod tests {
  198. use crate::{
  199. crypto::SymKeyKind,
  200. test_helpers::{Randomizer, SectoredCursor},
  201. SECTOR_SZ_DEFAULT,
  202. };
  203. use super::*;
  204. fn secret_stream_sequential_test_case(key: SymKey, inner_sect_sz: usize, sect_ct: usize) {
  205. let mut stream = SecretStream::new(key)
  206. .try_compose(SectoredCursor::new(
  207. vec![0u8; inner_sect_sz * sect_ct],
  208. inner_sect_sz,
  209. ))
  210. .expect("compose failed");
  211. let sector_sz = stream.sector_sz();
  212. for k in 0..sect_ct {
  213. let sector = vec![k as u8; sector_sz];
  214. stream.write(&sector).expect("write failed");
  215. }
  216. stream.seek(SeekFrom::Start(0)).expect("seek failed");
  217. for k in 0..sect_ct {
  218. let expected = vec![k as u8; sector_sz];
  219. let mut actual = vec![0u8; sector_sz];
  220. stream.read(&mut actual).expect("read failed");
  221. assert_eq!(expected, actual);
  222. }
  223. }
  224. fn secret_stream_sequential_test_suite(kind: SymKeyKind) {
  225. let key = SymKey::generate(kind).expect("key generation failed");
  226. secret_stream_sequential_test_case(key.clone(), SECTOR_SZ_DEFAULT, 16);
  227. }
  228. #[test]
  229. fn secret_stream_encrypt_decrypt_are_inverse_aes256cbc() {
  230. secret_stream_sequential_test_suite(SymKeyKind::Aes256Cbc)
  231. }
  232. #[test]
  233. fn secret_stream_encrypt_decrypt_are_inverse_aes256ctr() {
  234. secret_stream_sequential_test_suite(SymKeyKind::Aes256Ctr)
  235. }
  236. fn secret_stream_random_access_test_case(
  237. rando: Randomizer,
  238. key: SymKey,
  239. inner_sect_sz: usize,
  240. sect_ct: usize,
  241. ) {
  242. let mut stream = SecretStream::new(key)
  243. .try_compose(SectoredCursor::new(
  244. vec![0u8; inner_sect_sz * sect_ct],
  245. inner_sect_sz,
  246. ))
  247. .expect("compose failed");
  248. let sect_sz = stream.sector_sz();
  249. let indices: Vec<usize> = rando.take(sect_ct).map(|e| e % sect_ct).collect();
  250. for index in indices.iter().map(|e| *e) {
  251. let offset = index * sect_sz;
  252. stream
  253. .seek(SeekFrom::Start(offset as u64))
  254. .expect("seek to write failed");
  255. let sector = vec![index as u8; sect_sz];
  256. stream.write(&sector).expect("write failed");
  257. }
  258. for index in indices.iter().map(|e| *e) {
  259. let offset = index * sect_sz;
  260. stream
  261. .seek(SeekFrom::Start(offset as u64))
  262. .expect("seek to read failed");
  263. let expected = vec![index as u8; sect_sz];
  264. let mut actual = vec![0u8; sect_sz];
  265. stream.read(&mut actual).expect("read failed");
  266. assert_eq!(expected, actual);
  267. }
  268. }
  269. fn secret_stream_random_access_test_suite(kind: SymKeyKind) {
  270. const SEED: [u8; Randomizer::HASH.len()] = [3u8; Randomizer::HASH.len()];
  271. let key = SymKey::generate(kind).expect("key generation failed");
  272. secret_stream_random_access_test_case(
  273. Randomizer::new(SEED),
  274. key.clone(),
  275. SECTOR_SZ_DEFAULT,
  276. 20,
  277. );
  278. secret_stream_random_access_test_case(
  279. Randomizer::new(SEED),
  280. key.clone(),
  281. SECTOR_SZ_DEFAULT,
  282. 800,
  283. );
  284. secret_stream_random_access_test_case(Randomizer::new(SEED), key.clone(), 512, 200);
  285. secret_stream_random_access_test_case(Randomizer::new(SEED), key.clone(), 512, 20);
  286. secret_stream_random_access_test_case(Randomizer::new(SEED), key.clone(), 512, 200);
  287. }
  288. #[test]
  289. fn secret_stream_random_access() {
  290. secret_stream_random_access_test_suite(SymKeyKind::Aes256Cbc);
  291. secret_stream_random_access_test_suite(SymKeyKind::Aes256Ctr);
  292. }
  293. fn make_secret_stream(
  294. key_kind: SymKeyKind,
  295. num_sectors: usize,
  296. ) -> SecretStream<SectoredCursor<Vec<u8>>> {
  297. let key = SymKey::generate(key_kind).expect("key generation failed");
  298. let inner = SectoredCursor::new(
  299. vec![0u8; num_sectors * SECTOR_SZ_DEFAULT],
  300. SECTOR_SZ_DEFAULT,
  301. );
  302. SecretStream::new(key)
  303. .try_compose(inner)
  304. .expect("compose failed")
  305. }
  306. #[test]
  307. fn secret_stream_seek_from_start() {
  308. let mut stream = make_secret_stream(SymKeyKind::Aes256Cbc, 3);
  309. let sector_sz = stream.sector_sz();
  310. let expected = vec![2u8; sector_sz];
  311. // Write one sector of ones, one sector of twos and one sector of threes.
  312. for k in 1..4 {
  313. let sector: Vec<u8> = std::iter::repeat(k as u8).take(sector_sz).collect();
  314. stream.write(&sector).expect("writing to stream failed");
  315. }
  316. stream
  317. .seek(SeekFrom::Start(sector_sz as u64))
  318. .expect("seek failed");
  319. // A read from the stream should now return the second sector, which is filled with twos.
  320. let mut actual = vec![0u8; sector_sz];
  321. stream
  322. .read(&mut actual)
  323. .expect("reading from stream failed");
  324. assert_eq!(expected, actual);
  325. }
  326. #[test]
  327. fn secret_stream_seek_from_current() {
  328. let mut stream = make_secret_stream(SymKeyKind::Aes256Cbc, 3);
  329. let sector_sz = stream.sector_sz();
  330. let expected = vec![3u8; sector_sz];
  331. // Write one sector of ones, one sector of twos and one sector of threes.
  332. for k in 1..4 {
  333. let sector: Vec<u8> = std::iter::repeat(k as u8).take(sector_sz).collect();
  334. stream.write(&sector).expect("writing to stream failed");
  335. }
  336. stream
  337. .seek(SeekFrom::Current(-1 * (sector_sz as i64)))
  338. .expect("seek failed");
  339. // A read from the stream should now return the last sector, which is filled with threes.
  340. let mut actual = vec![0u8; sector_sz];
  341. stream
  342. .read(&mut actual)
  343. .expect("reading from stream failed");
  344. assert_eq!(expected, actual);
  345. }
  346. }