1#![allow(clippy::module_name_repetitions)]
2
3use std::marker::{PhantomData, Unpin};
4use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc, task::Context, task::Poll};
5
6use futures::stream::{BoxStream, FuturesOrdered, FuturesUnordered};
7use futures::{future::OptionFuture, FutureExt, Stream, StreamExt};
8use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
9use tokio::sync::Notify;
10
11use crate::finalization::{BatchStatus, BatchStatusReceiver};
12use crate::shutdown::ShutdownSignal;
13
14pub type OrderedFinalizer<T> = FinalizerSet<T, FuturesOrdered<FinalizerFuture<T>>>;
19
20pub type UnorderedFinalizer<T> = FinalizerSet<T, FuturesUnordered<FinalizerFuture<T>>>;
25
26#[derive(Debug)]
35pub struct FinalizerSet<T, S> {
36 sender: Option<UnboundedSender<(BatchStatusReceiver, T)>>,
37 flush: Arc<Notify>,
38 _phantom: PhantomData<S>,
39}
40
41impl<T, S> FinalizerSet<T, S>
42where
43 T: Send + Debug + 'static,
44 S: FuturesSet<FinalizerFuture<T>> + Default + Send + Unpin + 'static,
45{
46 #[must_use]
57 pub fn new(shutdown: Option<ShutdownSignal>) -> (Self, BoxStream<'static, (BatchStatus, T)>) {
58 let (todo_tx, todo_rx) = mpsc::unbounded_channel();
59 let flush1 = Arc::new(Notify::new());
60 let flush2 = Arc::clone(&flush1);
61 (
62 Self {
63 sender: Some(todo_tx),
64 flush: flush1,
65 _phantom: PhantomData,
66 },
67 finalizer_stream(shutdown, todo_rx, S::default(), flush2).boxed(),
68 )
69 }
70
71 #[must_use]
76 pub fn maybe_new(
77 maybe: bool,
78 shutdown: Option<ShutdownSignal>,
79 ) -> (Option<Self>, BoxStream<'static, (BatchStatus, T)>) {
80 if maybe {
81 let (finalizer, stream) = Self::new(shutdown);
82 (Some(finalizer), stream)
83 } else {
84 (None, EmptyStream::default().boxed())
85 }
86 }
87
88 pub fn add(&self, entry: T, receiver: BatchStatusReceiver) {
89 if let Some(sender) = &self.sender {
90 if let Err(error) = sender.send((receiver, entry)) {
91 error!(message = "FinalizerSet task ended prematurely.", %error);
92 }
93 }
94 }
95
96 pub fn flush(&self) {
97 self.flush.notify_one();
98 }
99}
100
101fn finalizer_stream<T, S>(
102 shutdown: Option<ShutdownSignal>,
103 mut new_entries: UnboundedReceiver<(BatchStatusReceiver, T)>,
104 mut status_receivers: S,
105 flush: Arc<Notify>,
106) -> impl Stream<Item = (BatchStatus, T)>
107where
108 S: Default + FuturesSet<FinalizerFuture<T>> + Unpin,
109{
110 let handle_shutdown = shutdown.is_some();
111 let mut shutdown = OptionFuture::from(shutdown);
112
113 async_stream::stream! {
114 loop {
115 tokio::select! {
116 biased;
117 _ = &mut shutdown, if handle_shutdown => break,
118 () = flush.notified() => {
119 status_receivers = S::default();
121 },
122 finished = status_receivers.next(), if !status_receivers.is_empty() => match finished {
125 Some((status, entry)) => yield (status, entry),
126 None => unreachable!(),
128 },
129 new_entry = new_entries.recv() => match new_entry {
131 Some((receiver, entry)) => {
132 status_receivers.push(FinalizerFuture {
133 receiver,
134 entry: Some(entry),
135 });
136 }
137 None => break,
139 },
140 }
141 }
142
143 while let Some((status, entry)) = status_receivers.next().await {
147 yield (status, entry);
148 }
149
150 drop(shutdown);
153 }
154}
155
156pub trait FuturesSet<Fut: Future>: Stream<Item = Fut::Output> {
157 fn is_empty(&self) -> bool;
158 fn push(&mut self, future: Fut);
159}
160
161impl<Fut: Future> FuturesSet<Fut> for FuturesOrdered<Fut> {
162 fn is_empty(&self) -> bool {
163 Self::is_empty(self)
164 }
165
166 fn push(&mut self, future: Fut) {
167 Self::push_back(self, future);
168 }
169}
170
171impl<Fut: Future> FuturesSet<Fut> for FuturesUnordered<Fut> {
172 fn is_empty(&self) -> bool {
173 Self::is_empty(self)
174 }
175
176 fn push(&mut self, future: Fut) {
177 Self::push(self, future);
178 }
179}
180
181#[pin_project::pin_project]
182pub struct FinalizerFuture<T> {
183 receiver: BatchStatusReceiver,
184 entry: Option<T>,
185}
186
187impl<T> Future for FinalizerFuture<T> {
188 type Output = (<BatchStatusReceiver as Future>::Output, T);
189 fn poll(mut self: Pin<&mut Self>, ctx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
190 let status = std::task::ready!(self.receiver.poll_unpin(ctx));
191 Poll::Ready((status, self.entry.take().unwrap_or_else(|| unreachable!())))
194 }
195}
196
197#[derive(Clone, Copy)]
198pub struct EmptyStream<T>(PhantomData<T>);
199
200impl<T> Default for EmptyStream<T> {
201 fn default() -> Self {
202 Self(PhantomData)
203 }
204}
205
206impl<T> Stream for EmptyStream<T> {
207 type Item = T;
208
209 fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
210 Poll::Pending
211 }
212
213 fn size_hint(&self) -> (usize, Option<usize>) {
214 (0, Some(0))
215 }
216}