lib.rs 7.9 KB


  1. use log::error;
  2. use nix::{
  3. sys::signal::{self, Signal},
  4. unistd::Pid,
  5. };
  6. use std::{
  7. error::Error,
  8. io,
  9. path::PathBuf,
  10. process::{Child, Command, ExitStatus, Stdio},
  11. str::FromStr,
  12. sync::{
  13. atomic::{AtomicU16, Ordering},
  14. mpsc::{channel, Receiver, TryRecvError},
  15. },
  16. time::{Duration, SystemTime},
  17. };
  18. use tempdir::TempDir;
  19. use tss_esapi::{
  20. tcti_ldr::{TabrmdConfig, TctiNameConf},
  21. Context,
  22. };
  23. pub struct SwtpmHarness {
  24. dir: TempDir,
  25. state_path: PathBuf,
  26. pid_path: PathBuf,
  27. tabrmd: Child,
  28. tabrmd_config: String,
  29. }
  30. impl SwtpmHarness {
  31. const HOST: &'static str = "127.0.0.1";
  32. fn dbus_name(port: u16) -> String {
  33. let port_str: String = port
  34. .to_string()
  35. .chars()
  36. // Shifting each code point by 17 makes the digits into capital letters.
  37. .map(|e| ((e as u8) + 17) as char)
  38. .collect();
  39. format!("com.intel.tss2.Tabrmd.{port_str}")
  40. }
  41. pub fn new() -> io::Result<SwtpmHarness> {
  42. static PORT: AtomicU16 = AtomicU16::new(21901);
  43. let port = PORT.fetch_add(2, Ordering::SeqCst);
  44. let ctrl_port = port + 1;
  45. let dir = TempDir::new(format!("swtpm_harness.{port}").as_str())?;
  46. let dir_path = dir.path();
  47. let dir_path_display = dir_path.display();
  48. let conf_path = dir_path.join("swtpm_setup.conf");
  49. let state_path = dir_path.join("state.bt");
  50. let pid_path = dir_path.join("swtpm.pid");
  51. let dbus_name = Self::dbus_name(port);
  52. let addr = Self::HOST;
  53. std::fs::write(
  54. &conf_path,
  55. r#"# Program invoked for creating certificates
  56. #create_certs_tool= /usr/bin/swtpm_localca
  57. # Comma-separated list (no spaces) of PCR banks to activate by default
  58. active_pcr_banks = sha256
  59. "#,
  60. )?;
  61. Command::new("swtpm_setup")
  62. .stdout(Stdio::null())
  63. .args([
  64. "--tpm2",
  65. "--config",
  66. conf_path.to_str().unwrap(),
  67. "--tpm-state",
  68. format!("{dir_path_display}").as_str(),
  69. ])
  70. .status()?
  71. .success_or_err()?;
  72. Command::new("swtpm")
  73. .args([
  74. "socket",
  75. "--daemon",
  76. "--tpm2",
  77. "--server",
  78. format!("type=tcp,port={port},bindaddr={addr}").as_str(),
  79. "--ctrl",
  80. format!("type=tcp,port={ctrl_port},bindaddr={addr}").as_str(),
  81. "--log",
  82. format!("file={dir_path_display}/log.txt,level=5").as_str(),
  83. "--flags",
  84. "not-need-init,startup-clear",
  85. "--tpmstate",
  86. format!("dir={dir_path_display}").as_str(),
  87. "--pid",
  88. format!("file={}", pid_path.display()).as_str(),
  89. ])
  90. .status()?
  91. .success_or_err()?;
  92. let mut blocker = DbusBlocker::new_session(dbus_name.clone())?;
  93. let tabrmd = Command::new("tpm2-abrmd")
  94. .args([
  95. format!("--tcti=swtpm:host=127.0.0.1,port={port}").as_str(),
  96. "--dbus-name",
  97. dbus_name.as_str(),
  98. "--session",
  99. ])
  100. .spawn()?;
  101. blocker.block(Duration::from_secs(5))?;
  102. Ok(SwtpmHarness {
  103. dir,
  104. state_path,
  105. pid_path,
  106. tabrmd,
  107. tabrmd_config: format!("bus_name={},bus_type=session", Self::dbus_name(port)),
  108. })
  109. }
  110. pub fn tabrmd_config(&self) -> &str {
  111. &self.tabrmd_config
  112. }
  113. pub fn context(&self) -> io::Result<Context> {
  114. let config = TabrmdConfig::from_str(self.tabrmd_config()).box_err()?;
  115. Context::new(TctiNameConf::Tabrmd(config)).box_err()
  116. }
  117. pub fn dir_path(&self) -> &std::path::Path {
  118. self.dir.path()
  119. }
  120. pub fn state_path(&self) -> &std::path::Path {
  121. &self.state_path
  122. }
  123. }
  124. impl Drop for SwtpmHarness {
  125. fn drop(&mut self) {
  126. if let Err(err) = self.tabrmd.kill() {
  127. error!("failed to kill tpm2-abrmd: {err}");
  128. }
  129. let pid_str = std::fs::read_to_string(&self.pid_path).unwrap();
  130. let pid_int = pid_str.parse::<i32>().unwrap();
  131. let pid = Pid::from_raw(pid_int);
  132. signal::kill(pid, Signal::SIGKILL).unwrap();
  133. }
  134. }
  135. trait ExitStatusExt {
  136. fn success_or_err(&self) -> io::Result<()>;
  137. }
  138. impl ExitStatusExt for ExitStatus {
  139. fn success_or_err(&self) -> io::Result<()> {
  140. match self.code() {
  141. Some(0) => Ok(()),
  142. Some(code) => Err(io::Error::new(
  143. io::ErrorKind::Other,
  144. format!("ExitCode was non-zero: {code}"),
  145. )),
  146. None => Err(io::Error::new(io::ErrorKind::Other, "ExitCode was None")),
  147. }
  148. }
  149. }
  150. /// A DBus message which is sent when the ownership of a name changes.
  151. struct NameOwnerChanged {
  152. name: String,
  153. old_owner: String,
  154. new_owner: String,
  155. }
  156. impl dbus::arg::AppendAll for NameOwnerChanged {
  157. fn append(&self, iter: &mut dbus::arg::IterAppend) {
  158. dbus::arg::RefArg::append(&self.name, iter);
  159. dbus::arg::RefArg::append(&self.old_owner, iter);
  160. dbus::arg::RefArg::append(&self.new_owner, iter);
  161. }
  162. }
  163. impl dbus::arg::ReadAll for NameOwnerChanged {
  164. fn read(iter: &mut dbus::arg::Iter) -> std::result::Result<Self, dbus::arg::TypeMismatchError> {
  165. Ok(NameOwnerChanged {
  166. name: iter.read()?,
  167. old_owner: iter.read()?,
  168. new_owner: iter.read()?,
  169. })
  170. }
  171. }
  172. impl dbus::message::SignalArgs for NameOwnerChanged {
  173. const NAME: &'static str = "NameOwnerChanged";
  174. const INTERFACE: &'static str = "org.freedesktop.DBus";
  175. }
  176. /// A struct used to block until a specific name appears on DBus.
  177. struct DbusBlocker {
  178. receiver: Receiver<()>,
  179. conn: dbus::blocking::Connection,
  180. }
  181. impl DbusBlocker {
  182. fn new_session(name: String) -> io::Result<DbusBlocker> {
  183. use dbus::{blocking::Connection, Message};
  184. const DEST: &str = "org.freedesktop.DBus";
  185. let (sender, receiver) = channel();
  186. let conn = Connection::new_session().box_err()?;
  187. let proxy = conn.with_proxy(DEST, "/org/freedesktop/DBus", Duration::from_secs(1));
  188. let _ = proxy.match_signal(move |h: NameOwnerChanged, _: &Connection, _: &Message| {
  189. let name_appeared = h.name == name;
  190. if name_appeared {
  191. if let Err(err) = sender.send(()) {
  192. error!("failed to send unblocking signal: {err}");
  193. }
  194. }
  195. // This local variable exists to help clarify the logic.
  196. #[allow(clippy::let_and_return)]
  197. let remove_match = !name_appeared;
  198. remove_match
  199. });
  200. Ok(DbusBlocker { receiver, conn })
  201. }
  202. fn block(&mut self, timeout: Duration) -> io::Result<()> {
  203. let time_limit = SystemTime::now() + timeout;
  204. loop {
  205. self.conn.process(Duration::from_millis(100)).box_err()?;
  206. match self.receiver.try_recv() {
  207. Ok(_) => break,
  208. Err(err) => match err {
  209. TryRecvError::Empty => (),
  210. _ => return Err(io::Error::custom(err)),
  211. },
  212. }
  213. if SystemTime::now() > time_limit {
  214. return Err(io::Error::new(
  215. io::ErrorKind::TimedOut,
  216. "timed out waiting for DBUS message",
  217. ));
  218. }
  219. }
  220. Ok(())
  221. }
  222. }
  223. trait IoErrorExt {
  224. fn custom<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> io::Error {
  225. io::Error::new(io::ErrorKind::Other, err)
  226. }
  227. }
  228. impl IoErrorExt for io::Error {}
  229. trait ResultExt<T, E> {
  230. fn box_err(self) -> Result<T, io::Error>;
  231. }
  232. impl<T, E: Into<Box<dyn Error + Send + Sync>>> ResultExt<T, E> for Result<T, E> {
  233. fn box_err(self) -> Result<T, io::Error> {
  234. self.map_err(io::Error::custom)
  235. }
  236. }