vector_stream/batcher/
mod.rs

1pub mod config;
2pub mod data;
3pub mod limiter;
4
5use std::{
6    pin::Pin,
7    task::{ready, Context, Poll},
8};
9
10pub use config::BatchConfig;
11use futures::{
12    stream::{Fuse, Stream},
13    Future, StreamExt,
14};
15use pin_project::pin_project;
16use tokio::time::Sleep;
17
18#[pin_project]
19pub struct Batcher<S, C> {
20    state: C,
21
22    #[pin]
23    /// The stream this `Batcher` wraps
24    stream: Fuse<S>,
25
26    #[pin]
27    timer: Maybe<Sleep>,
28}
29
30/// An `Option`, but with pin projection
31#[pin_project(project = MaybeProj)]
32pub enum Maybe<T> {
33    Some(#[pin] T),
34    None,
35}
36
37impl<S, C> Batcher<S, C>
38where
39    S: Stream,
40    C: BatchConfig<S::Item>,
41{
42    pub fn new(stream: S, config: C) -> Self {
43        Self {
44            state: config,
45            stream: stream.fuse(),
46            timer: Maybe::None,
47        }
48    }
49}
50
51impl<S, C> Stream for Batcher<S, C>
52where
53    S: Stream,
54    C: BatchConfig<S::Item>,
55{
56    type Item = C::Batch;
57
58    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
59        loop {
60            let mut this = self.as_mut().project();
61            match this.stream.poll_next(cx) {
62                Poll::Ready(None) => {
63                    return {
64                        if this.state.len() == 0 {
65                            Poll::Ready(None)
66                        } else {
67                            Poll::Ready(Some(this.state.take_batch()))
68                        }
69                    }
70                }
71                Poll::Ready(Some(item)) => {
72                    let (item_fits, item_metadata) = this.state.item_fits_in_batch(&item);
73                    if item_fits {
74                        this.state.push(item, item_metadata);
75                        if this.state.is_batch_full() {
76                            this.timer.set(Maybe::None);
77                            return Poll::Ready(Some(this.state.take_batch()));
78                        } else if this.state.len() == 1 {
79                            this.timer
80                                .set(Maybe::Some(tokio::time::sleep(this.state.timeout())));
81                        }
82                    } else {
83                        let output = Poll::Ready(Some(this.state.take_batch()));
84                        this.state.push(item, item_metadata);
85                        this.timer
86                            .set(Maybe::Some(tokio::time::sleep(this.state.timeout())));
87                        return output;
88                    }
89                }
90                Poll::Pending => {
91                    return {
92                        if let MaybeProj::Some(timer) = this.timer.as_mut().project() {
93                            ready!(timer.poll(cx));
94                            this.timer.set(Maybe::None);
95                            debug_assert!(
96                                this.state.len() != 0,
97                                "timer should have been cancelled"
98                            );
99                            Poll::Ready(Some(this.state.take_batch()))
100                        } else {
101                            Poll::Pending
102                        }
103                    }
104                }
105            }
106        }
107    }
108
109    fn size_hint(&self) -> (usize, Option<usize>) {
110        self.stream.size_hint()
111    }
112}
113
114#[cfg(test)]
115#[allow(clippy::similar_names)]
116mod test {
117    use std::{num::NonZeroUsize, time::Duration};
118
119    use futures::stream;
120
121    use super::*;
122    use crate::BatcherSettings;
123
124    #[tokio::test]
125    async fn item_limit() {
126        let stream = stream::iter([1, 2, 3]);
127        let settings = BatcherSettings::new(
128            Duration::from_millis(100),
129            NonZeroUsize::new(10000).unwrap(),
130            NonZeroUsize::new(2).unwrap(),
131        );
132        let batcher = Batcher::new(stream, settings.as_item_size_config(|x: &u32| *x as usize));
133        let batches: Vec<_> = batcher.collect().await;
134        assert_eq!(batches, vec![vec![1, 2], vec![3],]);
135    }
136
137    #[tokio::test]
138    async fn size_limit() {
139        let batcher = Batcher::new(
140            stream::iter([1, 2, 3, 4, 5, 6, 2, 3, 1]),
141            BatcherSettings::new(
142                Duration::from_millis(100),
143                NonZeroUsize::new(5).unwrap(),
144                NonZeroUsize::new(100).unwrap(),
145            )
146            .as_item_size_config(|x: &u32| *x as usize),
147        );
148        let batches: Vec<_> = batcher.collect().await;
149        assert_eq!(
150            batches,
151            vec![
152                vec![1, 2],
153                vec![3],
154                vec![4],
155                vec![5],
156                vec![6],
157                vec![2, 3],
158                vec![1],
159            ]
160        );
161    }
162
163    #[tokio::test]
164    async fn timeout_limit() {
165        tokio::time::pause();
166
167        let timeout = Duration::from_millis(100);
168        let stream = stream::iter([1, 2]).chain(stream::pending());
169        let batcher = Batcher::new(
170            stream,
171            BatcherSettings::new(
172                timeout,
173                NonZeroUsize::new(5).unwrap(),
174                NonZeroUsize::new(100).unwrap(),
175            )
176            .as_item_size_config(|x: &u32| *x as usize),
177        );
178
179        tokio::pin!(batcher);
180        let mut next = batcher.next();
181        assert_eq!(futures::poll!(&mut next), Poll::Pending);
182        tokio::time::advance(timeout).await;
183        let batch = next.await;
184        assert_eq!(batch, Some(vec![1, 2]));
185    }
186}