lib.rs 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. // SPDX-License-Identifier: AGPL-3.0-or-later
  2. use anyhow::anyhow;
  3. use log::error;
  4. use nix::{
  5. sys::signal::{self, Signal},
  6. unistd::Pid,
  7. };
  8. use std::{
  9. error::Error,
  10. io,
  11. path::PathBuf,
  12. process::{Child, Command, ExitStatus, Stdio},
  13. str::FromStr,
  14. sync::{
  15. atomic::{AtomicU16, Ordering},
  16. mpsc::{channel, Receiver, TryRecvError},
  17. },
  18. time::{Duration, SystemTime},
  19. };
  20. use tempdir::TempDir;
  21. use tss_esapi::{
  22. tcti_ldr::{TabrmdConfig, TctiNameConf},
  23. Context,
  24. };
  25. pub struct SwtpmHarness {
  26. dir: TempDir,
  27. state_path: PathBuf,
  28. pid_path: PathBuf,
  29. tabrmd: Child,
  30. tabrmd_config: String,
  31. }
  32. impl SwtpmHarness {
  33. const HOST: &'static str = "127.0.0.1";
  34. fn dbus_name(port: u16) -> String {
  35. let port_str: String = port
  36. .to_string()
  37. .chars()
  38. // Shifting each code point by 17 makes the digits into capital letters.
  39. .map(|e| ((e as u8) + 17) as char)
  40. .collect();
  41. format!("com.intel.tss2.Tabrmd.{port_str}")
  42. }
  43. pub fn new() -> anyhow::Result<SwtpmHarness> {
  44. static PORT: AtomicU16 = AtomicU16::new(21901);
  45. let port = PORT.fetch_add(2, Ordering::SeqCst);
  46. let ctrl_port = port + 1;
  47. let dir = TempDir::new(format!("swtpm_harness.{port}").as_str())?;
  48. let dir_path = dir.path();
  49. let dir_path_display = dir_path.display();
  50. let conf_path = dir_path.join("swtpm_setup.conf");
  51. let state_path = dir_path.join("tpm_cred_store.state");
  52. let pid_path = dir_path.join("swtpm.pid");
  53. let dbus_name = Self::dbus_name(port);
  54. let addr = Self::HOST;
  55. std::fs::write(
  56. &conf_path,
  57. r#"# Program invoked for creating certificates
  58. #create_certs_tool= /usr/bin/swtpm_localca
  59. # Comma-separated list (no spaces) of PCR banks to activate by default
  60. active_pcr_banks = sha256
  61. "#,
  62. )?;
  63. Command::new("swtpm_setup")
  64. .stdout(Stdio::null())
  65. .args([
  66. "--tpm2",
  67. "--config",
  68. conf_path.to_str().unwrap(),
  69. "--tpm-state",
  70. format!("{dir_path_display}").as_str(),
  71. ])
  72. .status()?
  73. .success_or_err()?;
  74. Command::new("swtpm")
  75. .args([
  76. "socket",
  77. "--daemon",
  78. "--tpm2",
  79. "--server",
  80. format!("type=tcp,port={port},bindaddr={addr}").as_str(),
  81. "--ctrl",
  82. format!("type=tcp,port={ctrl_port},bindaddr={addr}").as_str(),
  83. "--log",
  84. format!("file={dir_path_display}/log.txt,level=5").as_str(),
  85. "--flags",
  86. "not-need-init,startup-clear",
  87. "--tpmstate",
  88. format!("dir={dir_path_display}").as_str(),
  89. "--pid",
  90. format!("file={}", pid_path.display()).as_str(),
  91. ])
  92. .status()?
  93. .success_or_err()
  94. .map_err(|err| {
  95. anyhow!("swtpm {err}. This usually indicates an instance of swtpm is still running. You can rectify this with `killall swtpm`.")
  96. })?;
  97. let mut blocker = DbusBlocker::new_session(dbus_name.clone())?;
  98. let tabrmd = Command::new("tpm2-abrmd")
  99. .args([
  100. format!("--tcti=swtpm:host=127.0.0.1,port={port}").as_str(),
  101. "--dbus-name",
  102. dbus_name.as_str(),
  103. "--session",
  104. ])
  105. .spawn()?;
  106. blocker.block(Duration::from_secs(5))?;
  107. Ok(SwtpmHarness {
  108. dir,
  109. state_path,
  110. pid_path,
  111. tabrmd,
  112. tabrmd_config: format!("bus_name={},bus_type=session", Self::dbus_name(port)),
  113. })
  114. }
  115. pub fn tabrmd_config(&self) -> &str {
  116. &self.tabrmd_config
  117. }
  118. pub fn context(&self) -> io::Result<Context> {
  119. let config = TabrmdConfig::from_str(self.tabrmd_config()).box_err()?;
  120. Context::new(TctiNameConf::Tabrmd(config)).box_err()
  121. }
  122. pub fn dir_path(&self) -> &std::path::Path {
  123. self.dir.path()
  124. }
  125. pub fn state_path(&self) -> &std::path::Path {
  126. &self.state_path
  127. }
  128. }
  129. impl Drop for SwtpmHarness {
  130. fn drop(&mut self) {
  131. if let Err(err) = self.tabrmd.kill() {
  132. error!("failed to kill tpm2-abrmd: {err}");
  133. }
  134. let pid_str = std::fs::read_to_string(&self.pid_path).unwrap();
  135. let pid_int = pid_str.parse::<i32>().unwrap();
  136. let pid = Pid::from_raw(pid_int);
  137. signal::kill(pid, Signal::SIGKILL).unwrap();
  138. }
  139. }
  140. trait ExitStatusExt {
  141. fn success_or_err(&self) -> anyhow::Result<()>;
  142. }
  143. impl ExitStatusExt for ExitStatus {
  144. fn success_or_err(&self) -> anyhow::Result<()> {
  145. match self.code() {
  146. Some(0) => Ok(()),
  147. Some(code) => Err(anyhow!("ExitCode was non-zero: {code}")),
  148. None => Err(anyhow!("ExitCode was None")),
  149. }
  150. }
  151. }
  152. /// A DBus message which is sent when the ownership of a name changes.
  153. struct NameOwnerChanged {
  154. name: String,
  155. old_owner: String,
  156. new_owner: String,
  157. }
  158. impl dbus::arg::AppendAll for NameOwnerChanged {
  159. fn append(&self, iter: &mut dbus::arg::IterAppend) {
  160. dbus::arg::RefArg::append(&self.name, iter);
  161. dbus::arg::RefArg::append(&self.old_owner, iter);
  162. dbus::arg::RefArg::append(&self.new_owner, iter);
  163. }
  164. }
  165. impl dbus::arg::ReadAll for NameOwnerChanged {
  166. fn read(iter: &mut dbus::arg::Iter) -> std::result::Result<Self, dbus::arg::TypeMismatchError> {
  167. Ok(NameOwnerChanged {
  168. name: iter.read()?,
  169. old_owner: iter.read()?,
  170. new_owner: iter.read()?,
  171. })
  172. }
  173. }
  174. impl dbus::message::SignalArgs for NameOwnerChanged {
  175. const NAME: &'static str = "NameOwnerChanged";
  176. const INTERFACE: &'static str = "org.freedesktop.DBus";
  177. }
  178. /// A struct used to block until a specific name appears on DBus.
  179. struct DbusBlocker {
  180. receiver: Receiver<()>,
  181. conn: dbus::blocking::Connection,
  182. }
  183. impl DbusBlocker {
  184. fn new_session(name: String) -> io::Result<DbusBlocker> {
  185. use dbus::{blocking::Connection, Message};
  186. const DEST: &str = "org.freedesktop.DBus";
  187. let (sender, receiver) = channel();
  188. let conn = Connection::new_session().box_err()?;
  189. let proxy = conn.with_proxy(DEST, "/org/freedesktop/DBus", Duration::from_secs(1));
  190. let _ = proxy.match_signal(move |h: NameOwnerChanged, _: &Connection, _: &Message| {
  191. let name_appeared = h.name == name;
  192. if name_appeared {
  193. if let Err(err) = sender.send(()) {
  194. error!("failed to send unblocking signal: {err}");
  195. }
  196. }
  197. // This local variable exists to help clarify the logic.
  198. #[allow(clippy::let_and_return)]
  199. let remove_match = !name_appeared;
  200. remove_match
  201. });
  202. Ok(DbusBlocker { receiver, conn })
  203. }
  204. fn block(&mut self, timeout: Duration) -> io::Result<()> {
  205. let time_limit = SystemTime::now() + timeout;
  206. loop {
  207. self.conn.process(Duration::from_millis(100)).box_err()?;
  208. match self.receiver.try_recv() {
  209. Ok(_) => break,
  210. Err(err) => match err {
  211. TryRecvError::Empty => (),
  212. _ => return Err(io::Error::custom(err)),
  213. },
  214. }
  215. if SystemTime::now() > time_limit {
  216. return Err(io::Error::new(
  217. io::ErrorKind::TimedOut,
  218. "timed out waiting for DBUS message",
  219. ));
  220. }
  221. }
  222. Ok(())
  223. }
  224. }
  225. trait IoErrorExt {
  226. fn custom<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> io::Error {
  227. io::Error::new(io::ErrorKind::Other, err)
  228. }
  229. }
  230. impl IoErrorExt for io::Error {}
  231. trait ResultExt<T, E> {
  232. fn box_err(self) -> Result<T, io::Error>;
  233. }
  234. impl<T, E: Into<Box<dyn Error + Send + Sync>>> ResultExt<T, E> for Result<T, E> {
  235. fn box_err(self) -> Result<T, io::Error> {
  236. self.map_err(io::Error::custom)
  237. }
  238. }