1#![allow(clippy::module_name_repetitions)]
2
3use std::{
4 collections::HashMap,
5 future::Future,
6 pin::Pin,
7 sync::Arc,
8 task::{ready, Context, Poll},
9};
10
11use futures::{future, FutureExt};
12use stream_cancel::{Trigger, Tripwire};
13use tokio::time::{timeout_at, Instant};
14
15use crate::{config::ComponentKey, trigger::DisabledTrigger};
16
17pub async fn tripwire_handler(closed: bool) {
18 std::future::poll_fn(|_| {
19 if closed {
20 Poll::Ready(())
21 } else {
22 Poll::Pending
23 }
24 })
25 .await;
26}
27
28#[derive(Clone, Debug)]
33pub struct ShutdownSignalToken {
34 _shutdown_complete: Arc<Trigger>,
35}
36
37impl ShutdownSignalToken {
38 fn new(shutdown_complete: Trigger) -> Self {
39 Self {
40 _shutdown_complete: Arc::new(shutdown_complete),
41 }
42 }
43}
44
45#[pin_project::pin_project]
47#[derive(Clone, Debug)]
48pub struct ShutdownSignal {
49 #[pin]
52 begin_shutdown: Option<Tripwire>,
53
54 shutdown_complete: Option<ShutdownSignalToken>,
58}
59
60impl Future for ShutdownSignal {
61 type Output = ShutdownSignalToken;
62
63 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
64 match self.as_mut().project().begin_shutdown.as_pin_mut() {
65 Some(fut) => {
66 let closed = ready!(fut.poll(cx));
67 let mut pinned = self.project();
68 pinned.begin_shutdown.set(None);
69 if closed {
70 Poll::Ready(pinned.shutdown_complete.take().unwrap())
71 } else {
72 Poll::Pending
73 }
74 }
75 None => Poll::Pending,
78 }
79 }
80}
81
82impl ShutdownSignal {
83 #[must_use]
84 pub fn new(tripwire: Tripwire, trigger: Trigger) -> Self {
85 Self {
86 begin_shutdown: Some(tripwire),
87 shutdown_complete: Some(ShutdownSignalToken::new(trigger)),
88 }
89 }
90
91 #[must_use]
92 pub fn noop() -> Self {
93 let (trigger, tripwire) = Tripwire::new();
94 Self {
95 begin_shutdown: Some(tripwire),
96 shutdown_complete: Some(ShutdownSignalToken::new(trigger)),
97 }
98 }
99
100 #[must_use]
101 pub fn new_wired() -> (Trigger, ShutdownSignal, Tripwire) {
102 let (trigger_shutdown, tripwire) = Tripwire::new();
103 let (trigger, shutdown_done) = Tripwire::new();
104 let shutdown = ShutdownSignal::new(tripwire, trigger);
105
106 (trigger_shutdown, shutdown, shutdown_done)
107 }
108}
109
110type IsInternal = bool;
111
112#[derive(Debug, Default)]
113pub struct SourceShutdownCoordinator {
114 begun_triggers: HashMap<ComponentKey, (IsInternal, Trigger)>,
115 force_triggers: HashMap<ComponentKey, Trigger>,
116 complete_tripwires: HashMap<ComponentKey, Tripwire>,
117}
118
119impl SourceShutdownCoordinator {
120 pub fn register_source(
124 &mut self,
125 id: &ComponentKey,
126 internal: bool,
127 ) -> (ShutdownSignal, impl Future<Output = ()>) {
128 let (shutdown_begun_trigger, shutdown_begun_tripwire) = Tripwire::new();
129 let (force_shutdown_trigger, force_shutdown_tripwire) = Tripwire::new();
130 let (shutdown_complete_trigger, shutdown_complete_tripwire) = Tripwire::new();
131
132 self.begun_triggers
133 .insert(id.clone(), (internal, shutdown_begun_trigger));
134 self.force_triggers
135 .insert(id.clone(), force_shutdown_trigger);
136 self.complete_tripwires
137 .insert(id.clone(), shutdown_complete_tripwire);
138
139 let shutdown_signal =
140 ShutdownSignal::new(shutdown_begun_tripwire, shutdown_complete_trigger);
141
142 let force_shutdown_tripwire = force_shutdown_tripwire.then(tripwire_handler);
145 (shutdown_signal, force_shutdown_tripwire)
146 }
147
148 pub fn takeover_source(&mut self, id: &ComponentKey, other: &mut Self) {
154 let existing = self.begun_triggers.insert(
155 id.clone(),
156 other.begun_triggers.remove(id).unwrap_or_else(|| {
157 panic!(
158 "Other ShutdownCoordinator didn't have a shutdown_begun_trigger for \"{id}\""
159 )
160 }),
161 );
162 assert!(
163 existing.is_none(),
164 "ShutdownCoordinator already has a shutdown_begin_trigger for source \"{id}\""
165 );
166
167 let existing = self.force_triggers.insert(
168 id.clone(),
169 other.force_triggers.remove(id).unwrap_or_else(|| {
170 panic!(
171 "Other ShutdownCoordinator didn't have a shutdown_force_trigger for \"{id}\""
172 )
173 }),
174 );
175 assert!(
176 existing.is_none(),
177 "ShutdownCoordinator already has a shutdown_force_trigger for source \"{id}\""
178 );
179
180 let existing = self.complete_tripwires.insert(
181 id.clone(),
182 other
183 .complete_tripwires
184 .remove(id)
185 .unwrap_or_else(|| {
186 panic!(
187 "Other ShutdownCoordinator didn't have a shutdown_complete_tripwire for \"{id}\""
188 )
189 }),
190 );
191 assert!(
192 existing.is_none(),
193 "ShutdownCoordinator already has a shutdown_complete_tripwire for source \"{id}\""
194 );
195 }
196
197 pub fn shutdown_all(self, deadline: Option<Instant>) -> impl Future<Output = ()> {
207 let mut internal_sources_complete_futures = Vec::new();
208 let mut external_sources_complete_futures = Vec::new();
209
210 let shutdown_begun_triggers = self.begun_triggers;
211 let mut shutdown_complete_tripwires = self.complete_tripwires;
212 let mut shutdown_force_triggers = self.force_triggers;
213
214 for (id, (internal, trigger)) in shutdown_begun_triggers {
215 trigger.cancel();
216
217 let shutdown_complete_tripwire =
218 shutdown_complete_tripwires.remove(&id).unwrap_or_else(|| {
219 panic!(
220 "shutdown_complete_tripwire for source \"{id}\" not found in the ShutdownCoordinator"
221 )
222 });
223 let shutdown_force_trigger = shutdown_force_triggers.remove(&id).unwrap_or_else(|| {
224 panic!(
225 "shutdown_force_trigger for source \"{id}\" not found in the ShutdownCoordinator"
226 )
227 });
228
229 let source_complete = SourceShutdownCoordinator::shutdown_source_complete(
230 shutdown_complete_tripwire,
231 shutdown_force_trigger,
232 id.clone(),
233 deadline,
234 );
235
236 if internal {
237 internal_sources_complete_futures.push(source_complete);
238 } else {
239 external_sources_complete_futures.push(source_complete);
240 }
241 }
242
243 futures::future::join_all(external_sources_complete_futures)
244 .then(|_| futures::future::join_all(internal_sources_complete_futures))
245 .map(|_| ())
246 }
247
248 pub fn shutdown_source(
259 &mut self,
260 id: &ComponentKey,
261 deadline: Instant,
262 ) -> impl Future<Output = bool> {
263 let (_, begin_shutdown_trigger) = self.begun_triggers.remove(id).unwrap_or_else(|| {
264 panic!(
265 "shutdown_begun_trigger for source \"{id}\" not found in the ShutdownCoordinator"
266 )
267 });
268 begin_shutdown_trigger.cancel();
270
271 let shutdown_complete_tripwire = self
272 .complete_tripwires
273 .remove(id)
274 .unwrap_or_else(|| {
275 panic!(
276 "shutdown_complete_tripwire for source \"{id}\" not found in the ShutdownCoordinator"
277 )
278 });
279 let shutdown_force_trigger = self.force_triggers.remove(id).unwrap_or_else(|| {
280 panic!(
281 "shutdown_force_trigger for source \"{id}\" not found in the ShutdownCoordinator"
282 )
283 });
284 SourceShutdownCoordinator::shutdown_source_complete(
285 shutdown_complete_tripwire,
286 shutdown_force_trigger,
287 id.clone(),
288 Some(deadline),
289 )
290 }
291
292 #[must_use]
294 pub fn shutdown_tripwire(&self) -> future::BoxFuture<'static, ()> {
295 let futures = self
296 .complete_tripwires
297 .values()
298 .cloned()
299 .map(|tripwire| tripwire.then(tripwire_handler).boxed());
300
301 future::join_all(futures)
302 .map(|_| info!("All sources have finished."))
303 .boxed()
304 }
305
306 fn shutdown_source_complete(
307 shutdown_complete_tripwire: Tripwire,
308 shutdown_force_trigger: Trigger,
309 id: ComponentKey,
310 deadline: Option<Instant>,
311 ) -> impl Future<Output = bool> {
312 async move {
313 let fut = shutdown_complete_tripwire.then(tripwire_handler);
314 if let Some(deadline) = deadline {
315 let shutdown_force_trigger = DisabledTrigger::new(shutdown_force_trigger);
317 if timeout_at(deadline, fut).await.is_ok() {
318 shutdown_force_trigger.into_inner().disable();
319 true
320 } else {
321 error!(
322 "Source '{}' failed to shutdown before deadline. Forcing shutdown.",
323 id,
324 );
325 shutdown_force_trigger.into_inner().cancel();
326 false
327 }
328 } else {
329 fut.await;
330 true
331 }
332 }
333 .boxed()
334 }
335}
336
337#[cfg(test)]
338mod test {
339 use tokio::time::{Duration, Instant};
340
341 use super::*;
342 use crate::shutdown::SourceShutdownCoordinator;
343
344 #[tokio::test]
345 async fn shutdown_coordinator_shutdown_source_clean() {
346 let mut shutdown = SourceShutdownCoordinator::default();
347 let id = ComponentKey::from("test");
348
349 let (shutdown_signal, _) = shutdown.register_source(&id, false);
350
351 let deadline = Instant::now() + Duration::from_secs(1);
352 let shutdown_complete = shutdown.shutdown_source(&id, deadline);
353
354 drop(shutdown_signal);
355
356 let success = shutdown_complete.await;
357 assert!(success);
358 }
359
360 #[tokio::test]
361 async fn shutdown_coordinator_shutdown_source_force() {
362 let mut shutdown = SourceShutdownCoordinator::default();
363 let id = ComponentKey::from("test");
364
365 let (_shutdown_signal, force_shutdown_tripwire) = shutdown.register_source(&id, false);
366
367 let deadline = Instant::now() + Duration::from_secs(1);
368 let shutdown_complete = shutdown.shutdown_source(&id, deadline);
369
370 let success = shutdown_complete.await;
373 assert!(!success);
374
375 let finished = futures::poll!(force_shutdown_tripwire.boxed());
376 assert_eq!(finished, Poll::Ready(()));
377 }
378}