vector_common/
shutdown.rs

1#![allow(clippy::module_name_repetitions)]
2
3use std::{
4    collections::HashMap,
5    future::Future,
6    pin::Pin,
7    sync::Arc,
8    task::{ready, Context, Poll},
9};
10
11use futures::{future, FutureExt};
12use stream_cancel::{Trigger, Tripwire};
13use tokio::time::{timeout_at, Instant};
14
15use crate::{config::ComponentKey, trigger::DisabledTrigger};
16
17pub async fn tripwire_handler(closed: bool) {
18    std::future::poll_fn(|_| {
19        if closed {
20            Poll::Ready(())
21        } else {
22            Poll::Pending
23        }
24    })
25    .await;
26}
27
28/// When this struct goes out of scope and its internal refcount goes to 0 it is a signal that its
29/// corresponding `Source` has completed executing and may be cleaned up.  It is the responsibility
30/// of each `Source` to ensure that at least one copy of this handle remains alive for the entire
31/// lifetime of the Source.
32#[derive(Clone, Debug)]
33pub struct ShutdownSignalToken {
34    _shutdown_complete: Arc<Trigger>,
35}
36
37impl ShutdownSignalToken {
38    fn new(shutdown_complete: Trigger) -> Self {
39        Self {
40            _shutdown_complete: Arc::new(shutdown_complete),
41        }
42    }
43}
44
45/// Passed to each `Source` to coordinate the global shutdown process.
46#[pin_project::pin_project]
47#[derive(Clone, Debug)]
48pub struct ShutdownSignal {
49    /// This will be triggered when global shutdown has begun, and is a sign to the Source to begin
50    /// its shutdown process.
51    #[pin]
52    begin_shutdown: Option<Tripwire>,
53
54    /// When a Source allows this to go out of scope it informs the global shutdown coordinator that
55    /// this Source's local shutdown process is complete.
56    /// Optional only so that `poll()` can move the handle out and return it.
57    shutdown_complete: Option<ShutdownSignalToken>,
58}
59
60impl Future for ShutdownSignal {
61    type Output = ShutdownSignalToken;
62
63    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
64        match self.as_mut().project().begin_shutdown.as_pin_mut() {
65            Some(fut) => {
66                let closed = ready!(fut.poll(cx));
67                let mut pinned = self.project();
68                pinned.begin_shutdown.set(None);
69                if closed {
70                    Poll::Ready(pinned.shutdown_complete.take().unwrap())
71                } else {
72                    Poll::Pending
73                }
74            }
75            // TODO: This should almost certainly be a panic to avoid deadlocking in the case of a
76            // poll-after-ready situation.
77            None => Poll::Pending,
78        }
79    }
80}
81
82impl ShutdownSignal {
83    #[must_use]
84    pub fn new(tripwire: Tripwire, trigger: Trigger) -> Self {
85        Self {
86            begin_shutdown: Some(tripwire),
87            shutdown_complete: Some(ShutdownSignalToken::new(trigger)),
88        }
89    }
90
91    #[must_use]
92    pub fn noop() -> Self {
93        let (trigger, tripwire) = Tripwire::new();
94        Self {
95            begin_shutdown: Some(tripwire),
96            shutdown_complete: Some(ShutdownSignalToken::new(trigger)),
97        }
98    }
99
100    #[must_use]
101    pub fn new_wired() -> (Trigger, ShutdownSignal, Tripwire) {
102        let (trigger_shutdown, tripwire) = Tripwire::new();
103        let (trigger, shutdown_done) = Tripwire::new();
104        let shutdown = ShutdownSignal::new(tripwire, trigger);
105
106        (trigger_shutdown, shutdown, shutdown_done)
107    }
108}
109
110type IsInternal = bool;
111
112#[derive(Debug, Default)]
113pub struct SourceShutdownCoordinator {
114    begun_triggers: HashMap<ComponentKey, (IsInternal, Trigger)>,
115    force_triggers: HashMap<ComponentKey, Trigger>,
116    complete_tripwires: HashMap<ComponentKey, Tripwire>,
117}
118
119impl SourceShutdownCoordinator {
120    /// Creates the necessary Triggers and Tripwires for coordinating shutdown of this Source and
121    /// stores them as needed.  Returns the `ShutdownSignal` for this Source as well as a Tripwire
122    /// that will be notified if the Source should be forcibly shut down.
123    pub fn register_source(
124        &mut self,
125        id: &ComponentKey,
126        internal: bool,
127    ) -> (ShutdownSignal, impl Future<Output = ()>) {
128        let (shutdown_begun_trigger, shutdown_begun_tripwire) = Tripwire::new();
129        let (force_shutdown_trigger, force_shutdown_tripwire) = Tripwire::new();
130        let (shutdown_complete_trigger, shutdown_complete_tripwire) = Tripwire::new();
131
132        self.begun_triggers
133            .insert(id.clone(), (internal, shutdown_begun_trigger));
134        self.force_triggers
135            .insert(id.clone(), force_shutdown_trigger);
136        self.complete_tripwires
137            .insert(id.clone(), shutdown_complete_tripwire);
138
139        let shutdown_signal =
140            ShutdownSignal::new(shutdown_begun_tripwire, shutdown_complete_trigger);
141
142        // `force_shutdown_tripwire` resolves even if canceled when we should *not* be shutting down.
143        // `tripwire_handler` handles cancel by never resolving.
144        let force_shutdown_tripwire = force_shutdown_tripwire.then(tripwire_handler);
145        (shutdown_signal, force_shutdown_tripwire)
146    }
147
148    /// Takes ownership of all internal state for the given source from another `ShutdownCoordinator`.
149    ///
150    /// # Panics
151    ///
152    /// Panics if the other coordinator already had its triggers removed.
153    pub fn takeover_source(&mut self, id: &ComponentKey, other: &mut Self) {
154        let existing = self.begun_triggers.insert(
155            id.clone(),
156            other.begun_triggers.remove(id).unwrap_or_else(|| {
157                panic!(
158                    "Other ShutdownCoordinator didn't have a shutdown_begun_trigger for \"{id}\""
159                )
160            }),
161        );
162        assert!(
163            existing.is_none(),
164            "ShutdownCoordinator already has a shutdown_begin_trigger for source \"{id}\""
165        );
166
167        let existing = self.force_triggers.insert(
168            id.clone(),
169            other.force_triggers.remove(id).unwrap_or_else(|| {
170                panic!(
171                    "Other ShutdownCoordinator didn't have a shutdown_force_trigger for \"{id}\""
172                )
173            }),
174        );
175        assert!(
176            existing.is_none(),
177            "ShutdownCoordinator already has a shutdown_force_trigger for source \"{id}\""
178        );
179
180        let existing = self.complete_tripwires.insert(
181            id.clone(),
182            other
183                .complete_tripwires
184                .remove(id)
185                .unwrap_or_else(|| {
186                    panic!(
187                        "Other ShutdownCoordinator didn't have a shutdown_complete_tripwire for \"{id}\""
188                    )
189                }),
190        );
191        assert!(
192            existing.is_none(),
193            "ShutdownCoordinator already has a shutdown_complete_tripwire for source \"{id}\""
194        );
195    }
196
197    /// Sends a signal to begin shutting down to all sources, and returns a future that
198    /// resolves once all sources have either shut down completely, or have been sent the
199    /// force shutdown signal.  The force shutdown signal will be sent to any sources that
200    /// don't cleanly shut down before the given `deadline`.
201    ///
202    /// # Panics
203    ///
204    /// Panics if this coordinator has had its triggers removed (ie
205    /// has been taken over with `Self::takeover_source`).
206    pub fn shutdown_all(self, deadline: Option<Instant>) -> impl Future<Output = ()> {
207        let mut internal_sources_complete_futures = Vec::new();
208        let mut external_sources_complete_futures = Vec::new();
209
210        let shutdown_begun_triggers = self.begun_triggers;
211        let mut shutdown_complete_tripwires = self.complete_tripwires;
212        let mut shutdown_force_triggers = self.force_triggers;
213
214        for (id, (internal, trigger)) in shutdown_begun_triggers {
215            trigger.cancel();
216
217            let shutdown_complete_tripwire =
218                shutdown_complete_tripwires.remove(&id).unwrap_or_else(|| {
219                    panic!(
220                "shutdown_complete_tripwire for source \"{id}\" not found in the ShutdownCoordinator"
221            )
222                });
223            let shutdown_force_trigger = shutdown_force_triggers.remove(&id).unwrap_or_else(|| {
224                panic!(
225                    "shutdown_force_trigger for source \"{id}\" not found in the ShutdownCoordinator"
226                )
227            });
228
229            let source_complete = SourceShutdownCoordinator::shutdown_source_complete(
230                shutdown_complete_tripwire,
231                shutdown_force_trigger,
232                id.clone(),
233                deadline,
234            );
235
236            if internal {
237                internal_sources_complete_futures.push(source_complete);
238            } else {
239                external_sources_complete_futures.push(source_complete);
240            }
241        }
242
243        futures::future::join_all(external_sources_complete_futures)
244            .then(|_| futures::future::join_all(internal_sources_complete_futures))
245            .map(|_| ())
246    }
247
248    /// Sends the signal to the given source to begin shutting down. Returns a future that resolves
249    /// when the source has finished shutting down cleanly or been sent the force shutdown signal.
250    /// The returned future resolves to a bool that indicates if the source shut down cleanly before
251    /// the given `deadline`. If the result is false then that means the source failed to shut down
252    /// before `deadline` and had to be force-shutdown.
253    ///
254    /// # Panics
255    ///
256    /// Panics if this coordinator has had its triggers removed (ie
257    /// has been taken over with `Self::takeover_source`).
258    pub fn shutdown_source(
259        &mut self,
260        id: &ComponentKey,
261        deadline: Instant,
262    ) -> impl Future<Output = bool> {
263        let (_, begin_shutdown_trigger) = self.begun_triggers.remove(id).unwrap_or_else(|| {
264            panic!(
265                "shutdown_begun_trigger for source \"{id}\" not found in the ShutdownCoordinator"
266            )
267        });
268        // This is what actually triggers the source to begin shutting down.
269        begin_shutdown_trigger.cancel();
270
271        let shutdown_complete_tripwire = self
272            .complete_tripwires
273            .remove(id)
274            .unwrap_or_else(|| {
275                panic!(
276                "shutdown_complete_tripwire for source \"{id}\" not found in the ShutdownCoordinator"
277            )
278            });
279        let shutdown_force_trigger = self.force_triggers.remove(id).unwrap_or_else(|| {
280            panic!(
281                "shutdown_force_trigger for source \"{id}\" not found in the ShutdownCoordinator"
282            )
283        });
284        SourceShutdownCoordinator::shutdown_source_complete(
285            shutdown_complete_tripwire,
286            shutdown_force_trigger,
287            id.clone(),
288            Some(deadline),
289        )
290    }
291
292    /// Returned future will finish once all *current* sources have finished.
293    #[must_use]
294    pub fn shutdown_tripwire(&self) -> future::BoxFuture<'static, ()> {
295        let futures = self
296            .complete_tripwires
297            .values()
298            .cloned()
299            .map(|tripwire| tripwire.then(tripwire_handler).boxed());
300
301        future::join_all(futures)
302            .map(|_| info!("All sources have finished."))
303            .boxed()
304    }
305
306    fn shutdown_source_complete(
307        shutdown_complete_tripwire: Tripwire,
308        shutdown_force_trigger: Trigger,
309        id: ComponentKey,
310        deadline: Option<Instant>,
311    ) -> impl Future<Output = bool> {
312        async move {
313            let fut = shutdown_complete_tripwire.then(tripwire_handler);
314            if let Some(deadline) = deadline {
315                // Call `shutdown_force_trigger.disable()` on drop.
316                let shutdown_force_trigger = DisabledTrigger::new(shutdown_force_trigger);
317                if timeout_at(deadline, fut).await.is_ok() {
318                    shutdown_force_trigger.into_inner().disable();
319                    true
320                } else {
321                    error!(
322                        "Source '{}' failed to shutdown before deadline. Forcing shutdown.",
323                        id,
324                    );
325                    shutdown_force_trigger.into_inner().cancel();
326                    false
327                }
328            } else {
329                fut.await;
330                true
331            }
332        }
333        .boxed()
334    }
335}
336
337#[cfg(test)]
338mod test {
339    use tokio::time::{Duration, Instant};
340
341    use super::*;
342    use crate::shutdown::SourceShutdownCoordinator;
343
344    #[tokio::test]
345    async fn shutdown_coordinator_shutdown_source_clean() {
346        let mut shutdown = SourceShutdownCoordinator::default();
347        let id = ComponentKey::from("test");
348
349        let (shutdown_signal, _) = shutdown.register_source(&id, false);
350
351        let deadline = Instant::now() + Duration::from_secs(1);
352        let shutdown_complete = shutdown.shutdown_source(&id, deadline);
353
354        drop(shutdown_signal);
355
356        let success = shutdown_complete.await;
357        assert!(success);
358    }
359
360    #[tokio::test]
361    async fn shutdown_coordinator_shutdown_source_force() {
362        let mut shutdown = SourceShutdownCoordinator::default();
363        let id = ComponentKey::from("test");
364
365        let (_shutdown_signal, force_shutdown_tripwire) = shutdown.register_source(&id, false);
366
367        let deadline = Instant::now() + Duration::from_secs(1);
368        let shutdown_complete = shutdown.shutdown_source(&id, deadline);
369
370        // Since we never drop the `ShutdownSignal` the `ShutdownCoordinator` assumes the Source is
371        // still running and must force shutdown.
372        let success = shutdown_complete.await;
373        assert!(!success);
374
375        let finished = futures::poll!(force_shutdown_tripwire.boxed());
376        assert_eq!(finished, Poll::Ready(()));
377    }
378}