1#![allow(clippy::module_name_repetitions)]
2
3use std::{
4 fmt::Debug,
5 future::Future,
6 marker::{PhantomData, Unpin},
7 pin::Pin,
8 sync::Arc,
9 task::{Context, Poll},
10};
11
12use futures::{
13 FutureExt, Stream, StreamExt,
14 future::OptionFuture,
15 stream::{BoxStream, FuturesOrdered, FuturesUnordered},
16};
17use tokio::sync::{
18 Notify,
19 mpsc::{self, UnboundedReceiver, UnboundedSender},
20};
21
22use crate::{
23 finalization::{BatchStatus, BatchStatusReceiver},
24 shutdown::ShutdownSignal,
25};
26
27pub type OrderedFinalizer<T> = FinalizerSet<T, FuturesOrdered<FinalizerFuture<T>>>;
32
33pub type UnorderedFinalizer<T> = FinalizerSet<T, FuturesUnordered<FinalizerFuture<T>>>;
38
39#[derive(Debug)]
48pub struct FinalizerSet<T, S> {
49 sender: Option<UnboundedSender<(BatchStatusReceiver, T)>>,
50 flush: Arc<Notify>,
51 _phantom: PhantomData<S>,
52}
53
54impl<T, S> FinalizerSet<T, S>
55where
56 T: Send + Debug + 'static,
57 S: FuturesSet<FinalizerFuture<T>> + Default + Send + Unpin + 'static,
58{
59 #[must_use]
70 pub fn new(shutdown: Option<ShutdownSignal>) -> (Self, BoxStream<'static, (BatchStatus, T)>) {
71 let (todo_tx, todo_rx) = mpsc::unbounded_channel();
72 let flush1 = Arc::new(Notify::new());
73 let flush2 = Arc::clone(&flush1);
74 (
75 Self {
76 sender: Some(todo_tx),
77 flush: flush1,
78 _phantom: PhantomData,
79 },
80 finalizer_stream(shutdown, todo_rx, S::default(), flush2).boxed(),
81 )
82 }
83
84 #[must_use]
89 pub fn maybe_new(
90 maybe: bool,
91 shutdown: Option<ShutdownSignal>,
92 ) -> (Option<Self>, BoxStream<'static, (BatchStatus, T)>) {
93 if maybe {
94 let (finalizer, stream) = Self::new(shutdown);
95 (Some(finalizer), stream)
96 } else {
97 (None, EmptyStream::default().boxed())
98 }
99 }
100
101 pub fn add(&self, entry: T, receiver: BatchStatusReceiver) {
102 if let Some(sender) = &self.sender
103 && let Err(error) = sender.send((receiver, entry))
104 {
105 error!(message = "FinalizerSet task ended prematurely.", %error);
106 }
107 }
108
109 pub fn flush(&self) {
110 self.flush.notify_one();
111 }
112}
113
114fn finalizer_stream<T, S>(
115 shutdown: Option<ShutdownSignal>,
116 mut new_entries: UnboundedReceiver<(BatchStatusReceiver, T)>,
117 mut status_receivers: S,
118 flush: Arc<Notify>,
119) -> impl Stream<Item = (BatchStatus, T)>
120where
121 S: Default + FuturesSet<FinalizerFuture<T>> + Unpin,
122{
123 let handle_shutdown = shutdown.is_some();
124 let mut shutdown = OptionFuture::from(shutdown);
125
126 async_stream::stream! {
127 loop {
128 tokio::select! {
129 biased;
130 _ = &mut shutdown, if handle_shutdown => break,
131 () = flush.notified() => {
132 status_receivers = S::default();
134 },
135 finished = status_receivers.next(), if !status_receivers.is_empty() => match finished {
138 Some((status, entry)) => yield (status, entry),
139 None => unreachable!(),
141 },
142 new_entry = new_entries.recv() => match new_entry {
144 Some((receiver, entry)) => {
145 status_receivers.push(FinalizerFuture {
146 receiver,
147 entry: Some(entry),
148 });
149 }
150 None => break,
152 },
153 }
154 }
155
156 while let Some((status, entry)) = status_receivers.next().await {
160 yield (status, entry);
161 }
162
163 drop(shutdown);
166 }
167}
168
169pub trait FuturesSet<Fut: Future>: Stream<Item = Fut::Output> {
170 fn is_empty(&self) -> bool;
171 fn push(&mut self, future: Fut);
172}
173
174impl<Fut: Future> FuturesSet<Fut> for FuturesOrdered<Fut> {
175 fn is_empty(&self) -> bool {
176 Self::is_empty(self)
177 }
178
179 fn push(&mut self, future: Fut) {
180 Self::push_back(self, future);
181 }
182}
183
184impl<Fut: Future> FuturesSet<Fut> for FuturesUnordered<Fut> {
185 fn is_empty(&self) -> bool {
186 Self::is_empty(self)
187 }
188
189 fn push(&mut self, future: Fut) {
190 Self::push(self, future);
191 }
192}
193
194#[pin_project::pin_project]
195pub struct FinalizerFuture<T> {
196 receiver: BatchStatusReceiver,
197 entry: Option<T>,
198}
199
200impl<T> Future for FinalizerFuture<T> {
201 type Output = (<BatchStatusReceiver as Future>::Output, T);
202 fn poll(mut self: Pin<&mut Self>, ctx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
203 let status = std::task::ready!(self.receiver.poll_unpin(ctx));
204 Poll::Ready((status, self.entry.take().unwrap_or_else(|| unreachable!())))
207 }
208}
209
210#[derive(Clone, Copy)]
211pub struct EmptyStream<T>(PhantomData<T>);
212
213impl<T> Default for EmptyStream<T> {
214 fn default() -> Self {
215 Self(PhantomData)
216 }
217}
218
219impl<T> Stream for EmptyStream<T> {
220 type Item = T;
221
222 fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
223 Poll::Pending
224 }
225
226 fn size_hint(&self) -> (usize, Option<usize>) {
227 (0, Some(0))
228 }
229}