vector_stream/batcher/
mod.rs1pub mod config;
2pub mod data;
3pub mod limiter;
4
5use std::{
6 pin::Pin,
7 task::{ready, Context, Poll},
8};
9
10pub use config::BatchConfig;
11use futures::{
12 stream::{Fuse, Stream},
13 Future, StreamExt,
14};
15use pin_project::pin_project;
16use tokio::time::Sleep;
17
18#[pin_project]
19pub struct Batcher<S, C> {
20 state: C,
21
22 #[pin]
23 stream: Fuse<S>,
25
26 #[pin]
27 timer: Maybe<Sleep>,
28}
29
30#[pin_project(project = MaybeProj)]
32pub enum Maybe<T> {
33 Some(#[pin] T),
34 None,
35}
36
37impl<S, C> Batcher<S, C>
38where
39 S: Stream,
40 C: BatchConfig<S::Item>,
41{
42 pub fn new(stream: S, config: C) -> Self {
43 Self {
44 state: config,
45 stream: stream.fuse(),
46 timer: Maybe::None,
47 }
48 }
49}
50
51impl<S, C> Stream for Batcher<S, C>
52where
53 S: Stream,
54 C: BatchConfig<S::Item>,
55{
56 type Item = C::Batch;
57
58 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
59 loop {
60 let mut this = self.as_mut().project();
61 match this.stream.poll_next(cx) {
62 Poll::Ready(None) => {
63 return {
64 if this.state.len() == 0 {
65 Poll::Ready(None)
66 } else {
67 Poll::Ready(Some(this.state.take_batch()))
68 }
69 }
70 }
71 Poll::Ready(Some(item)) => {
72 let (item_fits, item_metadata) = this.state.item_fits_in_batch(&item);
73 if item_fits {
74 this.state.push(item, item_metadata);
75 if this.state.is_batch_full() {
76 this.timer.set(Maybe::None);
77 return Poll::Ready(Some(this.state.take_batch()));
78 } else if this.state.len() == 1 {
79 this.timer
80 .set(Maybe::Some(tokio::time::sleep(this.state.timeout())));
81 }
82 } else {
83 let output = Poll::Ready(Some(this.state.take_batch()));
84 this.state.push(item, item_metadata);
85 this.timer
86 .set(Maybe::Some(tokio::time::sleep(this.state.timeout())));
87 return output;
88 }
89 }
90 Poll::Pending => {
91 return {
92 if let MaybeProj::Some(timer) = this.timer.as_mut().project() {
93 ready!(timer.poll(cx));
94 this.timer.set(Maybe::None);
95 debug_assert!(
96 this.state.len() != 0,
97 "timer should have been cancelled"
98 );
99 Poll::Ready(Some(this.state.take_batch()))
100 } else {
101 Poll::Pending
102 }
103 }
104 }
105 }
106 }
107 }
108
109 fn size_hint(&self) -> (usize, Option<usize>) {
110 self.stream.size_hint()
111 }
112}
113
114#[cfg(test)]
115#[allow(clippy::similar_names)]
116mod test {
117 use std::{num::NonZeroUsize, time::Duration};
118
119 use futures::stream;
120
121 use super::*;
122 use crate::BatcherSettings;
123
124 #[tokio::test]
125 async fn item_limit() {
126 let stream = stream::iter([1, 2, 3]);
127 let settings = BatcherSettings::new(
128 Duration::from_millis(100),
129 NonZeroUsize::new(10000).unwrap(),
130 NonZeroUsize::new(2).unwrap(),
131 );
132 let batcher = Batcher::new(stream, settings.as_item_size_config(|x: &u32| *x as usize));
133 let batches: Vec<_> = batcher.collect().await;
134 assert_eq!(batches, vec![vec![1, 2], vec![3],]);
135 }
136
137 #[tokio::test]
138 async fn size_limit() {
139 let batcher = Batcher::new(
140 stream::iter([1, 2, 3, 4, 5, 6, 2, 3, 1]),
141 BatcherSettings::new(
142 Duration::from_millis(100),
143 NonZeroUsize::new(5).unwrap(),
144 NonZeroUsize::new(100).unwrap(),
145 )
146 .as_item_size_config(|x: &u32| *x as usize),
147 );
148 let batches: Vec<_> = batcher.collect().await;
149 assert_eq!(
150 batches,
151 vec![
152 vec![1, 2],
153 vec![3],
154 vec![4],
155 vec![5],
156 vec![6],
157 vec![2, 3],
158 vec![1],
159 ]
160 );
161 }
162
163 #[tokio::test]
164 async fn timeout_limit() {
165 tokio::time::pause();
166
167 let timeout = Duration::from_millis(100);
168 let stream = stream::iter([1, 2]).chain(stream::pending());
169 let batcher = Batcher::new(
170 stream,
171 BatcherSettings::new(
172 timeout,
173 NonZeroUsize::new(5).unwrap(),
174 NonZeroUsize::new(100).unwrap(),
175 )
176 .as_item_size_config(|x: &u32| *x as usize),
177 );
178
179 tokio::pin!(batcher);
180 let mut next = batcher.next();
181 assert_eq!(futures::poll!(&mut next), Poll::Pending);
182 tokio::time::advance(timeout).await;
183 let batch = next.await;
184 assert_eq!(batch, Some(vec![1, 2]));
185 }
186}