vector/components/validation/
sync.rs1use std::sync::{
2 atomic::{AtomicUsize, Ordering},
3 Arc, Mutex,
4};
5
6use tokio::sync::{oneshot, Notify};
7
8struct WaitGroupState {
9 registered: AtomicUsize,
10 done: AtomicUsize,
11 notify: Notify,
12}
13
14impl WaitGroupState {
15 fn all_children_done(&self) -> bool {
16 self.registered.load(Ordering::Acquire) == self.done.load(Ordering::Acquire)
17 }
18}
19pub struct WaitGroup {
21 locked: bool,
22 state: Arc<WaitGroupState>,
23}
24
25pub struct WaitGroupChild {
26 done: bool,
27 state: Arc<WaitGroupState>,
28}
29
30impl WaitGroup {
31 pub fn new() -> Self {
33 Self {
34 locked: false,
35 state: Arc::new(WaitGroupState {
36 registered: AtomicUsize::new(0),
37 done: AtomicUsize::new(0),
38 notify: Notify::new(),
39 }),
40 }
41 }
42
43 pub fn add_child(&self) -> WaitGroupChild {
50 if self.locked {
51 panic!("tried to add child after wait group locked");
52 }
53
54 WaitGroupChild::from_state(&self.state)
55 }
56
57 pub async fn wait_for_children(&mut self) {
62 if !self.locked {
70 self.locked = true;
71 }
72
73 while !self.state.all_children_done() {
74 self.state.notify.notified().await;
75 }
76 }
77}
78
79impl WaitGroupChild {
80 fn from_state(state: &Arc<WaitGroupState>) -> Self {
81 state.registered.fetch_add(1, Ordering::AcqRel);
82
83 Self {
84 done: false,
85 state: Arc::clone(state),
86 }
87 }
88
89 pub fn mark_as_done(mut self) {
94 self.done = true;
95
96 self.state.done.fetch_add(1, Ordering::SeqCst);
97 if self.state.all_children_done() {
98 self.state.notify.notify_one();
99 }
100 }
101}
102
103impl Drop for WaitGroupChild {
104 fn drop(&mut self) {
105 if !self.done {
106 panic!("wait group child dropped without being marked as done");
107 }
108 }
109}
110
111pub struct WaitTrigger {
112 tx: oneshot::Sender<()>,
113}
114
115pub struct WaitHandle {
116 rx: Option<oneshot::Receiver<()>>,
117}
118
119impl WaitTrigger {
120 pub fn new() -> (Self, WaitHandle) {
122 let (tx, rx) = oneshot::channel();
123
124 let trigger = Self { tx };
125 let handle = WaitHandle { rx: Some(rx) };
126
127 (trigger, handle)
128 }
129
130 pub fn trigger(self) {
132 _ = self.tx.send(());
136 }
137}
138
139impl WaitHandle {
140 pub async fn wait(&mut self) {
142 match self.rx.as_mut() {
143 Some(rx) => rx
144 .await
145 .expect("paired task no longer holding wait trigger"),
146 None => panic!("tried to await wait trigger signal but has already been received"),
147 }
148
149 self.rx.take();
152 }
153}
154
155pub struct Configuring {
156 tasks_started: WaitGroup,
157 tasks_completed: WaitGroup,
158 shutdown_triggers: Mutex<Vec<WaitTrigger>>,
159}
160
161pub struct Started {
162 tasks_completed: Option<WaitGroup>,
163 shutdown_triggers: Vec<WaitTrigger>,
164}
165
166pub struct TaskCoordinator<State> {
182 state: State,
183 name: String,
184}
185
186impl TaskCoordinator<()> {
187 pub fn new(name: &str) -> TaskCoordinator<Configuring> {
189 TaskCoordinator {
190 state: Configuring {
191 tasks_started: WaitGroup::new(),
192 tasks_completed: WaitGroup::new(),
193 shutdown_triggers: Mutex::new(Vec::new()),
194 },
195 name: name.to_string(),
196 }
197 }
198}
199
200impl TaskCoordinator<Configuring> {
201 pub fn track_started(&self) -> WaitGroupChild {
203 self.state.tasks_started.add_child()
204 }
205
206 pub fn track_completed(&self) -> WaitGroupChild {
208 self.state.tasks_completed.add_child()
209 }
210
211 pub fn register_for_shutdown(&self) -> WaitHandle {
213 let (trigger, handle) = WaitTrigger::new();
214 self.state
215 .shutdown_triggers
216 .lock()
217 .expect("poisoned")
218 .push(trigger);
219 handle
220 }
221
222 pub async fn started(self) -> TaskCoordinator<Started> {
224 let Configuring {
225 mut tasks_started,
226 tasks_completed,
227 shutdown_triggers,
228 } = self.state;
229
230 tasks_started.wait_for_children().await;
231 trace!("All coordinated tasks reported as having started.");
232
233 TaskCoordinator {
234 state: Started {
235 tasks_completed: Some(tasks_completed),
236 shutdown_triggers: shutdown_triggers.into_inner().expect("poisoned"),
237 },
238 name: self.name,
239 }
240 }
241}
242
243impl TaskCoordinator<Started> {
244 pub async fn shutdown(&mut self) {
246 info!("{}: triggering task to shutdown.", self.name);
247
248 for trigger in self.state.shutdown_triggers.drain(..) {
250 trigger.trigger();
251 debug!("{}: shutdown triggered for coordinated tasks.", self.name);
252 }
253
254 debug!(
256 "{}: waiting for coordinated tasks to complete...",
257 self.name
258 );
259 let tasks_completed = self
260 .state
261 .tasks_completed
262 .as_mut()
263 .expect("tasks completed wait group already consumed");
264 tasks_completed.wait_for_children().await;
265
266 info!("{}: task has been shutdown.", self.name);
267 }
268}