1use std::{
2 collections::HashMap,
3 hash::{BuildHasherDefault, Hash},
4 num::NonZeroUsize,
5 pin::Pin,
6 task::{ready, Context, Poll},
7 time::Duration,
8};
9
10use futures::stream::{Fuse, Stream, StreamExt};
11use pin_project::pin_project;
12use tokio_util::time::{delay_queue::Key, DelayQueue};
13use twox_hash::XxHash64;
14use vector_common::byte_size_of::ByteSizeOf;
15use vector_core::{partition::Partitioner, time::KeyedTimer};
16
17use crate::batcher::{
18 config::BatchConfigParts,
19 data::BatchData,
20 limiter::{ByteSizeOfItemSize, ItemBatchSize, SizeLimit},
21 BatchConfig,
22};
23
24pub struct ExpirationQueue<K> {
26 timeout: Duration,
28 expirations: DelayQueue<K>,
30 expiration_map: HashMap<K, Key>,
32}
33
34impl<K> ExpirationQueue<K> {
35 pub fn new(timeout: Duration) -> Self {
39 Self {
40 timeout,
41 expirations: DelayQueue::new(),
42 expiration_map: HashMap::default(),
43 }
44 }
45
46 pub fn len(&self) -> usize {
51 self.expirations.len()
52 }
53
54 pub fn is_empty(&self) -> bool {
56 self.len() == 0
57 }
58}
59
60impl<K> KeyedTimer<K> for ExpirationQueue<K>
61where
62 K: Eq + Hash + Clone,
63{
64 fn clear(&mut self) {
65 self.expirations.clear();
66 self.expiration_map.clear();
67 }
68
69 fn insert(&mut self, item_key: K) {
70 if let Some(expiration_key) = self.expiration_map.get(&item_key) {
71 self.expirations.reset(expiration_key, self.timeout);
74 } else {
75 let expiration_key = self.expirations.insert(item_key.clone(), self.timeout);
78 assert!(self
79 .expiration_map
80 .insert(item_key, expiration_key)
81 .is_none());
82 }
83 }
84
85 fn remove(&mut self, item_key: &K) {
86 if let Some(expiration_key) = self.expiration_map.remove(item_key) {
87 self.expirations.remove(&expiration_key);
88 }
89 }
90
91 fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<K>> {
92 match ready!(self.expirations.poll_expired(cx)) {
93 None => Poll::Ready(None),
95 Some(expiration) => {
96 assert!(self.expiration_map.remove(expiration.get_ref()).is_some());
99 Poll::Ready(Some(expiration.into_inner()))
100 }
101 }
102 }
103}
104
105#[derive(Copy, Clone, Debug)]
113pub struct BatcherSettings {
114 pub timeout: Duration,
115 pub size_limit: usize,
116 pub item_limit: usize,
117}
118
119impl BatcherSettings {
120 pub const fn new(
121 timeout: Duration,
122 size_limit: NonZeroUsize,
123 item_limit: NonZeroUsize,
124 ) -> Self {
125 BatcherSettings {
126 timeout,
127 size_limit: size_limit.get(),
128 item_limit: item_limit.get(),
129 }
130 }
131
132 pub fn as_byte_size_config<T: ByteSizeOf>(
135 &self,
136 ) -> BatchConfigParts<SizeLimit<ByteSizeOfItemSize>, Vec<T>> {
137 self.as_item_size_config(ByteSizeOfItemSize)
138 }
139
140 pub fn as_item_size_config<T, I>(&self, item_size: I) -> BatchConfigParts<SizeLimit<I>, Vec<T>>
143 where
144 I: ItemBatchSize<T>,
145 {
146 BatchConfigParts {
147 batch_limiter: SizeLimit {
148 batch_size_limit: self.size_limit,
149 batch_item_limit: self.item_limit,
150 current_size: 0,
151 item_size_calculator: item_size,
152 },
153 batch_data: vec![],
154 timeout: self.timeout,
155 }
156 }
157
158 pub fn as_reducer_config<I, T, B>(
161 &self,
162 item_size: I,
163 reducer: B,
164 ) -> BatchConfigParts<SizeLimit<I>, B>
165 where
166 I: ItemBatchSize<T>,
167 B: BatchData<T>,
168 {
169 BatchConfigParts {
170 batch_limiter: SizeLimit {
171 batch_size_limit: self.size_limit,
172 batch_item_limit: self.item_limit,
173 current_size: 0,
174 item_size_calculator: item_size,
175 },
176 batch_data: reducer,
177 timeout: self.timeout,
178 }
179 }
180}
181
182#[pin_project]
183pub struct PartitionedBatcher<St, Prt, KT, C, F, B>
184where
185 Prt: Partitioner,
186{
187 state: F,
190 batches: HashMap<Prt::Key, C, BuildHasherDefault<XxHash64>>,
193 closed_batches: Vec<(Prt::Key, B)>,
197 timer: KT,
199 partitioner: Prt,
201 #[pin]
202 stream: Fuse<St>,
204}
205
206impl<St, Prt, C, F, B> PartitionedBatcher<St, Prt, ExpirationQueue<Prt::Key>, C, F, B>
207where
208 St: Stream<Item = Prt::Item>,
209 Prt: Partitioner + Unpin,
210 Prt::Key: Eq + Hash + Clone,
211 Prt::Item: ByteSizeOf,
212 C: BatchConfig<Prt::Item>,
213 F: Fn() -> C + Send,
214{
215 pub fn new(stream: St, partitioner: Prt, settings: F) -> Self {
216 let timeout = settings().timeout();
217 Self {
218 state: settings,
219 batches: HashMap::default(),
220 closed_batches: Vec::default(),
221 timer: ExpirationQueue::new(timeout),
222 partitioner,
223 stream: stream.fuse(),
224 }
225 }
226}
227
228#[cfg(test)]
229impl<St, Prt, KT, C, F, B> PartitionedBatcher<St, Prt, KT, C, F, B>
230where
231 St: Stream<Item = Prt::Item>,
232 Prt: Partitioner + Unpin,
233 Prt::Key: Eq + Hash + Clone,
234 Prt::Item: ByteSizeOf,
235 C: BatchConfig<Prt::Item>,
236 F: Fn() -> C + Send,
237{
238 pub fn with_timer(stream: St, partitioner: Prt, timer: KT, settings: F) -> Self {
239 Self {
240 state: settings,
241 batches: HashMap::default(),
242 closed_batches: Vec::default(),
243 timer,
244 partitioner,
245 stream: stream.fuse(),
246 }
247 }
248}
249
250impl<St, Prt, KT, C, F, B> Stream for PartitionedBatcher<St, Prt, KT, C, F, B>
251where
252 St: Stream<Item = Prt::Item>,
253 Prt: Partitioner + Unpin,
254 Prt::Key: Eq + Hash + Clone,
255 Prt::Item: ByteSizeOf,
256 KT: KeyedTimer<Prt::Key>,
257 C: BatchConfig<Prt::Item, Batch = B>,
258 F: Fn() -> C + Send,
259{
260 type Item = (Prt::Key, B);
261
262 fn size_hint(&self) -> (usize, Option<usize>) {
263 self.stream.size_hint()
264 }
265
266 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
267 let mut this = self.project();
268 loop {
269 if !this.closed_batches.is_empty() {
270 return Poll::Ready(this.closed_batches.pop());
271 }
272 match this.stream.as_mut().poll_next(cx) {
273 Poll::Pending => match this.timer.poll_expired(cx) {
274 Poll::Pending | Poll::Ready(None) => return Poll::Pending,
277 Poll::Ready(Some(item_key)) => {
278 let mut batch = this
279 .batches
280 .remove(&item_key)
281 .expect("batch should exist if it is set to expire");
282 this.closed_batches.push((item_key, batch.take_batch()));
283 }
284 },
285 Poll::Ready(None) => {
286 if !this.batches.is_empty() {
292 this.timer.clear();
293 this.closed_batches.extend(
294 this.batches
295 .drain()
296 .map(|(key, mut batch)| (key, batch.take_batch())),
297 );
298 continue;
299 }
300 return Poll::Ready(None);
301 }
302 Poll::Ready(Some(item)) => {
303 let item_key = this.partitioner.partition(&item);
304
305 let batch = if let Some(batch) = this.batches.get_mut(&item_key) {
307 batch
308 } else {
309 let batch = (this.state)();
310 this.batches.insert(item_key.clone(), batch);
311 this.timer.insert(item_key.clone());
312 this.batches
313 .get_mut(&item_key)
314 .expect("batch has just been inserted so should exist")
315 };
316
317 let (fits, metadata) = batch.item_fits_in_batch(&item);
318 if !fits {
319 this.closed_batches
323 .push((item_key.clone(), batch.take_batch()));
324
325 this.timer.insert(item_key.clone());
329 }
330
331 batch.push(item, metadata);
333 if batch.is_batch_full() {
334 this.closed_batches
337 .push((item_key.clone(), batch.take_batch()));
338 this.batches.remove(&item_key);
339 this.timer.remove(&item_key);
340 }
341 }
342 }
343 }
344 }
345}
346
347#[allow(clippy::cast_sign_loss)]
348#[cfg(test)]
349mod test {
350 use std::{
351 collections::{HashMap, HashSet},
352 num::{NonZeroU8, NonZeroUsize},
353 pin::Pin,
354 task::{Context, Poll},
355 time::Duration,
356 };
357
358 use futures::{stream, Stream};
359 use pin_project::pin_project;
360 use proptest::prelude::*;
361 use tokio::{pin, time::advance};
362 use vector_core::{partition::Partitioner, time::KeyedTimer};
363
364 use crate::{
365 partitioned_batcher::{ExpirationQueue, PartitionedBatcher},
366 BatcherSettings,
367 };
368
369 #[derive(Debug)]
370 struct TestTimer {
377 responses: Vec<Poll<Option<u8>>>,
378 valid_keys: HashSet<u8>,
379 }
380
381 impl TestTimer {
382 fn new(responses: Vec<Poll<Option<u8>>>) -> Self {
383 Self {
384 responses,
385 valid_keys: HashSet::new(),
386 }
387 }
388 }
389
390 impl KeyedTimer<u8> for TestTimer {
391 fn clear(&mut self) {
392 self.valid_keys.clear();
393 }
394
395 fn insert(&mut self, item_key: u8) {
396 self.valid_keys.insert(item_key);
397 }
398
399 fn remove(&mut self, item_key: &u8) {
400 self.valid_keys.remove(item_key);
401 }
402
403 fn poll_expired(&mut self, _cx: &mut Context) -> Poll<Option<u8>> {
404 match self.responses.pop() {
405 Some(Poll::Pending) => unreachable!(),
406 None | Some(Poll::Ready(None)) => Poll::Ready(None),
407 Some(Poll::Ready(Some(k))) => {
408 if self.valid_keys.contains(&k) {
409 Poll::Ready(Some(k))
410 } else {
411 Poll::Ready(None)
412 }
413 }
414 }
415 }
416 }
417
418 fn arb_timer() -> impl Strategy<Value = TestTimer> {
419 Vec::<(bool, u8)>::arbitrary()
421 .prop_map(|v| {
422 v.into_iter()
423 .map(|(b, k)| {
424 if b {
425 Poll::Ready(Some(k))
426 } else {
427 Poll::Ready(None)
428 }
429 })
430 .collect()
431 })
432 .prop_map(TestTimer::new)
433 }
434
435 #[pin_project]
440 #[derive(Debug)]
441 struct TestPartitioner {
442 key_space: NonZeroU8,
443 }
444
445 impl Partitioner for TestPartitioner {
446 type Item = u64;
447 type Key = u8;
448
449 #[allow(clippy::cast_possible_truncation)]
450 fn partition(&self, item: &Self::Item) -> Self::Key {
451 let key = *item % u64::from(self.key_space.get());
452 key as Self::Key
453 }
454 }
455
456 fn arb_partitioner() -> impl Strategy<Value = TestPartitioner> {
457 (1..u8::MAX,).prop_map(|(ks,)| TestPartitioner {
458 key_space: NonZeroU8::new(ks).unwrap(),
459 })
460 }
461
462 proptest! {
463 #[test]
464 fn size_hint_eq(stream: Vec<u64>,
465 item_limit in 1..u16::MAX,
466 allocation_limit in 8..128,
467 partitioner in arb_partitioner(),
468 timer in arb_timer()) {
469 let mut stream = stream::iter(stream.into_iter());
475 let stream_size_hint = stream.size_hint();
476 let item_limit = NonZeroUsize::new(item_limit as usize).unwrap();
477 let allocation_limit = NonZeroUsize::new(allocation_limit as usize).unwrap();
478 let batch_settings = BatcherSettings::new(Duration::from_secs(1), allocation_limit, item_limit);
479
480 let batcher = PartitionedBatcher::with_timer(&mut stream, partitioner, timer,
481 Box::new(move || batch_settings.as_byte_size_config()));
482 let batcher_size_hint = batcher.size_hint();
483
484 assert_eq!(stream_size_hint, batcher_size_hint);
485 }
486 }
487
488 proptest! {
489 #[test]
490 fn batch_item_size_leq_limit(stream: Vec<u64>,
491 item_limit in 1..u16::MAX,
492 allocation_limit in 8..128,
493 partitioner in arb_partitioner(),
494 timer in arb_timer()) {
495 let noop_waker = futures::task::noop_waker();
498 let mut cx = Context::from_waker(&noop_waker);
499
500 let mut stream = stream::iter(stream.into_iter());
501 let item_limit = NonZeroUsize::new(item_limit as usize).unwrap();
502 let allocation_limit = NonZeroUsize::new(allocation_limit as usize).unwrap();
503 let batch_settings = BatcherSettings::new(Duration::from_secs(1), allocation_limit, item_limit);
504 let mut batcher = PartitionedBatcher::with_timer(&mut stream, partitioner,
505 timer, Box::new(move || batch_settings.as_byte_size_config()));
506 let mut batcher = Pin::new(&mut batcher);
507
508 loop {
509 match batcher.as_mut().poll_next(&mut cx) {
510 Poll::Pending => {}
511 Poll::Ready(None) => {
512 break;
513 }
514 Poll::Ready(Some((_, batch))) => {
515 debug_assert!(
516 batch.len() <= item_limit.get(),
517 "{} < {}",
518 batch.len(),
519 item_limit.get()
520 );
521 }
522 }
523 }
524 }
525 }
526
527 fn separate_partitions(
533 stream: Vec<u64>,
534 partitioner: &TestPartitioner,
535 ) -> HashMap<u8, Vec<u64>> {
536 let mut map = stream
537 .into_iter()
538 .map(|item| {
539 let key = partitioner.partition(&item);
540 (key, item)
541 })
542 .fold(
543 HashMap::default(),
544 |mut acc: HashMap<u8, Vec<u64>>, (key, item)| {
545 let arr: &mut Vec<u64> = acc.entry(key).or_default();
546 arr.push(item);
547 acc
548 },
549 );
550 for part in map.values_mut() {
551 part.reverse();
552 }
553 map
554 }
555
556 proptest! {
557 #[test]
558 fn batch_does_not_reorder(stream: Vec<u64>,
559 item_limit in 1..u16::MAX,
560 allocation_limit in 8..128,
561 partitioner in arb_partitioner(),
562 timer in arb_timer()) {
563 let noop_waker = futures::task::noop_waker();
567 let mut cx = Context::from_waker(&noop_waker);
568
569 let mut partitions = separate_partitions(stream.clone(), &partitioner);
570
571 let mut stream = stream::iter(stream.into_iter());
572 let item_limit = NonZeroUsize::new(item_limit as usize).unwrap();
573 let allocation_limit = NonZeroUsize::new(allocation_limit as usize).unwrap();
574 let batch_settings = BatcherSettings::new(Duration::from_secs(1), allocation_limit, item_limit);
575 let mut batcher = PartitionedBatcher::with_timer(&mut stream, partitioner,
576 timer, Box::new(move || batch_settings.as_byte_size_config()));
577 let mut batcher = Pin::new(&mut batcher);
578
579 loop {
580 match batcher.as_mut().poll_next(&mut cx) {
581 Poll::Pending => {}
582 Poll::Ready(None) => {
583 break;
584 }
585 Poll::Ready(Some((key, actual_batch))) => {
586 let expected_partition = partitions
587 .get_mut(&key)
588 .expect("impossible situation");
589
590 for item in actual_batch {
591 assert_eq!(item, expected_partition.pop().unwrap());
592 }
593 }
594 }
595 }
596 for v in partitions.values() {
597 assert!(v.is_empty());
598 }
599 }
600 }
601
602 proptest! {
603 #[test]
604 fn batch_does_not_lose_items(stream: Vec<u64>,
605 item_limit in 1..u16::MAX,
606 allocation_limit in 8..128,
607 partitioner in arb_partitioner(),
608 timer in arb_timer()) {
609 let noop_waker = futures::task::noop_waker();
612 let mut cx = Context::from_waker(&noop_waker);
613
614 let total_items = stream.len();
615 let mut stream = stream::iter(stream.into_iter());
616 let item_limit = NonZeroUsize::new(item_limit as usize).unwrap();
617 let allocation_limit = NonZeroUsize::new(allocation_limit as usize).unwrap();
618 let batch_settings = BatcherSettings::new(Duration::from_secs(1), allocation_limit, item_limit);
619 let mut batcher = PartitionedBatcher::with_timer(&mut stream, partitioner,
620 timer, Box::new(move || batch_settings.as_byte_size_config()));
621 let mut batcher = Pin::new(&mut batcher);
622
623 let mut observed_items = 0;
624 loop {
625 match batcher.as_mut().poll_next(&mut cx) {
626 Poll::Pending => {}
627 Poll::Ready(None) => {
628 assert_eq!(observed_items, total_items);
631 break;
632 }
633 Poll::Ready(Some((_, batch))) => {
634 observed_items += batch.len();
635 assert!(observed_items <= total_items);
636 }
637 }
638 }
639 }
640 }
641
642 #[tokio::test(start_paused = true)]
643 #[allow(clippy::semicolon_if_nothing_returned)] async fn expiration_queue_impl_keyed_timer() {
645 let timeout = Duration::from_millis(100); let mut expiration_queue: ExpirationQueue<u8> = ExpirationQueue::new(timeout);
651
652 assert_eq!(0, expiration_queue.len());
655 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
656 assert_eq!(result, Poll::Ready(None));
657
658 expiration_queue.insert(128);
662 assert_eq!(1, expiration_queue.len());
663
664 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
665 assert_eq!(result, Poll::Pending);
666
667 advance(timeout + Duration::from_nanos(1)).await;
668 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
669 assert_eq!(result, Poll::Ready(Some(128)));
670 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
671 assert_eq!(result, Poll::Ready(None));
672
673 assert_eq!(0, expiration_queue.len());
675 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
676 assert_eq!(result, Poll::Ready(None));
677
678 expiration_queue.insert(128);
681 expiration_queue.insert(64);
682 expiration_queue.insert(32);
683 assert_eq!(3, expiration_queue.len());
684 expiration_queue.clear();
685 assert_eq!(0, expiration_queue.len());
686 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
687 assert_eq!(result, Poll::Ready(None));
688 }
689
690 fn single_poll<T, F>(mut f: F) -> Poll<T>
691 where
692 F: FnMut(&mut Context<'_>) -> Poll<T>,
693 {
694 let noop_waker = futures::task::noop_waker();
695 let mut cx = Context::from_waker(&noop_waker);
696
697 f(&mut cx)
698 }
699}