scuffle_signal/
bootstrap.rs

1use std::sync::Arc;
2
3use scuffle_bootstrap::global::Global;
4use scuffle_bootstrap::service::Service;
5use scuffle_context::ContextFutExt;
6
7#[derive(Default, Debug, Clone, Copy)]
8pub struct SignalSvc;
9
10pub trait SignalConfig: Global {
11    fn signals(&self) -> Vec<tokio::signal::unix::SignalKind> {
12        vec![
13            tokio::signal::unix::SignalKind::terminate(),
14            tokio::signal::unix::SignalKind::interrupt(),
15        ]
16    }
17
18    fn timeout(&self) -> Option<std::time::Duration> {
19        Some(std::time::Duration::from_secs(30))
20    }
21
22    fn on_shutdown(self: &Arc<Self>) -> impl std::future::Future<Output = anyhow::Result<()>> + Send {
23        std::future::ready(Ok(()))
24    }
25
26    fn on_force_shutdown(
27        &self,
28        signal: Option<tokio::signal::unix::SignalKind>,
29    ) -> impl std::future::Future<Output = anyhow::Result<()>> + Send {
30        let err = if let Some(signal) = signal {
31            anyhow::anyhow!("received signal, shutting down immediately: {:?}", signal)
32        } else {
33            anyhow::anyhow!("timeout reached, shutting down immediately")
34        };
35
36        std::future::ready(Err(err))
37    }
38}
39
40impl<Global: SignalConfig> Service<Global> for SignalSvc {
41    fn enabled(&self, global: &Arc<Global>) -> impl std::future::Future<Output = anyhow::Result<bool>> + Send {
42        std::future::ready(Ok(!global.signals().is_empty()))
43    }
44
45    async fn run(self, global: Arc<Global>, ctx: scuffle_context::Context) -> anyhow::Result<()> {
46        let timeout = global.timeout();
47
48        let signals = global.signals();
49        let mut handler = crate::SignalHandler::with_signals(signals);
50
51        // Wait for a signal, or for the context to be done.
52        handler.recv().with_context(&ctx).await;
53        global.on_shutdown().await?;
54        drop(ctx);
55
56        tokio::select! {
57            signal = handler.recv() => {
58                global.on_force_shutdown(Some(signal)).await?;
59            },
60            _ = scuffle_context::Handler::global().shutdown() => {}
61            Some(()) = async {
62                if let Some(timeout) = timeout {
63                    tokio::time::sleep(timeout).await;
64                    Some(())
65                } else {
66                    None
67                }
68            } => {
69                global.on_force_shutdown(None).await?;
70            },
71        };
72
73        Ok(())
74    }
75}
76
77#[cfg_attr(all(coverage_nightly, test), coverage(off))]
78#[cfg(test)]
79mod tests {
80    use std::sync::Arc;
81
82    use scuffle_bootstrap::global::GlobalWithoutConfig;
83    use scuffle_bootstrap::Service;
84    use scuffle_future_ext::FutureExt;
85    use tokio::signal::unix::SignalKind;
86
87    use super::{SignalConfig, SignalSvc};
88    use crate::tests::raise_signal;
89    use crate::SignalHandler;
90
91    async fn force_shutdown_two_signals<Global: GlobalWithoutConfig + SignalConfig>() {
92        let (ctx, handler) = scuffle_context::Context::new();
93
94        // Block the global context
95        let _global_ctx = scuffle_context::Context::global();
96
97        let svc = SignalSvc;
98        let global = <Global as GlobalWithoutConfig>::init().await.unwrap();
99
100        assert!(svc.enabled(&global).await.unwrap());
101        let result = tokio::spawn(svc.run(global, ctx));
102
103        // Wait for the service to start
104        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
105
106        raise_signal(tokio::signal::unix::SignalKind::interrupt());
107        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
108        raise_signal(tokio::signal::unix::SignalKind::interrupt());
109
110        match result.with_timeout(tokio::time::Duration::from_millis(100)).await {
111            Ok(Ok(Err(e))) => {
112                assert_eq!(e.to_string(), "received signal, shutting down immediately: SignalKind(2)");
113            }
114            _ => panic!("unexpected result"),
115        }
116
117        assert!(handler
118            .shutdown()
119            .with_timeout(tokio::time::Duration::from_millis(100))
120            .await
121            .is_ok());
122    }
123
124    struct TestGlobal;
125
126    impl GlobalWithoutConfig for TestGlobal {
127        fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
128            std::future::ready(Ok(Arc::new(Self)))
129        }
130    }
131
132    impl SignalConfig for TestGlobal {}
133
134    #[tokio::test]
135    async fn default_bootstrap_service() {
136        force_shutdown_two_signals::<TestGlobal>().await;
137    }
138    struct NoTimeoutTestGlobal;
139
140    impl GlobalWithoutConfig for NoTimeoutTestGlobal {
141        fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
142            std::future::ready(Ok(Arc::new(Self)))
143        }
144    }
145
146    impl SignalConfig for NoTimeoutTestGlobal {
147        fn timeout(&self) -> Option<std::time::Duration> {
148            None
149        }
150    }
151
152    #[tokio::test]
153    async fn bootstrap_service_no_timeout() {
154        let (ctx, handler) = scuffle_context::Context::new();
155        let svc = SignalSvc;
156        let global = NoTimeoutTestGlobal::init().await.unwrap();
157
158        assert!(svc.enabled(&global).await.unwrap());
159        let result = tokio::spawn(svc.run(global, ctx));
160
161        // Wait for the service to start
162        tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
163
164        raise_signal(tokio::signal::unix::SignalKind::interrupt());
165        assert!(result.await.is_ok());
166
167        assert!(handler
168            .shutdown()
169            .with_timeout(tokio::time::Duration::from_millis(100))
170            .await
171            .is_ok());
172    }
173
174    #[tokio::test]
175    async fn bootstrap_service_force_shutdown() {
176        force_shutdown_two_signals::<NoTimeoutTestGlobal>().await;
177    }
178
179    struct NoSignalsTestGlobal;
180
181    impl GlobalWithoutConfig for NoSignalsTestGlobal {
182        fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
183            std::future::ready(Ok(Arc::new(Self)))
184        }
185    }
186
187    impl SignalConfig for NoSignalsTestGlobal {
188        fn signals(&self) -> Vec<tokio::signal::unix::SignalKind> {
189            vec![]
190        }
191
192        fn timeout(&self) -> Option<std::time::Duration> {
193            None
194        }
195    }
196
197    #[tokio::test]
198    async fn bootstrap_service_no_signals() {
199        let (ctx, handler) = scuffle_context::Context::new();
200        let svc = SignalSvc;
201        let global = NoSignalsTestGlobal::init().await.unwrap();
202
203        assert!(!svc.enabled(&global).await.unwrap());
204        let result = tokio::spawn(svc.run(global, ctx));
205
206        // Wait for the service to start
207        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
208
209        // Make a new handler to catch the raised signal as it is expected to not be
210        // caught by the service
211        let mut signal_handler = SignalHandler::new().with_signal(SignalKind::terminate());
212
213        raise_signal(tokio::signal::unix::SignalKind::terminate());
214
215        // Wait for a signal to be received
216        assert_eq!(signal_handler.recv().await, SignalKind::terminate());
217
218        // Expected to timeout
219        assert!(result.with_timeout(tokio::time::Duration::from_millis(100)).await.is_err());
220
221        assert!(handler
222            .shutdown()
223            .with_timeout(tokio::time::Duration::from_millis(100))
224            .await
225            .is_ok());
226    }
227
228    struct SmallTimeoutTestGlobal;
229
230    impl GlobalWithoutConfig for SmallTimeoutTestGlobal {
231        fn init() -> impl std::future::Future<Output = anyhow::Result<Arc<Self>>> + Send {
232            std::future::ready(Ok(Arc::new(Self)))
233        }
234    }
235
236    impl SignalConfig for SmallTimeoutTestGlobal {
237        fn timeout(&self) -> Option<std::time::Duration> {
238            Some(std::time::Duration::from_millis(5))
239        }
240    }
241
242    #[tokio::test]
243    async fn bootstrap_service_timeout_force_shutdown() {
244        let (ctx, handler) = scuffle_context::Context::new();
245
246        // Block the global context
247        let _global_ctx = scuffle_context::Context::global();
248
249        let svc = SignalSvc;
250        let global = SmallTimeoutTestGlobal::init().await.unwrap();
251
252        assert!(svc.enabled(&global).await.unwrap());
253        let result = tokio::spawn(svc.run(global, ctx));
254
255        // Wait for the service to start
256        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
257
258        raise_signal(tokio::signal::unix::SignalKind::terminate());
259
260        match result.with_timeout(tokio::time::Duration::from_millis(100)).await {
261            Ok(Ok(Err(e))) => {
262                assert_eq!(e.to_string(), "timeout reached, shutting down immediately");
263            }
264            _ => panic!("unexpected result"),
265        }
266
267        assert!(handler
268            .shutdown()
269            .with_timeout(tokio::time::Duration::from_millis(100))
270            .await
271            .is_ok());
272    }
273}