vector/components/validation/
sync.rs

1use std::sync::{
2    atomic::{AtomicUsize, Ordering},
3    Arc, Mutex,
4};
5
6use tokio::sync::{oneshot, Notify};
7
8struct WaitGroupState {
9    registered: AtomicUsize,
10    done: AtomicUsize,
11    notify: Notify,
12}
13
14impl WaitGroupState {
15    fn all_children_done(&self) -> bool {
16        self.registered.load(Ordering::Acquire) == self.done.load(Ordering::Acquire)
17    }
18}
19/// A synchronization primitive for waiting for an arbitrary number of processes to rendezvous.
20pub struct WaitGroup {
21    locked: bool,
22    state: Arc<WaitGroupState>,
23}
24
25pub struct WaitGroupChild {
26    done: bool,
27    state: Arc<WaitGroupState>,
28}
29
30impl WaitGroup {
31    /// Creates a new `WaitGroup`.
32    pub fn new() -> Self {
33        Self {
34            locked: false,
35            state: Arc::new(WaitGroupState {
36                registered: AtomicUsize::new(0),
37                done: AtomicUsize::new(0),
38                notify: Notify::new(),
39            }),
40        }
41    }
42
43    /// Creates and attaches a new child to this wait group.
44    ///
45    /// ## Panics
46    ///
47    /// If the caller attempts to add a child after calling `wait_for_children` at least once, this
48    /// method will panic.
49    pub fn add_child(&self) -> WaitGroupChild {
50        if self.locked {
51            panic!("tried to add child after wait group locked");
52        }
53
54        WaitGroupChild::from_state(&self.state)
55    }
56
57    /// Waits until all children have marked themselves as done.
58    ///
59    /// If no children were added to the wait group, or all of them have already completed, this
60    /// function returns immediately.
61    pub async fn wait_for_children(&mut self) {
62        // We "lock" ourselves because, if we did not, then _technically_ we can't be sure that caller
63        // hasn't called this method multiple times, after a new child being added in between...
64        // which messes up the invariant that once we start waiting, nothing else should be added.
65        //
66        // It's easier to do that internally, and panic if `add_child` is called after the first
67        // call to `wait_for_children`, rather than deal with having to make this future
68        // cancellation safe some other way.
69        if !self.locked {
70            self.locked = true;
71        }
72
73        while !self.state.all_children_done() {
74            self.state.notify.notified().await;
75        }
76    }
77}
78
79impl WaitGroupChild {
80    fn from_state(state: &Arc<WaitGroupState>) -> Self {
81        state.registered.fetch_add(1, Ordering::AcqRel);
82
83        Self {
84            done: false,
85            state: Arc::clone(state),
86        }
87    }
88
89    /// Marks this child as done.
90    ///
91    /// If the wait group has been finalized and is waiting for all children to be marked as done,
92    /// and this is the last outstanding child to be marked as done, the wait group will be notified.
93    pub fn mark_as_done(mut self) {
94        self.done = true;
95
96        self.state.done.fetch_add(1, Ordering::SeqCst);
97        if self.state.all_children_done() {
98            self.state.notify.notify_one();
99        }
100    }
101}
102
103impl Drop for WaitGroupChild {
104    fn drop(&mut self) {
105        if !self.done {
106            panic!("wait group child dropped without being marked as done");
107        }
108    }
109}
110
111pub struct WaitTrigger {
112    tx: oneshot::Sender<()>,
113}
114
115pub struct WaitHandle {
116    rx: Option<oneshot::Receiver<()>>,
117}
118
119impl WaitTrigger {
120    /// Creates a new waiter pair.
121    pub fn new() -> (Self, WaitHandle) {
122        let (tx, rx) = oneshot::channel();
123
124        let trigger = Self { tx };
125        let handle = WaitHandle { rx: Some(rx) };
126
127        (trigger, handle)
128    }
129
130    /// Triggers the wait handle to wake up.
131    pub fn trigger(self) {
132        // We don't care if our trigger is actually received, because the receiver side may
133        // intentionally not be used i.e. if the code is generic in a way where only some codepaths
134        // wait to be triggered and others don't, but the trigger must always be called regardless.
135        _ = self.tx.send(());
136    }
137}
138
139impl WaitHandle {
140    /// Waits until triggered.
141    pub async fn wait(&mut self) {
142        match self.rx.as_mut() {
143            Some(rx) => rx
144                .await
145                .expect("paired task no longer holding wait trigger"),
146            None => panic!("tried to await wait trigger signal but has already been received"),
147        }
148
149        // If we're here, we've successfully received the signal, so we consume the
150        // receiver, as it cannot be used/polled again.
151        self.rx.take();
152    }
153}
154
155pub struct Configuring {
156    tasks_started: WaitGroup,
157    tasks_completed: WaitGroup,
158    shutdown_triggers: Mutex<Vec<WaitTrigger>>,
159}
160
161pub struct Started {
162    tasks_completed: Option<WaitGroup>,
163    shutdown_triggers: Vec<WaitTrigger>,
164}
165
166/// Coordination primitive for external tasks.
167///
168/// When validating a component, an external resource may be spun up either to provide the inputs to
169/// the component or to act as the collector of outputs from the component. Additionally, other
170/// tasks may be spawned to forward data between parts of the topology. The validation runner must
171/// be able to ensure that these tasks have started, and completed, at different stages of the
172/// validation run, to ensure all inputs have been processed, or that all outputs have been received.
173///
174/// This coordinator uses a state machine that is encoded into the type of the coordinator itself to
175/// ensure that once it has begin configured -- tasks are registered -- that it can only be used in
176/// a forward direction: waiting for all tasks to start, and after that, signalling all tasks to
177/// shutdown and waiting for them to do so.
178///
179/// This approach provides a stronger mechanism for avoiding bugs such as adding registered tasks
180/// after waiting for all tasks to start, and so on.
181pub struct TaskCoordinator<State> {
182    state: State,
183    name: String,
184}
185
186impl TaskCoordinator<()> {
187    /// Creates a new `TaskCoordinator`.
188    pub fn new(name: &str) -> TaskCoordinator<Configuring> {
189        TaskCoordinator {
190            state: Configuring {
191                tasks_started: WaitGroup::new(),
192                tasks_completed: WaitGroup::new(),
193                shutdown_triggers: Mutex::new(Vec::new()),
194            },
195            name: name.to_string(),
196        }
197    }
198}
199
200impl TaskCoordinator<Configuring> {
201    /// Attaches a new child to the wait group that tracks when tasks have started.
202    pub fn track_started(&self) -> WaitGroupChild {
203        self.state.tasks_started.add_child()
204    }
205
206    /// Attaches a new child to the wait group that tracks when tasks have completed.
207    pub fn track_completed(&self) -> WaitGroupChild {
208        self.state.tasks_completed.add_child()
209    }
210
211    /// Registers a handle that will be notified when shutdown is triggered.
212    pub fn register_for_shutdown(&self) -> WaitHandle {
213        let (trigger, handle) = WaitTrigger::new();
214        self.state
215            .shutdown_triggers
216            .lock()
217            .expect("poisoned")
218            .push(trigger);
219        handle
220    }
221
222    /// Waits for all tasks to have marked that they have started.
223    pub async fn started(self) -> TaskCoordinator<Started> {
224        let Configuring {
225            mut tasks_started,
226            tasks_completed,
227            shutdown_triggers,
228        } = self.state;
229
230        tasks_started.wait_for_children().await;
231        trace!("All coordinated tasks reported as having started.");
232
233        TaskCoordinator {
234            state: Started {
235                tasks_completed: Some(tasks_completed),
236                shutdown_triggers: shutdown_triggers.into_inner().expect("poisoned"),
237            },
238            name: self.name,
239        }
240    }
241}
242
243impl TaskCoordinator<Started> {
244    /// Triggers all coordinated tasks to shutdown, and waits for them to mark themselves as completed.
245    pub async fn shutdown(&mut self) {
246        info!("{}: triggering task to shutdown.", self.name);
247
248        // Trigger all registered shutdown handles.
249        for trigger in self.state.shutdown_triggers.drain(..) {
250            trigger.trigger();
251            debug!("{}: shutdown triggered for coordinated tasks.", self.name);
252        }
253
254        // Now simply wait for all of them to mark themselves as completed.
255        debug!(
256            "{}: waiting for coordinated tasks to complete...",
257            self.name
258        );
259        let tasks_completed = self
260            .state
261            .tasks_completed
262            .as_mut()
263            .expect("tasks completed wait group already consumed");
264        tasks_completed.wait_for_children().await;
265
266        info!("{}: task has been shutdown.", self.name);
267    }
268}