lib.rs 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. use anyhow::anyhow;
  2. use log::error;
  3. use nix::{
  4. sys::signal::{self, Signal},
  5. unistd::Pid,
  6. };
  7. use std::{
  8. error::Error,
  9. io,
  10. path::PathBuf,
  11. process::{Child, Command, ExitStatus, Stdio},
  12. str::FromStr,
  13. sync::{
  14. atomic::{AtomicU16, Ordering},
  15. mpsc::{channel, Receiver, TryRecvError},
  16. },
  17. time::{Duration, SystemTime},
  18. };
  19. use tempdir::TempDir;
  20. use tss_esapi::{
  21. tcti_ldr::{TabrmdConfig, TctiNameConf},
  22. Context,
  23. };
  24. pub struct SwtpmHarness {
  25. dir: TempDir,
  26. state_path: PathBuf,
  27. pid_path: PathBuf,
  28. tabrmd: Child,
  29. tabrmd_config: String,
  30. }
  31. impl SwtpmHarness {
  32. const HOST: &'static str = "127.0.0.1";
  33. fn dbus_name(port: u16) -> String {
  34. let port_str: String = port
  35. .to_string()
  36. .chars()
  37. // Shifting each code point by 17 makes the digits into capital letters.
  38. .map(|e| ((e as u8) + 17) as char)
  39. .collect();
  40. format!("com.intel.tss2.Tabrmd.{port_str}")
  41. }
  42. pub fn new() -> anyhow::Result<SwtpmHarness> {
  43. static PORT: AtomicU16 = AtomicU16::new(21901);
  44. let port = PORT.fetch_add(2, Ordering::SeqCst);
  45. let ctrl_port = port + 1;
  46. let dir = TempDir::new(format!("swtpm_harness.{port}").as_str())?;
  47. let dir_path = dir.path();
  48. let dir_path_display = dir_path.display();
  49. let conf_path = dir_path.join("swtpm_setup.conf");
  50. let state_path = dir_path.join("tpm_cred_store.state");
  51. let pid_path = dir_path.join("swtpm.pid");
  52. let dbus_name = Self::dbus_name(port);
  53. let addr = Self::HOST;
  54. std::fs::write(
  55. &conf_path,
  56. r#"# Program invoked for creating certificates
  57. #create_certs_tool= /usr/bin/swtpm_localca
  58. # Comma-separated list (no spaces) of PCR banks to activate by default
  59. active_pcr_banks = sha256
  60. "#,
  61. )?;
  62. Command::new("swtpm_setup")
  63. .stdout(Stdio::null())
  64. .args([
  65. "--tpm2",
  66. "--config",
  67. conf_path.to_str().unwrap(),
  68. "--tpm-state",
  69. format!("{dir_path_display}").as_str(),
  70. ])
  71. .status()?
  72. .success_or_err()?;
  73. Command::new("swtpm")
  74. .args([
  75. "socket",
  76. "--daemon",
  77. "--tpm2",
  78. "--server",
  79. format!("type=tcp,port={port},bindaddr={addr}").as_str(),
  80. "--ctrl",
  81. format!("type=tcp,port={ctrl_port},bindaddr={addr}").as_str(),
  82. "--log",
  83. format!("file={dir_path_display}/log.txt,level=5").as_str(),
  84. "--flags",
  85. "not-need-init,startup-clear",
  86. "--tpmstate",
  87. format!("dir={dir_path_display}").as_str(),
  88. "--pid",
  89. format!("file={}", pid_path.display()).as_str(),
  90. ])
  91. .status()?
  92. .success_or_err()?;
  93. let mut blocker = DbusBlocker::new_session(dbus_name.clone())?;
  94. let tabrmd = Command::new("tpm2-abrmd")
  95. .args([
  96. format!("--tcti=swtpm:host=127.0.0.1,port={port}").as_str(),
  97. "--dbus-name",
  98. dbus_name.as_str(),
  99. "--session",
  100. ])
  101. .spawn()?;
  102. blocker.block(Duration::from_secs(5))?;
  103. Ok(SwtpmHarness {
  104. dir,
  105. state_path,
  106. pid_path,
  107. tabrmd,
  108. tabrmd_config: format!("bus_name={},bus_type=session", Self::dbus_name(port)),
  109. })
  110. }
  111. pub fn tabrmd_config(&self) -> &str {
  112. &self.tabrmd_config
  113. }
  114. pub fn context(&self) -> io::Result<Context> {
  115. let config = TabrmdConfig::from_str(self.tabrmd_config()).box_err()?;
  116. Context::new(TctiNameConf::Tabrmd(config)).box_err()
  117. }
  118. pub fn dir_path(&self) -> &std::path::Path {
  119. self.dir.path()
  120. }
  121. pub fn state_path(&self) -> &std::path::Path {
  122. &self.state_path
  123. }
  124. }
  125. impl Drop for SwtpmHarness {
  126. fn drop(&mut self) {
  127. if let Err(err) = self.tabrmd.kill() {
  128. error!("failed to kill tpm2-abrmd: {err}");
  129. }
  130. let pid_str = std::fs::read_to_string(&self.pid_path).unwrap();
  131. let pid_int = pid_str.parse::<i32>().unwrap();
  132. let pid = Pid::from_raw(pid_int);
  133. signal::kill(pid, Signal::SIGKILL).unwrap();
  134. }
  135. }
  136. trait ExitStatusExt {
  137. fn success_or_err(&self) -> anyhow::Result<()>;
  138. }
  139. impl ExitStatusExt for ExitStatus {
  140. fn success_or_err(&self) -> anyhow::Result<()> {
  141. match self.code() {
  142. Some(0) => Ok(()),
  143. Some(code) => Err(anyhow!("ExitCode was non-zero: {code}")),
  144. None => Err(anyhow!("ExitCode was None")),
  145. }
  146. }
  147. }
  148. /// A DBus message which is sent when the ownership of a name changes.
  149. struct NameOwnerChanged {
  150. name: String,
  151. old_owner: String,
  152. new_owner: String,
  153. }
  154. impl dbus::arg::AppendAll for NameOwnerChanged {
  155. fn append(&self, iter: &mut dbus::arg::IterAppend) {
  156. dbus::arg::RefArg::append(&self.name, iter);
  157. dbus::arg::RefArg::append(&self.old_owner, iter);
  158. dbus::arg::RefArg::append(&self.new_owner, iter);
  159. }
  160. }
  161. impl dbus::arg::ReadAll for NameOwnerChanged {
  162. fn read(iter: &mut dbus::arg::Iter) -> std::result::Result<Self, dbus::arg::TypeMismatchError> {
  163. Ok(NameOwnerChanged {
  164. name: iter.read()?,
  165. old_owner: iter.read()?,
  166. new_owner: iter.read()?,
  167. })
  168. }
  169. }
  170. impl dbus::message::SignalArgs for NameOwnerChanged {
  171. const NAME: &'static str = "NameOwnerChanged";
  172. const INTERFACE: &'static str = "org.freedesktop.DBus";
  173. }
  174. /// A struct used to block until a specific name appears on DBus.
  175. struct DbusBlocker {
  176. receiver: Receiver<()>,
  177. conn: dbus::blocking::Connection,
  178. }
  179. impl DbusBlocker {
  180. fn new_session(name: String) -> io::Result<DbusBlocker> {
  181. use dbus::{blocking::Connection, Message};
  182. const DEST: &str = "org.freedesktop.DBus";
  183. let (sender, receiver) = channel();
  184. let conn = Connection::new_session().box_err()?;
  185. let proxy = conn.with_proxy(DEST, "/org/freedesktop/DBus", Duration::from_secs(1));
  186. let _ = proxy.match_signal(move |h: NameOwnerChanged, _: &Connection, _: &Message| {
  187. let name_appeared = h.name == name;
  188. if name_appeared {
  189. if let Err(err) = sender.send(()) {
  190. error!("failed to send unblocking signal: {err}");
  191. }
  192. }
  193. // This local variable exists to help clarify the logic.
  194. #[allow(clippy::let_and_return)]
  195. let remove_match = !name_appeared;
  196. remove_match
  197. });
  198. Ok(DbusBlocker { receiver, conn })
  199. }
  200. fn block(&mut self, timeout: Duration) -> io::Result<()> {
  201. let time_limit = SystemTime::now() + timeout;
  202. loop {
  203. self.conn.process(Duration::from_millis(100)).box_err()?;
  204. match self.receiver.try_recv() {
  205. Ok(_) => break,
  206. Err(err) => match err {
  207. TryRecvError::Empty => (),
  208. _ => return Err(io::Error::custom(err)),
  209. },
  210. }
  211. if SystemTime::now() > time_limit {
  212. return Err(io::Error::new(
  213. io::ErrorKind::TimedOut,
  214. "timed out waiting for DBUS message",
  215. ));
  216. }
  217. }
  218. Ok(())
  219. }
  220. }
  221. trait IoErrorExt {
  222. fn custom<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> io::Error {
  223. io::Error::new(io::ErrorKind::Other, err)
  224. }
  225. }
  226. impl IoErrorExt for io::Error {}
  227. trait ResultExt<T, E> {
  228. fn box_err(self) -> Result<T, io::Error>;
  229. }
  230. impl<T, E: Into<Box<dyn Error + Send + Sync>>> ResultExt<T, E> for Result<T, E> {
  231. fn box_err(self) -> Result<T, io::Error> {
  232. self.map_err(io::Error::custom)
  233. }
  234. }