123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 |
- // SPDX-License-Identifier: AGPL-3.0-or-later
- use anyhow::anyhow;
- use log::error;
- use nix::{
- sys::signal::{self, Signal},
- unistd::Pid,
- };
- use std::{
- error::Error,
- io,
- path::PathBuf,
- process::{Child, Command, ExitStatus, Stdio},
- str::FromStr,
- sync::{
- atomic::{AtomicU16, Ordering},
- mpsc::{channel, Receiver, TryRecvError},
- },
- time::{Duration, SystemTime},
- };
- use tempdir::TempDir;
- use tss_esapi::{
- tcti_ldr::{TabrmdConfig, TctiNameConf},
- Context,
- };
- pub struct SwtpmHarness {
- dir: TempDir,
- state_path: PathBuf,
- pid_path: PathBuf,
- tabrmd: Child,
- tabrmd_config: String,
- }
- impl SwtpmHarness {
- const HOST: &'static str = "127.0.0.1";
- fn dbus_name(port: u16) -> String {
- let port_str: String = port
- .to_string()
- .chars()
- // Shifting each code point by 17 makes the digits into capital letters.
- .map(|e| ((e as u8) + 17) as char)
- .collect();
- format!("com.intel.tss2.Tabrmd.{port_str}")
- }
- pub fn new() -> anyhow::Result<SwtpmHarness> {
- static PORT: AtomicU16 = AtomicU16::new(21901);
- let port = PORT.fetch_add(2, Ordering::SeqCst);
- let ctrl_port = port + 1;
- let dir = TempDir::new(format!("swtpm_harness.{port}").as_str())?;
- let dir_path = dir.path();
- let dir_path_display = dir_path.display();
- let conf_path = dir_path.join("swtpm_setup.conf");
- let state_path = dir_path.join("tpm_cred_store.state");
- let pid_path = dir_path.join("swtpm.pid");
- let dbus_name = Self::dbus_name(port);
- let addr = Self::HOST;
- std::fs::write(
- &conf_path,
- r#"# Program invoked for creating certificates
- #create_certs_tool= /usr/bin/swtpm_localca
- # Comma-separated list (no spaces) of PCR banks to activate by default
- active_pcr_banks = sha256
- "#,
- )?;
- Command::new("swtpm_setup")
- .stdout(Stdio::null())
- .args([
- "--tpm2",
- "--config",
- conf_path.to_str().unwrap(),
- "--tpm-state",
- format!("{dir_path_display}").as_str(),
- ])
- .status()?
- .success_or_err()?;
- Command::new("swtpm")
- .args([
- "socket",
- "--daemon",
- "--tpm2",
- "--server",
- format!("type=tcp,port={port},bindaddr={addr}").as_str(),
- "--ctrl",
- format!("type=tcp,port={ctrl_port},bindaddr={addr}").as_str(),
- "--log",
- format!("file={dir_path_display}/log.txt,level=5").as_str(),
- "--flags",
- "not-need-init,startup-clear",
- "--tpmstate",
- format!("dir={dir_path_display}").as_str(),
- "--pid",
- format!("file={}", pid_path.display()).as_str(),
- ])
- .status()?
- .success_or_err()
- .map_err(|err| {
- anyhow!("swtpm {err}. This usually indicates an instance of swtpm is still running. You can rectify this with `killall swtpm`.")
- })?;
- let mut blocker = DbusBlocker::new_session(dbus_name.clone())?;
- let tabrmd = Command::new("tpm2-abrmd")
- .args([
- format!("--tcti=swtpm:host=127.0.0.1,port={port}").as_str(),
- "--dbus-name",
- dbus_name.as_str(),
- "--session",
- ])
- .spawn()?;
- blocker.block(Duration::from_secs(5))?;
- Ok(SwtpmHarness {
- dir,
- state_path,
- pid_path,
- tabrmd,
- tabrmd_config: format!("bus_name={},bus_type=session", Self::dbus_name(port)),
- })
- }
- pub fn tabrmd_config(&self) -> &str {
- &self.tabrmd_config
- }
- pub fn context(&self) -> io::Result<Context> {
- let config = TabrmdConfig::from_str(self.tabrmd_config()).box_err()?;
- Context::new(TctiNameConf::Tabrmd(config)).box_err()
- }
- pub fn dir_path(&self) -> &std::path::Path {
- self.dir.path()
- }
- pub fn state_path(&self) -> &std::path::Path {
- &self.state_path
- }
- }
- impl Drop for SwtpmHarness {
- fn drop(&mut self) {
- if let Err(err) = self.tabrmd.kill() {
- error!("failed to kill tpm2-abrmd: {err}");
- }
- let pid_str = std::fs::read_to_string(&self.pid_path).unwrap();
- let pid_int = pid_str.parse::<i32>().unwrap();
- let pid = Pid::from_raw(pid_int);
- signal::kill(pid, Signal::SIGKILL).unwrap();
- }
- }
- trait ExitStatusExt {
- fn success_or_err(&self) -> anyhow::Result<()>;
- }
- impl ExitStatusExt for ExitStatus {
- fn success_or_err(&self) -> anyhow::Result<()> {
- match self.code() {
- Some(0) => Ok(()),
- Some(code) => Err(anyhow!("ExitCode was non-zero: {code}")),
- None => Err(anyhow!("ExitCode was None")),
- }
- }
- }
- /// A DBus message which is sent when the ownership of a name changes.
- struct NameOwnerChanged {
- name: String,
- old_owner: String,
- new_owner: String,
- }
- impl dbus::arg::AppendAll for NameOwnerChanged {
- fn append(&self, iter: &mut dbus::arg::IterAppend) {
- dbus::arg::RefArg::append(&self.name, iter);
- dbus::arg::RefArg::append(&self.old_owner, iter);
- dbus::arg::RefArg::append(&self.new_owner, iter);
- }
- }
- impl dbus::arg::ReadAll for NameOwnerChanged {
- fn read(iter: &mut dbus::arg::Iter) -> std::result::Result<Self, dbus::arg::TypeMismatchError> {
- Ok(NameOwnerChanged {
- name: iter.read()?,
- old_owner: iter.read()?,
- new_owner: iter.read()?,
- })
- }
- }
- impl dbus::message::SignalArgs for NameOwnerChanged {
- const NAME: &'static str = "NameOwnerChanged";
- const INTERFACE: &'static str = "org.freedesktop.DBus";
- }
- /// A struct used to block until a specific name appears on DBus.
- struct DbusBlocker {
- receiver: Receiver<()>,
- conn: dbus::blocking::Connection,
- }
- impl DbusBlocker {
- fn new_session(name: String) -> io::Result<DbusBlocker> {
- use dbus::{blocking::Connection, Message};
- const DEST: &str = "org.freedesktop.DBus";
- let (sender, receiver) = channel();
- let conn = Connection::new_session().box_err()?;
- let proxy = conn.with_proxy(DEST, "/org/freedesktop/DBus", Duration::from_secs(1));
- let _ = proxy.match_signal(move |h: NameOwnerChanged, _: &Connection, _: &Message| {
- let name_appeared = h.name == name;
- if name_appeared {
- if let Err(err) = sender.send(()) {
- error!("failed to send unblocking signal: {err}");
- }
- }
- // This local variable exists to help clarify the logic.
- #[allow(clippy::let_and_return)]
- let remove_match = !name_appeared;
- remove_match
- });
- Ok(DbusBlocker { receiver, conn })
- }
- fn block(&mut self, timeout: Duration) -> io::Result<()> {
- let time_limit = SystemTime::now() + timeout;
- loop {
- self.conn.process(Duration::from_millis(100)).box_err()?;
- match self.receiver.try_recv() {
- Ok(_) => break,
- Err(err) => match err {
- TryRecvError::Empty => (),
- _ => return Err(io::Error::custom(err)),
- },
- }
- if SystemTime::now() > time_limit {
- return Err(io::Error::new(
- io::ErrorKind::TimedOut,
- "timed out waiting for DBUS message",
- ));
- }
- }
- Ok(())
- }
- }
- trait IoErrorExt {
- fn custom<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> io::Error {
- io::Error::new(io::ErrorKind::Other, err)
- }
- }
- impl IoErrorExt for io::Error {}
- trait ResultExt<T, E> {
- fn box_err(self) -> Result<T, io::Error>;
- }
- impl<T, E: Into<Box<dyn Error + Send + Sync>>> ResultExt<T, E> for Result<T, E> {
- fn box_err(self) -> Result<T, io::Error> {
- self.map_err(io::Error::custom)
- }
- }
|