lib.rs 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. #![feature(impl_trait_in_assoc_type)]
  2. use std::{
  3. any::Any,
  4. collections::{hash_map, HashMap},
  5. fmt::Display,
  6. future::{ready, Future, Ready},
  7. marker::PhantomData,
  8. net::IpAddr,
  9. ops::DerefMut,
  10. pin::Pin,
  11. sync::Arc,
  12. };
  13. use btlib::{bterr, crypto::Creds, error::StringError, BlockPath, Result};
  14. use btserde::{from_slice, to_vec, write_to};
  15. use bttp::{DeserCallback, MsgCallback, Receiver, Replier, Transmitter};
  16. use kernel::{kernel, SpawnReq};
  17. use serde::{Deserialize, Serialize};
  18. use tokio::{
  19. sync::{mpsc, oneshot, Mutex, RwLock},
  20. task::AbortHandle,
  21. };
  22. mod kernel;
  23. pub mod model;
  24. use model::*;
  25. /// Declares a new [Runtime] which listens for messages at the given IP address and uses the given
  26. /// [Creds]. Runtimes are intended to be created once in a process's lifetime and continue running
  27. /// until the process exits.
  28. #[macro_export]
  29. macro_rules! declare_runtime {
  30. ($name:ident, $ip_addr:expr, $creds:expr) => {
  31. ::lazy_static::lazy_static! {
  32. static ref $name: &'static $crate::Runtime = {
  33. ::lazy_static::lazy_static! {
  34. static ref RUNTIME: $crate::Runtime = $crate::Runtime::_new($creds).unwrap();
  35. static ref RECEIVER: ::bttp::Receiver = _new_receiver($ip_addr, $creds, &*RUNTIME);
  36. }
  37. // By dereferencing RECEIVER we ensure it is started.
  38. let _ = &*RECEIVER;
  39. &*RUNTIME
  40. };
  41. }
  42. };
  43. }
  44. /// This function is not intended to be called by downstream crates.
  45. #[doc(hidden)]
  46. pub fn _new_receiver<C>(ip_addr: IpAddr, creds: Arc<C>, runtime: &'static Runtime) -> Receiver
  47. where
  48. C: 'static + Send + Sync + Creds,
  49. {
  50. let callback = RuntimeCallback::new(runtime);
  51. Receiver::new(ip_addr, creds, callback).unwrap()
  52. }
  53. /// Type used to implement an actor's mailbox.
  54. pub type Mailbox<T> = mpsc::Receiver<Envelope<T>>;
  55. /// An actor runtime.
  56. ///
  57. /// Actors can be activated by the runtime and execute autonomously until they return. Running
  58. /// actors can be sent messages using the `send` method, which does not wait for a response from the
  59. /// recipient. If a reply is needed, then `call` can be used, which returns a future that will
  60. /// be ready when the reply has been received.
  61. pub struct Runtime {
  62. path: Arc<BlockPath>,
  63. handles: RwLock<HashMap<ActorId, ActorHandle>>,
  64. peers: RwLock<HashMap<Arc<BlockPath>, Transmitter>>,
  65. registry: RwLock<HashMap<ServiceId, ServiceRecord>>,
  66. kernel_sender: mpsc::Sender<SpawnReq>,
  67. }
  68. impl Runtime {
  69. /// The size of the buffer to use for the channel between [Runtime] and [kernel] used for
  70. /// spawning tasks.
  71. const SPAWN_REQ_BUF_SZ: usize = 16;
  72. /// This method is not intended to be called directly by downstream crates. Use the macro
  73. /// [declare_runtime] to create a [Runtime].
  74. ///
  75. /// If you create a non-static [Runtime], your process will panic when it is dropped.
  76. #[doc(hidden)]
  77. pub fn _new<C: 'static + Send + Sync + Creds>(creds: Arc<C>) -> Result<Runtime> {
  78. let path = Arc::new(creds.bind_path()?);
  79. let (sender, receiver) = mpsc::channel(Self::SPAWN_REQ_BUF_SZ);
  80. tokio::task::spawn(kernel(receiver));
  81. Ok(Runtime {
  82. path,
  83. handles: RwLock::new(HashMap::new()),
  84. peers: RwLock::new(HashMap::new()),
  85. registry: RwLock::new(HashMap::new()),
  86. kernel_sender: sender,
  87. })
  88. }
  89. pub fn path(&self) -> &Arc<BlockPath> {
  90. &self.path
  91. }
  92. /// Returns the number of actors that are currently executing in this [Runtime].
  93. pub async fn num_running(&self) -> usize {
  94. let guard = self.handles.read().await;
  95. guard.len()
  96. }
  97. /// Sends a message to the actor identified by the given [ActorName].
  98. pub async fn send<T: 'static + SendMsg>(
  99. &self,
  100. to: ActorName,
  101. from: ActorName,
  102. msg: T,
  103. ) -> Result<()> {
  104. if to.path().as_ref() == self.path.as_ref() {
  105. let guard = self.handles.read().await;
  106. if let Some(handle) = guard.get(&to.act_id()) {
  107. handle.send(from, msg).await
  108. } else {
  109. Err(bterr!("invalid actor name"))
  110. }
  111. } else {
  112. let guard = self.peers.read().await;
  113. if let Some(peer) = guard.get(to.path()) {
  114. let buf = to_vec(&msg)?;
  115. let wire_msg = WireMsg {
  116. to,
  117. from,
  118. payload: &buf,
  119. };
  120. peer.send(wire_msg).await
  121. } else {
  122. todo!("Discover the network location of the recipient runtime and connect to it.")
  123. }
  124. }
  125. }
  126. /// Sends a message to the service identified by [ServiceName].
  127. pub async fn send_service<T: 'static + SendMsg>(
  128. &'static self,
  129. to: ServiceAddr,
  130. from: ActorName,
  131. msg: T,
  132. ) -> Result<()> {
  133. if to.path().as_ref() == self.path.as_ref() {
  134. let actor_id = self.service_provider(&to).await?;
  135. let handles = self.handles.read().await;
  136. if let Some(handle) = handles.get(&actor_id) {
  137. handle.send(from, msg).await
  138. } else {
  139. panic!(
  140. "Service record '{}' had a non-existent actor with ID '{}'.",
  141. to.service_id(),
  142. actor_id
  143. );
  144. }
  145. } else {
  146. todo!("Send the message to an appropriate peer.")
  147. }
  148. }
  149. /// Sends a message to the actor identified by the given [ActorName] and returns a future which
  150. /// is ready when a reply has been received.
  151. pub async fn call<T: 'static + CallMsg>(
  152. &self,
  153. to: ActorName,
  154. from: ActorName,
  155. msg: T,
  156. ) -> Result<T::Reply> {
  157. if to.path().as_ref() == self.path.as_ref() {
  158. let guard = self.handles.read().await;
  159. if let Some(handle) = guard.get(&to.act_id()) {
  160. handle.call_through(msg).await
  161. } else {
  162. Err(bterr!("invalid actor name"))
  163. }
  164. } else {
  165. let guard = self.peers.read().await;
  166. if let Some(peer) = guard.get(to.path()) {
  167. let buf = to_vec(&msg)?;
  168. let wire_msg = WireMsg {
  169. to,
  170. from,
  171. payload: &buf,
  172. };
  173. peer.call(wire_msg, ReplyCallback::<T>::new()).await?
  174. } else {
  175. todo!("Use the filesystem to find the address of the recipient and connect to it.")
  176. }
  177. }
  178. }
  179. /// Calls a service identified by [ServiceName].
  180. pub async fn call_service<T: 'static + CallMsg>(
  181. &'static self,
  182. to: ServiceAddr,
  183. msg: T,
  184. ) -> Result<T::Reply> {
  185. if to.path().as_ref() == self.path.as_ref() {
  186. let actor_id = self.service_provider(&to).await?;
  187. let handles = self.handles.read().await;
  188. if let Some(handle) = handles.get(&actor_id) {
  189. handle.call_through(msg).await
  190. } else {
  191. panic!(
  192. "Service record '{}' had a non-existent actor with ID '{}'.",
  193. to.service_id(),
  194. actor_id
  195. );
  196. }
  197. } else {
  198. todo!("Send the message to an appropriate peer.")
  199. }
  200. }
  201. fn service_not_registered_err(id: &ServiceId) -> btlib::Error {
  202. bterr!("Service is not registered: '{id}'")
  203. }
  204. async fn service_provider(&'static self, to: &ServiceAddr) -> Result<ActorId> {
  205. let actor_id = {
  206. let registry = self.registry.read().await;
  207. if let Some(record) = registry.get(to.service_id()) {
  208. record.actor_ids.first().copied()
  209. } else {
  210. return Err(Self::service_not_registered_err(to.service_id()));
  211. }
  212. };
  213. let actor_id = if let Some(actor_id) = actor_id {
  214. actor_id
  215. } else {
  216. let mut registry = self.registry.write().await;
  217. if let Some(record) = registry.get_mut(to.service_id()) {
  218. // It's possible that another thread got the write lock before us and they
  219. // already spawned an actor.
  220. if record.actor_ids.is_empty() {
  221. let spawner = record.spawner.as_ref();
  222. let actor_name = spawner(self).await?;
  223. let actor_id = actor_name.act_id();
  224. record.actor_ids.push(actor_id);
  225. actor_id
  226. } else {
  227. record.actor_ids[0]
  228. }
  229. } else {
  230. return Err(Self::service_not_registered_err(to.service_id()));
  231. }
  232. };
  233. Ok(actor_id)
  234. }
  235. /// Spawns a new actor using the given activator function and returns a handle to it.
  236. pub async fn spawn<Msg, F, Fut>(&'static self, activator: F) -> ActorName
  237. where
  238. Msg: 'static + CallMsg,
  239. Fut: 'static + Send + Future<Output = ()>,
  240. F: FnOnce(&'static Runtime, Mailbox<Msg>, ActorId) -> Fut,
  241. {
  242. let mut guard = self.handles.write().await;
  243. let act_id = {
  244. let mut act_id = ActorId::new();
  245. while guard.contains_key(&act_id) {
  246. act_id = ActorId::new();
  247. }
  248. act_id
  249. };
  250. let act_name = self.actor_name(act_id);
  251. let (tx, rx) = mpsc::channel::<Envelope<Msg>>(MAILBOX_LIMIT);
  252. // The deliverer closure is responsible for deserializing messages received over the wire
  253. // and delivering them to the actor's mailbox, as well as sending replies to call messages.
  254. let deliverer = {
  255. let buffer = Arc::new(Mutex::new(Vec::<u8>::new()));
  256. let tx = tx.clone();
  257. let act_name = act_name.clone();
  258. move |envelope: WireEnvelope| {
  259. let (wire_msg, replier) = envelope.into_parts();
  260. let result = from_slice(wire_msg.payload);
  261. let buffer = buffer.clone();
  262. let tx = tx.clone();
  263. let act_name = act_name.clone();
  264. let fut: FutureResult = Box::pin(async move {
  265. let msg = result?;
  266. if let Some(mut replier) = replier {
  267. let (envelope, rx) = Envelope::new_call(msg);
  268. tx.send(envelope).await.map_err(|_| {
  269. bterr!("failed to deliver message. Recipient may have halted.")
  270. })?;
  271. match rx.await {
  272. Ok(reply) => {
  273. let mut guard = buffer.lock().await;
  274. guard.clear();
  275. write_to(&reply, guard.deref_mut())?;
  276. let wire_reply = WireReply::Ok(&guard);
  277. replier.reply(wire_reply).await
  278. }
  279. Err(err) => replier.reply_err(err.to_string(), None).await,
  280. }
  281. } else {
  282. tx.send(Envelope::new_send(act_name, msg))
  283. .await
  284. .map_err(|_| {
  285. bterr!("failed to deliver message. Recipient may have halted.")
  286. })
  287. }
  288. });
  289. fut
  290. }
  291. };
  292. let (req, receiver) = SpawnReq::new(activator(self, rx, act_id));
  293. self.kernel_sender
  294. .send(req)
  295. .await
  296. .unwrap_or_else(|err| panic!("The kernel has panicked: {err}"));
  297. let handle = receiver
  298. .await
  299. .unwrap_or_else(|err| panic!("Kernel failed to send abort handle: {err}"));
  300. let actor_handle = ActorHandle::new(handle, tx, deliverer);
  301. guard.insert(act_id, actor_handle);
  302. act_name
  303. }
  304. /// Registers a service activation closure for [ServiceId]. An error is returned if the
  305. /// [ServiceId] has already been registered.
  306. pub async fn register<Msg, F>(&self, id: ServiceId, spawner: F) -> Result<ServiceName>
  307. where
  308. Msg: 'static + CallMsg,
  309. F: 'static
  310. + Send
  311. + Sync
  312. + Fn(&'static Runtime) -> Pin<Box<dyn Future<Output = Result<ActorName>>>>,
  313. {
  314. let mut guard = self.registry.write().await;
  315. match guard.entry(id.clone()) {
  316. hash_map::Entry::Vacant(entry) => {
  317. entry.insert(ServiceRecord::new(spawner));
  318. Ok(ServiceName::new(self.path().clone(), id.clone()))
  319. }
  320. hash_map::Entry::Occupied(_) => {
  321. log::info!("Updated registration for service '{id}'.");
  322. Ok(ServiceName::new(self.path().clone(), id))
  323. }
  324. }
  325. }
  326. /// Removes the registration for the service with the given ID.
  327. ///
  328. /// If a vector reference is given in `service_providers`, the service providers which
  329. /// are part of the deregistered service are appended to it. Otherwise, their
  330. /// handles are dropped and their tasks are aborted.
  331. ///
  332. /// A [RuntimeError::BadServiceId] error is returned if there is no service registration with
  333. /// the given ID in this runtime.
  334. pub async fn deregister(
  335. &self,
  336. id: &ServiceId,
  337. service_providers: Option<&mut Vec<ActorHandle>>,
  338. ) -> Result<()> {
  339. let record = {
  340. let mut registry = self.registry.write().await;
  341. if let Some(record) = registry.remove(id) {
  342. record
  343. } else {
  344. return Err(RuntimeError::BadServiceId(id.clone()).into());
  345. }
  346. };
  347. let mut handles = self.handles.write().await;
  348. let removed = record
  349. .actor_ids
  350. .into_iter()
  351. .flat_map(|act_id| handles.remove(&act_id));
  352. // If a vector was provided, we put all the removed service providers in it. Otherwise
  353. // we just drop them.
  354. if let Some(service_providers) = service_providers {
  355. service_providers.extend(removed);
  356. } else {
  357. for _ in removed {}
  358. }
  359. Ok(())
  360. }
  361. /// Returns the [ActorHandle] for the actor with the given name.
  362. ///
  363. /// If there is no such actor in this runtime then a [RuntimeError::BadActorName] error is
  364. /// returned.
  365. ///
  366. /// Note that the actor will be aborted when the given handle is dropped (unless it has already
  367. /// returned when the handle is dropped), and no further messages will be delivered to it by
  368. /// this runtime.
  369. pub async fn take(&self, name: &ActorName) -> Result<ActorHandle> {
  370. if name.path().as_ref() == self.path.as_ref() {
  371. let mut guard = self.handles.write().await;
  372. if let Some(handle) = guard.remove(&name.act_id()) {
  373. Ok(handle)
  374. } else {
  375. Err(RuntimeError::BadActorName(name.clone()).into())
  376. }
  377. } else {
  378. Err(RuntimeError::BadActorName(name.clone()).into())
  379. }
  380. }
  381. /// Returns the name of the actor in this runtime with the given actor ID.
  382. pub fn actor_name(&self, act_id: ActorId) -> ActorName {
  383. ActorName::new(self.path.clone(), act_id)
  384. }
  385. }
  386. impl Drop for Runtime {
  387. fn drop(&mut self) {
  388. panic!("A Runtime was dropped. Panicking to avoid undefined behavior.");
  389. }
  390. }
  391. /// Closure type used to spawn new service providers.
  392. type Spawner =
  393. Box<dyn Send + Sync + Fn(&'static Runtime) -> Pin<Box<dyn Future<Output = Result<ActorName>>>>>;
  394. struct ServiceRecord {
  395. spawner: Spawner,
  396. actor_ids: Vec<ActorId>,
  397. }
  398. impl ServiceRecord {
  399. fn new<F>(spawner: F) -> Self
  400. where
  401. F: 'static
  402. + Send
  403. + Sync
  404. + Fn(&'static Runtime) -> Pin<Box<dyn Future<Output = Result<ActorName>>>>,
  405. {
  406. Self {
  407. spawner: Box::new(spawner),
  408. actor_ids: Vec::new(),
  409. }
  410. }
  411. }
  412. #[derive(Debug, Clone, PartialEq, Eq)]
  413. pub enum RuntimeError {
  414. BadActorName(ActorName),
  415. BadServiceId(ServiceId),
  416. }
  417. impl Display for RuntimeError {
  418. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  419. match self {
  420. Self::BadActorName(name) => write!(f, "bad actor name: {name}"),
  421. Self::BadServiceId(service_id) => {
  422. write!(f, "service ID is not registered: {service_id}")
  423. }
  424. }
  425. }
  426. }
  427. impl std::error::Error for RuntimeError {}
  428. /// Deserializes replies sent over the wire.
  429. struct ReplyCallback<T> {
  430. _phantom: PhantomData<T>,
  431. }
  432. impl<T: CallMsg> ReplyCallback<T> {
  433. fn new() -> Self {
  434. Self {
  435. _phantom: PhantomData,
  436. }
  437. }
  438. }
  439. impl<T: CallMsg> Default for ReplyCallback<T> {
  440. fn default() -> Self {
  441. Self::new()
  442. }
  443. }
  444. impl<T: CallMsg> DeserCallback for ReplyCallback<T> {
  445. type Arg<'de> = WireReply<'de> where T: 'de;
  446. type Return = Result<T::Reply>;
  447. type CallFut<'de> = Ready<Self::Return> where T: 'de, T::Reply: 'de;
  448. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  449. let result = match arg {
  450. WireReply::Ok(slice) => from_slice(slice).map_err(|err| err.into()),
  451. WireReply::Err(msg) => Err(StringError::new(msg.to_string()).into()),
  452. };
  453. ready(result)
  454. }
  455. }
  456. struct SendReplyCallback {
  457. replier: Option<Replier>,
  458. }
  459. impl SendReplyCallback {
  460. fn new(replier: Replier) -> Self {
  461. Self {
  462. replier: Some(replier),
  463. }
  464. }
  465. }
  466. impl DeserCallback for SendReplyCallback {
  467. type Arg<'de> = WireReply<'de>;
  468. type Return = Result<()>;
  469. type CallFut<'de> = impl 'de + Future<Output = Self::Return>;
  470. fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
  471. async move {
  472. if let Some(mut replier) = self.replier.take() {
  473. replier.reply(arg).await
  474. } else {
  475. Ok(())
  476. }
  477. }
  478. }
  479. }
  480. /// This struct implements the server callback for network messages.
  481. #[derive(Clone)]
  482. struct RuntimeCallback {
  483. rt: &'static Runtime,
  484. }
  485. impl RuntimeCallback {
  486. fn new(rt: &'static Runtime) -> Self {
  487. Self { rt }
  488. }
  489. async fn deliver_local(&self, msg: WireMsg<'_>, replier: Option<Replier>) -> Result<()> {
  490. let guard = self.rt.handles.read().await;
  491. if let Some(handle) = guard.get(&msg.to.act_id()) {
  492. let envelope = if let Some(replier) = replier {
  493. WireEnvelope::Call { msg, replier }
  494. } else {
  495. WireEnvelope::Send { msg }
  496. };
  497. (handle.deliverer)(envelope).await
  498. } else {
  499. Err(bterr!("invalid actor name: {}", msg.to))
  500. }
  501. }
  502. async fn route_msg(&self, msg: WireMsg<'_>, replier: Option<Replier>) -> Result<()> {
  503. let guard = self.rt.peers.read().await;
  504. if let Some(tx) = guard.get(msg.to.path()) {
  505. if let Some(replier) = replier {
  506. let callback = SendReplyCallback::new(replier);
  507. tx.call(msg, callback).await?
  508. } else {
  509. tx.send(msg).await
  510. }
  511. } else {
  512. Err(bterr!(
  513. "unable to deliver message to peer at '{}'",
  514. msg.to.path()
  515. ))
  516. }
  517. }
  518. }
  519. impl MsgCallback for RuntimeCallback {
  520. type Arg<'de> = WireMsg<'de>;
  521. type CallFut<'de> = impl 'de + Future<Output = Result<()>>;
  522. fn call<'de>(&'de self, arg: bttp::MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
  523. async move {
  524. let (_, body, replier) = arg.into_parts();
  525. if body.to.path() == self.rt.path() {
  526. self.deliver_local(body, replier).await
  527. } else {
  528. self.route_msg(body, replier).await
  529. }
  530. }
  531. }
  532. }
  533. /// The maximum number of messages which can be kept in an actor's mailbox.
  534. const MAILBOX_LIMIT: usize = 32;
  535. /// The type of messages sent over the wire between runtimes.
  536. #[derive(Serialize, Deserialize)]
  537. struct WireMsg<'a> {
  538. to: ActorName,
  539. from: ActorName,
  540. payload: &'a [u8],
  541. }
  542. impl<'a> WireMsg<'a> {
  543. #[allow(dead_code)]
  544. fn new(to: ActorName, from: ActorName, payload: &'a [u8]) -> Self {
  545. Self { to, from, payload }
  546. }
  547. }
  548. impl<'a> bttp::CallMsg<'a> for WireMsg<'a> {
  549. type Reply<'r> = WireReply<'r>;
  550. }
  551. impl<'a> bttp::SendMsg<'a> for WireMsg<'a> {}
  552. #[derive(Serialize, Deserialize)]
  553. enum WireReply<'a> {
  554. Ok(&'a [u8]),
  555. Err(&'a str),
  556. }
  557. /// A wrapper around [WireMsg] which indicates whether a call or send was executed.
  558. enum WireEnvelope<'de> {
  559. Send { msg: WireMsg<'de> },
  560. Call { msg: WireMsg<'de>, replier: Replier },
  561. }
  562. impl<'de> WireEnvelope<'de> {
  563. fn into_parts(self) -> (WireMsg<'de>, Option<Replier>) {
  564. match self {
  565. Self::Send { msg } => (msg, None),
  566. Self::Call { msg, replier } => (msg, Some(replier)),
  567. }
  568. }
  569. }
  570. pub enum EnvelopeKind<T: CallMsg> {
  571. Call {
  572. reply: Option<oneshot::Sender<T::Reply>>,
  573. },
  574. Send {
  575. from: ActorName,
  576. },
  577. }
  578. impl<T: CallMsg> EnvelopeKind<T> {
  579. pub fn name(&self) -> &'static str {
  580. match self {
  581. Self::Call { .. } => "Call",
  582. Self::Send { .. } => "Send",
  583. }
  584. }
  585. }
  586. /// Wrapper around a message type `T` which indicates who the message is from and, if the message
  587. /// was dispatched with `call`, provides a channel to reply to it.
  588. pub struct Envelope<T: CallMsg> {
  589. msg: T,
  590. kind: EnvelopeKind<T>,
  591. }
  592. impl<T: CallMsg> Envelope<T> {
  593. pub fn new(msg: T, kind: EnvelopeKind<T>) -> Self {
  594. Self { msg, kind }
  595. }
  596. /// Creates a new envelope containing the given message which does not expect a reply.
  597. fn new_send(from: ActorName, msg: T) -> Self {
  598. Self {
  599. kind: EnvelopeKind::Send { from },
  600. msg,
  601. }
  602. }
  603. /// Creates a new envelope containing the given message which expects exactly one reply.
  604. fn new_call(msg: T) -> (Self, oneshot::Receiver<T::Reply>) {
  605. let (tx, rx) = oneshot::channel::<T::Reply>();
  606. let envelope = Self {
  607. kind: EnvelopeKind::Call { reply: Some(tx) },
  608. msg,
  609. };
  610. (envelope, rx)
  611. }
  612. /// Returns the name of the actor which sent this message.
  613. pub fn from(&self) -> Option<&ActorName> {
  614. match &self.kind {
  615. EnvelopeKind::Send { from } => Some(from),
  616. _ => None,
  617. }
  618. }
  619. /// Returns a reference to the message in this envelope.
  620. pub fn msg(&self) -> &T {
  621. &self.msg
  622. }
  623. /// Sends a reply to this message.
  624. ///
  625. /// If this message is not expecting a reply, or if this message has already been replied to,
  626. /// then an error is returned.
  627. pub fn reply(&mut self, reply: T::Reply) -> Result<()> {
  628. match &mut self.kind {
  629. EnvelopeKind::Call { reply: tx } => {
  630. if let Some(tx) = tx.take() {
  631. tx.send(reply).map_err(|_| bterr!("Failed to send reply."))
  632. } else {
  633. Err(bterr!("Reply has already been sent."))
  634. }
  635. }
  636. _ => Err(bterr!("Can't reply to '{}' messages.", self.kind.name())),
  637. }
  638. }
  639. /// Returns true if this message expects a reply and it has not already been replied to.
  640. pub fn needs_reply(&self) -> bool {
  641. matches!(&self.kind, EnvelopeKind::Call { .. })
  642. }
  643. pub fn split(self) -> (T, EnvelopeKind<T>) {
  644. (self.msg, self.kind)
  645. }
  646. }
  647. type FutureResult = Pin<Box<dyn Send + Future<Output = Result<()>>>>;
  648. pub struct ActorHandle {
  649. handle: AbortHandle,
  650. sender: Box<dyn Send + Sync + Any>,
  651. deliverer: Box<dyn Send + Sync + Fn(WireEnvelope<'_>) -> FutureResult>,
  652. }
  653. impl ActorHandle {
  654. fn new<T, F>(handle: AbortHandle, sender: mpsc::Sender<Envelope<T>>, deliverer: F) -> Self
  655. where
  656. T: 'static + CallMsg,
  657. F: 'static + Send + Sync + Fn(WireEnvelope<'_>) -> FutureResult,
  658. {
  659. Self {
  660. handle,
  661. sender: Box::new(sender),
  662. deliverer: Box::new(deliverer),
  663. }
  664. }
  665. fn sender<T: 'static + CallMsg>(&self) -> Result<&mpsc::Sender<Envelope<T>>> {
  666. self.sender
  667. .downcast_ref::<mpsc::Sender<Envelope<T>>>()
  668. .ok_or_else(|| bterr!("Attempt to send message as the wrong type."))
  669. }
  670. /// Sends a message to the actor represented by this handle.
  671. pub async fn send<T: 'static + SendMsg>(&self, from: ActorName, msg: T) -> Result<()> {
  672. let sender = self.sender()?;
  673. sender
  674. .send(Envelope::new_send(from, msg))
  675. .await
  676. .map_err(|_| bterr!("failed to enqueue message"))?;
  677. Ok(())
  678. }
  679. pub async fn call_through<T: 'static + CallMsg>(&self, msg: T) -> Result<T::Reply> {
  680. let sender = self.sender()?;
  681. let (envelope, rx) = Envelope::new_call(msg);
  682. sender
  683. .send(envelope)
  684. .await
  685. .map_err(|_| bterr!("failed to enqueue call"))?;
  686. let reply = rx.await?;
  687. Ok(reply)
  688. }
  689. pub fn abort(&self) {
  690. self.handle.abort();
  691. }
  692. }
  693. impl Drop for ActorHandle {
  694. fn drop(&mut self) {
  695. self.abort();
  696. }
  697. }
  698. /// Sets up variable declarations and logging configuration to facilitate testing with a [Runtime].
  699. #[macro_export]
  700. macro_rules! test_setup {
  701. () => {
  702. const RUNTIME_ADDR: ::std::net::IpAddr =
  703. ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1));
  704. lazy_static! {
  705. static ref RUNTIME_CREDS: ::std::sync::Arc<::btlib::crypto::ConcreteCreds> = {
  706. let test_store = &::btlib_tests::TEST_STORE;
  707. ::btlib::crypto::CredStore::node_creds(test_store).unwrap()
  708. };
  709. }
  710. declare_runtime!(RUNTIME, RUNTIME_ADDR, RUNTIME_CREDS.clone());
  711. lazy_static! {
  712. /// A tokio async runtime.
  713. ///
  714. /// When the `#[tokio::test]` attribute is used on a test, a new current thread runtime
  715. /// is created for each test
  716. /// (source: https://docs.rs/tokio/latest/tokio/attr.test.html#current-thread-runtime).
  717. /// This creates a problem, because the first test thread to access the `RUNTIME` static
  718. /// will initialize its `Receiver` in its runtime, which will stop running at the end of
  719. /// the test. Hence subsequent tests will not be able to send remote messages to this
  720. /// `Runtime`.
  721. ///
  722. /// By creating a single async runtime which is used by all of the tests, we can avoid this
  723. /// problem.
  724. static ref ASYNC_RT: tokio::runtime::Runtime = ::tokio::runtime::Builder
  725. ::new_current_thread()
  726. .enable_all()
  727. .build()
  728. .unwrap();
  729. }
  730. /// The log level to use when running tests.
  731. const LOG_LEVEL: &str = "warn";
  732. #[::ctor::ctor]
  733. #[allow(non_snake_case)]
  734. fn ctor() {
  735. ::std::env::set_var("RUST_LOG", format!("{},quinn=WARN", LOG_LEVEL));
  736. let mut builder = ::env_logger::Builder::from_default_env();
  737. ::btlib::log::BuilderExt::btformat(&mut builder).init();
  738. }
  739. };
  740. }
  741. #[cfg(test)]
  742. pub mod test {
  743. use super::*;
  744. use btlib::crypto::{CredStore, CredsPriv};
  745. use btlib_tests::TEST_STORE;
  746. use bttp::BlockAddr;
  747. use lazy_static::lazy_static;
  748. use serde::{Deserialize, Serialize};
  749. use crate::CallMsg;
  750. test_setup!();
  751. #[derive(Serialize, Deserialize)]
  752. struct EchoMsg(String);
  753. impl CallMsg for EchoMsg {
  754. type Reply = EchoMsg;
  755. }
  756. async fn echo(
  757. _rt: &'static Runtime,
  758. mut mailbox: mpsc::Receiver<Envelope<EchoMsg>>,
  759. _act_id: ActorId,
  760. ) {
  761. while let Some(envelope) = mailbox.recv().await {
  762. let (msg, kind) = envelope.split();
  763. match kind {
  764. EnvelopeKind::Call { reply } => {
  765. let replier =
  766. reply.unwrap_or_else(|| panic!("The reply has already been sent."));
  767. if let Err(_) = replier.send(msg) {
  768. panic!("failed to send reply");
  769. }
  770. }
  771. _ => panic!("Expected EchoMsg to be a Call Message."),
  772. }
  773. }
  774. }
  775. #[test]
  776. fn local_call() {
  777. ASYNC_RT.block_on(async {
  778. const EXPECTED: &str = "hello";
  779. let name = RUNTIME.spawn(echo).await;
  780. let from = ActorName::new(name.path().clone(), ActorId::new());
  781. let reply = RUNTIME
  782. .call(name.clone(), from, EchoMsg(EXPECTED.into()))
  783. .await
  784. .unwrap();
  785. assert_eq!(EXPECTED, reply.0);
  786. RUNTIME.take(&name).await.unwrap();
  787. })
  788. }
  789. /// Tests the `num_running` method.
  790. ///
  791. /// This test uses its own runtime and so can use the `#[tokio::test]` attribute.
  792. #[tokio::test]
  793. async fn num_running() {
  794. declare_runtime!(
  795. LOCAL_RT,
  796. // This needs to be different from the address where `RUNTIME` is listening.
  797. IpAddr::from([127, 0, 0, 2]),
  798. TEST_STORE.node_creds().unwrap()
  799. );
  800. assert_eq!(0, LOCAL_RT.num_running().await);
  801. let name = LOCAL_RT.spawn(echo).await;
  802. assert_eq!(1, LOCAL_RT.num_running().await);
  803. LOCAL_RT.take(&name).await.unwrap();
  804. assert_eq!(0, LOCAL_RT.num_running().await);
  805. }
  806. #[test]
  807. fn remote_call() {
  808. ASYNC_RT.block_on(async {
  809. const EXPECTED: &str = "hello";
  810. let actor_name = RUNTIME.spawn(echo).await;
  811. let bind_path = Arc::new(RUNTIME_CREDS.bind_path().unwrap());
  812. let block_addr = Arc::new(BlockAddr::new(RUNTIME_ADDR, bind_path));
  813. let transmitter = Transmitter::new(block_addr, RUNTIME_CREDS.clone())
  814. .await
  815. .unwrap();
  816. let buf = to_vec(&EchoMsg(EXPECTED.to_string())).unwrap();
  817. let wire_msg =
  818. WireMsg::new(actor_name.clone(), RUNTIME.actor_name(ActorId::new()), &buf);
  819. let reply = transmitter
  820. .call(wire_msg, ReplyCallback::<EchoMsg>::new())
  821. .await
  822. .unwrap()
  823. .unwrap();
  824. assert_eq!(EXPECTED, reply.0);
  825. RUNTIME.take(&actor_name).await.unwrap();
  826. });
  827. }
  828. }