|  | @@ -1,18 +1,20 @@
 | 
	
		
			
				|  |  |  // SPDX-License-Identifier: AGPL-3.0-or-later
 | 
	
		
			
				|  |  | -use positioned_io::Size;
 | 
	
		
			
				|  |  | -use std::io::{self, Read, Seek, SeekFrom, Write};
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  use crate::{
 | 
	
		
			
				|  |  |      bterr,
 | 
	
		
			
				|  |  | -    crypto::{Error, Result, SymKey},
 | 
	
		
			
				|  |  | +    crypto::{Error, HashKind, Result, SymKey, SymParams},
 | 
	
		
			
				|  |  |      Decompose, Sectored, Split, TryCompose, EMPTY_SLICE,
 | 
	
		
			
				|  |  |  };
 | 
	
		
			
				|  |  | +use openssl::symm::{Crypter, Mode};
 | 
	
		
			
				|  |  | +use positioned_io::Size;
 | 
	
		
			
				|  |  | +use std::io::{self, Read, Seek, SeekFrom, Write};
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  pub use private::SecretStream;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  mod private {
 | 
	
		
			
				|  |  |      use super::*;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    const IV_BUF_LEN: usize = HashKind::Sha2_512.len();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      // A stream which encrypts all data written to it and decrypts all data read from it.
 | 
	
		
			
				|  |  |      pub struct SecretStream<T> {
 | 
	
		
			
				|  |  |          inner: T,
 | 
	
	
		
			
				|  | @@ -26,6 +28,7 @@ mod private {
 | 
	
		
			
				|  |  |          ct_buf: Vec<u8>,
 | 
	
		
			
				|  |  |          /// Buffer for plaintext.
 | 
	
		
			
				|  |  |          pt_buf: Vec<u8>,
 | 
	
		
			
				|  |  | +        iv_buf: [u8; IV_BUF_LEN],
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      impl<T> SecretStream<T> {
 | 
	
	
		
			
				|  | @@ -53,6 +56,30 @@ mod private {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    macro_rules! sym_params {
 | 
	
		
			
				|  |  | +        ($self:expr) => {{
 | 
	
		
			
				|  |  | +            let inner_offset = $self.inner.stream_position()?;
 | 
	
		
			
				|  |  | +            let SymParams { cipher, key, iv } = $self.key.params();
 | 
	
		
			
				|  |  | +            let iv = iv.ok_or_else(|| bterr!("no IV was present in block key"))?;
 | 
	
		
			
				|  |  | +            let kind = if iv.len() <= HashKind::Sha2_256.len() {
 | 
	
		
			
				|  |  | +                HashKind::Sha2_256
 | 
	
		
			
				|  |  | +            } else {
 | 
	
		
			
				|  |  | +                HashKind::Sha2_512
 | 
	
		
			
				|  |  | +            };
 | 
	
		
			
				|  |  | +            debug_assert!(iv.len() <= kind.len());
 | 
	
		
			
				|  |  | +            kind.digest(
 | 
	
		
			
				|  |  | +                &mut $self.iv_buf,
 | 
	
		
			
				|  |  | +                [inner_offset.to_le_bytes().as_slice(), iv].into_iter(),
 | 
	
		
			
				|  |  | +            )?;
 | 
	
		
			
				|  |  | +            let iv = &$self.iv_buf[..iv.len()];
 | 
	
		
			
				|  |  | +            Ok::<_, io::Error>(SymParams {
 | 
	
		
			
				|  |  | +                cipher,
 | 
	
		
			
				|  |  | +                key,
 | 
	
		
			
				|  |  | +                iv: Some(iv),
 | 
	
		
			
				|  |  | +            })
 | 
	
		
			
				|  |  | +        }};
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      impl SecretStream<()> {
 | 
	
		
			
				|  |  |          pub fn new(key: SymKey) -> SecretStream<()> {
 | 
	
		
			
				|  |  |              SecretStream {
 | 
	
	
		
			
				|  | @@ -62,6 +89,7 @@ mod private {
 | 
	
		
			
				|  |  |                  key,
 | 
	
		
			
				|  |  |                  ct_buf: Vec::new(),
 | 
	
		
			
				|  |  |                  pt_buf: Vec::new(),
 | 
	
		
			
				|  |  | +                iv_buf: [0u8; IV_BUF_LEN],
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
	
		
			
				|  | @@ -75,6 +103,7 @@ mod private {
 | 
	
		
			
				|  |  |                  key: self.key,
 | 
	
		
			
				|  |  |                  ct_buf: self.ct_buf,
 | 
	
		
			
				|  |  |                  pt_buf: self.pt_buf,
 | 
	
		
			
				|  |  | +                iv_buf: [0u8; IV_BUF_LEN],
 | 
	
		
			
				|  |  |              };
 | 
	
		
			
				|  |  |              (new_self, self.inner)
 | 
	
		
			
				|  |  |          }
 | 
	
	
		
			
				|  | @@ -87,6 +116,7 @@ mod private {
 | 
	
		
			
				|  |  |                  key: left.key,
 | 
	
		
			
				|  |  |                  ct_buf: left.ct_buf,
 | 
	
		
			
				|  |  |                  pt_buf: left.pt_buf,
 | 
	
		
			
				|  |  | +                iv_buf: [0u8; IV_BUF_LEN],
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
	
		
			
				|  | @@ -119,6 +149,7 @@ mod private {
 | 
	
		
			
				|  |  |                  key: self.key,
 | 
	
		
			
				|  |  |                  ct_buf: self.ct_buf,
 | 
	
		
			
				|  |  |                  pt_buf: self.pt_buf,
 | 
	
		
			
				|  |  | +                iv_buf: [0u8; IV_BUF_LEN],
 | 
	
		
			
				|  |  |              })
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
	
		
			
				|  | @@ -129,12 +160,13 @@ mod private {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    impl<T: Write> Write for SecretStream<T> {
 | 
	
		
			
				|  |  | +    impl<T: Write + Seek> Write for SecretStream<T> {
 | 
	
		
			
				|  |  |          fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
 | 
	
		
			
				|  |  |              self.assert_sector_sz(buf.len())?;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +            let SymParams { cipher, key, iv } = sym_params!(self)?;
 | 
	
		
			
				|  |  |              self.ct_buf.resize(self.inner_sect_sz, 0);
 | 
	
		
			
				|  |  | -            let mut encrypter = self.key.to_encrypter()?;
 | 
	
		
			
				|  |  | +            let mut encrypter = Crypter::new(cipher, Mode::Encrypt, key, iv)?;
 | 
	
		
			
				|  |  |              let mut count = encrypter.update(buf, &mut self.ct_buf)?;
 | 
	
		
			
				|  |  |              count += encrypter.finalize(&mut self.ct_buf[count..])?;
 | 
	
		
			
				|  |  |              self.ct_buf.truncate(count);
 | 
	
	
		
			
				|  | @@ -147,10 +179,13 @@ mod private {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    impl<T: Read> Read for SecretStream<T> {
 | 
	
		
			
				|  |  | +    impl<T: Read + Seek> Read for SecretStream<T> {
 | 
	
		
			
				|  |  |          fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
 | 
	
		
			
				|  |  |              self.assert_sector_sz(buf.len())?;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +            // sym_params must be called before reading from the inner stream so that it's position
 | 
	
		
			
				|  |  | +            // will be correct for the IV calculation.
 | 
	
		
			
				|  |  | +            let SymParams { cipher, key, iv } = sym_params!(self)?;
 | 
	
		
			
				|  |  |              match self.inner.read_exact(&mut self.ct_buf) {
 | 
	
		
			
				|  |  |                  Ok(_) => (),
 | 
	
		
			
				|  |  |                  Err(err) => {
 | 
	
	
		
			
				|  | @@ -164,7 +199,7 @@ mod private {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              self.pt_buf
 | 
	
		
			
				|  |  |                  .resize(self.inner_sect_sz + self.key.block_size(), 0);
 | 
	
		
			
				|  |  | -            let mut decrypter = self.key.to_decrypter()?;
 | 
	
		
			
				|  |  | +            let mut decrypter = Crypter::new(cipher, Mode::Decrypt, key, iv)?;
 | 
	
		
			
				|  |  |              let mut count = decrypter.update(&self.ct_buf, &mut self.pt_buf)?;
 | 
	
		
			
				|  |  |              count += decrypter.finalize(&mut self.pt_buf[count..])?;
 | 
	
		
			
				|  |  |              self.pt_buf.truncate(count);
 | 
	
	
		
			
				|  | @@ -248,7 +283,7 @@ mod tests {
 | 
	
		
			
				|  |  |              let expected = vec![k as u8; sector_sz];
 | 
	
		
			
				|  |  |              let mut actual = vec![0u8; sector_sz];
 | 
	
		
			
				|  |  |              stream.read(&mut actual).expect("read failed");
 | 
	
		
			
				|  |  | -            assert_eq!(expected, actual);
 | 
	
		
			
				|  |  | +            assert!(expected == actual);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 |