// SPDX-License-Identifier: AGPL-3.0-or-later use anyhow::anyhow; use log::error; use nix::{ sys::signal::{self, Signal}, unistd::Pid, }; use std::{ error::Error, io, path::PathBuf, process::{Child, Command, ExitStatus, Stdio}, str::FromStr, sync::{ atomic::{AtomicU16, Ordering}, mpsc::{channel, Receiver, TryRecvError}, }, time::{Duration, SystemTime}, }; use tempdir::TempDir; use tss_esapi::{ tcti_ldr::{TabrmdConfig, TctiNameConf}, Context, }; pub struct SwtpmHarness { dir: TempDir, state_path: PathBuf, pid_path: PathBuf, tabrmd: Child, tabrmd_config: String, } impl SwtpmHarness { const HOST: &'static str = "127.0.0.1"; fn dbus_name(port: u16) -> String { let port_str: String = port .to_string() .chars() // Shifting each code point by 17 makes the digits into capital letters. .map(|e| ((e as u8) + 17) as char) .collect(); format!("com.intel.tss2.Tabrmd.{port_str}") } pub fn new() -> anyhow::Result { static PORT: AtomicU16 = AtomicU16::new(21901); let port = PORT.fetch_add(2, Ordering::SeqCst); let ctrl_port = port + 1; let dir = TempDir::new(format!("swtpm_harness.{port}").as_str())?; let dir_path = dir.path(); let dir_path_display = dir_path.display(); let conf_path = dir_path.join("swtpm_setup.conf"); let state_path = dir_path.join("tpm_cred_store.state"); let pid_path = dir_path.join("swtpm.pid"); let dbus_name = Self::dbus_name(port); let addr = Self::HOST; std::fs::write( &conf_path, r#"# Program invoked for creating certificates #create_certs_tool= /usr/bin/swtpm_localca # Comma-separated list (no spaces) of PCR banks to activate by default active_pcr_banks = sha256 "#, )?; Command::new("swtpm_setup") .stdout(Stdio::null()) .args([ "--tpm2", "--config", conf_path.to_str().unwrap(), "--tpm-state", format!("{dir_path_display}").as_str(), ]) .status()? .success_or_err()?; Command::new("swtpm") .args([ "socket", "--daemon", "--tpm2", "--server", format!("type=tcp,port={port},bindaddr={addr}").as_str(), "--ctrl", format!("type=tcp,port={ctrl_port},bindaddr={addr}").as_str(), "--log", format!("file={dir_path_display}/log.txt,level=5").as_str(), "--flags", "not-need-init,startup-clear", "--tpmstate", format!("dir={dir_path_display}").as_str(), "--pid", format!("file={}", pid_path.display()).as_str(), ]) .status()? .success_or_err() .map_err(|err| { anyhow!("swtpm {err}. This usually indicates an instance of swtpm is still running. You can rectify this with `killall swtpm`.") })?; let mut blocker = DbusBlocker::new_session(dbus_name.clone())?; let tabrmd = Command::new("tpm2-abrmd") .args([ format!("--tcti=swtpm:host=127.0.0.1,port={port}").as_str(), "--dbus-name", dbus_name.as_str(), "--session", ]) .spawn()?; blocker.block(Duration::from_secs(5))?; Ok(SwtpmHarness { dir, state_path, pid_path, tabrmd, tabrmd_config: format!("bus_name={},bus_type=session", Self::dbus_name(port)), }) } pub fn tabrmd_config(&self) -> &str { &self.tabrmd_config } pub fn context(&self) -> io::Result { let config = TabrmdConfig::from_str(self.tabrmd_config()).box_err()?; Context::new(TctiNameConf::Tabrmd(config)).box_err() } pub fn dir_path(&self) -> &std::path::Path { self.dir.path() } pub fn state_path(&self) -> &std::path::Path { &self.state_path } } impl Drop for SwtpmHarness { fn drop(&mut self) { if let Err(err) = self.tabrmd.kill() { error!("failed to kill tpm2-abrmd: {err}"); } let pid_str = std::fs::read_to_string(&self.pid_path).unwrap(); let pid_int = pid_str.parse::().unwrap(); let pid = Pid::from_raw(pid_int); signal::kill(pid, Signal::SIGKILL).unwrap(); } } trait ExitStatusExt { fn success_or_err(&self) -> anyhow::Result<()>; } impl ExitStatusExt for ExitStatus { fn success_or_err(&self) -> anyhow::Result<()> { match self.code() { Some(0) => Ok(()), Some(code) => Err(anyhow!("ExitCode was non-zero: {code}")), None => Err(anyhow!("ExitCode was None")), } } } /// A DBus message which is sent when the ownership of a name changes. struct NameOwnerChanged { name: String, old_owner: String, new_owner: String, } impl dbus::arg::AppendAll for NameOwnerChanged { fn append(&self, iter: &mut dbus::arg::IterAppend) { dbus::arg::RefArg::append(&self.name, iter); dbus::arg::RefArg::append(&self.old_owner, iter); dbus::arg::RefArg::append(&self.new_owner, iter); } } impl dbus::arg::ReadAll for NameOwnerChanged { fn read(iter: &mut dbus::arg::Iter) -> std::result::Result { Ok(NameOwnerChanged { name: iter.read()?, old_owner: iter.read()?, new_owner: iter.read()?, }) } } impl dbus::message::SignalArgs for NameOwnerChanged { const NAME: &'static str = "NameOwnerChanged"; const INTERFACE: &'static str = "org.freedesktop.DBus"; } /// A struct used to block until a specific name appears on DBus. struct DbusBlocker { receiver: Receiver<()>, conn: dbus::blocking::Connection, } impl DbusBlocker { fn new_session(name: String) -> io::Result { use dbus::{blocking::Connection, Message}; const DEST: &str = "org.freedesktop.DBus"; let (sender, receiver) = channel(); let conn = Connection::new_session().box_err()?; let proxy = conn.with_proxy(DEST, "/org/freedesktop/DBus", Duration::from_secs(1)); let _ = proxy.match_signal(move |h: NameOwnerChanged, _: &Connection, _: &Message| { let name_appeared = h.name == name; if name_appeared { if let Err(err) = sender.send(()) { error!("failed to send unblocking signal: {err}"); } } // This local variable exists to help clarify the logic. #[allow(clippy::let_and_return)] let remove_match = !name_appeared; remove_match }); Ok(DbusBlocker { receiver, conn }) } fn block(&mut self, timeout: Duration) -> io::Result<()> { let time_limit = SystemTime::now() + timeout; loop { self.conn.process(Duration::from_millis(100)).box_err()?; match self.receiver.try_recv() { Ok(_) => break, Err(err) => match err { TryRecvError::Empty => (), _ => return Err(io::Error::custom(err)), }, } if SystemTime::now() > time_limit { return Err(io::Error::new( io::ErrorKind::TimedOut, "timed out waiting for DBUS message", )); } } Ok(()) } } trait IoErrorExt { fn custom>>(err: E) -> io::Error { io::Error::new(io::ErrorKind::Other, err) } } impl IoErrorExt for io::Error {} trait ResultExt { fn box_err(self) -> Result; } impl>> ResultExt for Result { fn box_err(self) -> Result { self.map_err(io::Error::custom) } }