Kaynağa Gözat

Added a return value to client trait methods.

Matthew Carr 1 yıl önce
ebeveyn
işleme
3506942bcd

+ 17 - 1
crates/btproto/src/model.rs

@@ -668,7 +668,11 @@ impl MethodModel {
             let kind = ValueKind::State { def: state.clone() };
             outputs.push(ValueModel::new(kind, type_prefix));
         }
-        if !part_of_client {
+        if part_of_client {
+            if def.in_msg().is_none() {
+                outputs.push(ValueModel::new(ValueKind::Return, type_prefix));
+            }
+        } else {
             for dest in def.out_msgs.as_ref().iter() {
                 let kind = ValueKind::new_dest(dest.clone(), messages, part_of_client);
                 outputs.push(ValueModel::new(kind, type_prefix));
@@ -783,6 +787,13 @@ impl ValueModel {
                 let assoc_type = None;
                 (decl, type_name, assoc_type)
             }
+            ValueKind::Return { .. } => {
+                let assoc_type_ident = format_ident!("{type_prefix}Return");
+                let decl = Some(quote! { type #assoc_type_ident; });
+                let type_name = Some(quote! { Self::#assoc_type_ident });
+                let assoc_type = Some(assoc_type_ident);
+                (decl, type_name, assoc_type)
+            }
         };
         Self {
             var_name: kind.var_name(),
@@ -870,6 +881,7 @@ impl ValueModel {
                     quote! {}
                 }
             }
+            ValueKind::Return => quote! {},
         }
     }
 }
@@ -888,6 +900,8 @@ pub(crate) enum ValueKind {
         reply_type: Option<Rc<TokenStream>>,
         part_of_client: bool,
     },
+    /// Represents the return value of a client handle.
+    Return,
 }
 
 impl ValueKind {
@@ -896,6 +910,7 @@ impl ValueKind {
             Self::Msg { def, .. } => def.msg_type.as_ref(),
             Self::State { def, .. } => def.state_trait.as_ref(),
             Self::Dest { def, .. } => def.state.state_ref().state_trait.as_ref(),
+            Self::Return { .. } => return format_ident!("return_var"),
         };
         format_ident!("{}_var", ident.pascal_to_snake())
     }
@@ -917,6 +932,7 @@ impl ValueKind {
             Self::Msg { def, .. } => def.span(),
             Self::State { def, .. } => def.span(),
             Self::Dest { def, .. } => def.span(),
+            Self::Return { .. } => Span::call_site(),
         }
     }
 

+ 9 - 6
crates/btproto/tests/protocol_tests.rs

@@ -58,9 +58,10 @@ fn minimal_syntax() {
     impl Client for ClientImpl {
         actor_name!("minimal_client");
 
-        type OnSendMsgFut = Ready<TransResult<Self, End>>;
+        type OnSendMsgReturn = usize;
+        type OnSendMsgFut = Ready<TransResult<Self, (End, usize)>>;
         fn on_send_msg(self) -> Self::OnSendMsgFut {
-            ready(TransResult::Ok(End))
+            ready(TransResult::Ok((End, 42)))
         }
     }
 }
@@ -99,10 +100,11 @@ fn reply() {
     impl Client for ClientImpl {
         actor_name!("reply_client");
 
+        type OnSendPingReturn = ();
         type OnSendPingClient = Self;
-        type OnSendPingFut = Ready<TransResult<Self, Self>>;
+        type OnSendPingFut = Ready<TransResult<Self, (Self, ())>>;
         fn on_send_ping(self, _ping: <Ping as CallMsg>::Reply) -> Self::OnSendPingFut {
-            ready(TransResult::Ok(self))
+            ready(TransResult::Ok((self, ())))
         }
     }
 }
@@ -145,10 +147,11 @@ fn client_callback() {
     impl Unregistered for UnregisteredImpl {
         actor_name!("callback_client");
 
+        type OnSendRegisterReturn = ();
         type OnSendRegisterRegistered = RegisteredImpl;
-        type OnSendRegisterFut = Ready<TransResult<Self, Self::OnSendRegisterRegistered>>;
+        type OnSendRegisterFut = Ready<TransResult<Self, (Self::OnSendRegisterRegistered, ())>>;
         fn on_send_register(self) -> Self::OnSendRegisterFut {
-            ready(TransResult::Ok(RegisteredImpl))
+            ready(TransResult::Ok((RegisteredImpl, ())))
         }
     }
 

+ 57 - 40
crates/btrun/tests/runtime_tests.rs

@@ -69,7 +69,7 @@ mod ping_pong {
             mut self,
             msg: Ping,
             service: ServiceAddr,
-        ) -> TransResult<Self, ClientHandleManual<T>> {
+        ) -> TransResult<Self, (ClientHandleManual<T>, T::OnSendPingReturn)> {
             let state = if let Some(state) = self.state.take() {
                 state
             } else {
@@ -97,9 +97,9 @@ mod ping_pong {
                     };
                     if let PingProtocolMsgs::PingReply(reply) = reply_enum {
                         match state.on_send_ping(reply).await {
-                            TransResult::Ok(new_state) => {
+                            TransResult::Ok((new_state, return_var)) => {
                                 self.state = Some(PingClientState::End(new_state));
-                                TransResult::Ok(self)
+                                TransResult::Ok((self, return_var))
                             }
                             TransResult::Abort { from, err } => {
                                 self.state = Some(PingClientState::Client(from));
@@ -280,10 +280,11 @@ mod ping_pong {
     impl Client for ClientImpl {
         actor_name!("ping_client");
 
-        type OnSendPingFut = impl Future<Output = TransResult<Self, End>>;
+        type OnSendPingReturn = ();
+        type OnSendPingFut = impl Future<Output = TransResult<Self, (End, ())>>;
         fn on_send_ping(self, _msg: PingReply) -> Self::OnSendPingFut {
             self.counter.fetch_sub(1, Ordering::SeqCst);
-            ready(TransResult::Ok(End))
+            ready(TransResult::Ok((End, ())))
         }
     }
 
@@ -414,12 +415,16 @@ mod client_callback {
     impl Unregistered for UnregisteredState {
         actor_name!("callback_client");
 
+        type OnSendRegisterReturn = ();
         type OnSendRegisterRegistered = RegisteredState;
-        type OnSendRegisterFut = Ready<TransResult<Self, Self::OnSendRegisterRegistered>>;
+        type OnSendRegisterFut = Ready<TransResult<Self, (Self::OnSendRegisterRegistered, ())>>;
         fn on_send_register(self) -> Self::OnSendRegisterFut {
-            ready(TransResult::Ok(RegisteredState {
-                sender: self.sender,
-            }))
+            ready(TransResult::Ok((
+                RegisteredState {
+                    sender: self.sender,
+                },
+                (),
+            )))
         }
     }
 
@@ -522,40 +527,52 @@ mod client_callback {
             self,
             to: ServiceAddr,
             msg: Register,
-        ) -> Result<ClientHandleManual<Init, NewState>> {
-            {
-                let mut guard = self.state.lock().await;
-                let state = guard
-                    .take()
-                    .unwrap_or_else(|| panic!("Logic error. The state was not returned."));
-                let new_state = match state {
-                    ClientStateManual::Unregistered(state) => {
-                        match state.on_send_register().await {
-                            TransResult::Ok(new_state) => {
-                                let msg = ClientCallbackMsgs::Register(msg);
-                                self.runtime
-                                    .send_service(to, self.name.clone(), msg)
-                                    .await?;
-                                ClientStateManual::Registered(new_state)
-                            }
-                            TransResult::Abort { from, err } => {
-                                log::warn!(
-                                    "Aborted transition from the {} state: {}",
-                                    "Unregistered",
-                                    err
-                                );
-                                ClientStateManual::Unregistered(from)
-                            }
-                            TransResult::Fatal { err } => {
-                                return Err(err);
-                            }
+        ) -> TransResult<
+            Self,
+            (
+                ClientHandleManual<Init, NewState>,
+                Init::OnSendRegisterReturn,
+            ),
+        > {
+            let mut guard = self.state.lock().await;
+            let state = guard
+                .take()
+                .unwrap_or_else(|| panic!("Logic error. The state was not returned."));
+            match state {
+                ClientStateManual::Unregistered(state) => match state.on_send_register().await {
+                    TransResult::Ok((new_state, return_var)) => {
+                        let msg = ClientCallbackMsgs::Register(msg);
+                        let result = self.runtime.send_service(to, self.name.clone(), msg).await;
+                        if let Err(err) = result {
+                            return TransResult::Fatal { err };
                         }
+                        *guard = Some(ClientStateManual::Registered(new_state));
+                        drop(guard);
+                        TransResult::Ok((self.new_type(), return_var))
                     }
-                    state => state,
-                };
-                *guard = Some(new_state);
+                    TransResult::Abort { from, err } => {
+                        *guard = Some(ClientStateManual::Unregistered(from));
+                        drop(guard);
+                        return TransResult::Abort { from: self, err };
+                    }
+                    TransResult::Fatal { err } => {
+                        return TransResult::Fatal { err };
+                    }
+                },
+                state => {
+                    let name = state.name();
+                    *guard = Some(state);
+                    drop(guard);
+                    TransResult::Abort {
+                        from: self,
+                        err: bterr!(
+                            "Unexpected state '{}' for '{}' method.",
+                            name,
+                            "send_register"
+                        ),
+                    }
+                }
             }
-            Ok(self.new_type())
         }
     }