1use std::{
2 collections::HashMap,
3 hash::{BuildHasherDefault, Hash},
4 num::NonZeroUsize,
5 pin::Pin,
6 task::{Context, Poll, ready},
7 time::Duration,
8};
9
10use futures::stream::{Fuse, Stream, StreamExt};
11use pin_project::pin_project;
12use tokio_util::time::{DelayQueue, delay_queue::Key};
13use twox_hash::XxHash64;
14use vector_common::byte_size_of::ByteSizeOf;
15use vector_core::{partition::Partitioner, time::KeyedTimer};
16
17use crate::batcher::{
18 BatchConfig,
19 config::BatchConfigParts,
20 data::BatchData,
21 limiter::{ByteSizeOfItemSize, ItemBatchSize, SizeLimit},
22};
23
24pub struct ExpirationQueue<K> {
26 timeout: Duration,
28 expirations: DelayQueue<K>,
30 expiration_map: HashMap<K, Key>,
32}
33
34impl<K> ExpirationQueue<K> {
35 pub fn new(timeout: Duration) -> Self {
39 Self {
40 timeout,
41 expirations: DelayQueue::new(),
42 expiration_map: HashMap::default(),
43 }
44 }
45
46 pub fn len(&self) -> usize {
51 self.expirations.len()
52 }
53
54 pub fn is_empty(&self) -> bool {
56 self.len() == 0
57 }
58}
59
60impl<K> KeyedTimer<K> for ExpirationQueue<K>
61where
62 K: Eq + Hash + Clone,
63{
64 fn clear(&mut self) {
65 self.expirations.clear();
66 self.expiration_map.clear();
67 }
68
69 fn insert(&mut self, item_key: K) {
70 if let Some(expiration_key) = self.expiration_map.get(&item_key) {
71 self.expirations.reset(expiration_key, self.timeout);
74 } else {
75 let expiration_key = self.expirations.insert(item_key.clone(), self.timeout);
78 assert!(
79 self.expiration_map
80 .insert(item_key, expiration_key)
81 .is_none()
82 );
83 }
84 }
85
86 fn remove(&mut self, item_key: &K) {
87 if let Some(expiration_key) = self.expiration_map.remove(item_key) {
88 self.expirations.remove(&expiration_key);
89 }
90 }
91
92 fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<K>> {
93 match ready!(self.expirations.poll_expired(cx)) {
94 None => Poll::Ready(None),
96 Some(expiration) => {
97 assert!(self.expiration_map.remove(expiration.get_ref()).is_some());
100 Poll::Ready(Some(expiration.into_inner()))
101 }
102 }
103 }
104}
105
106#[derive(Copy, Clone, Debug)]
114pub struct BatcherSettings {
115 pub timeout: Duration,
116 pub size_limit: usize,
117 pub item_limit: usize,
118}
119
120impl BatcherSettings {
121 pub const fn new(
122 timeout: Duration,
123 size_limit: NonZeroUsize,
124 item_limit: NonZeroUsize,
125 ) -> Self {
126 BatcherSettings {
127 timeout,
128 size_limit: size_limit.get(),
129 item_limit: item_limit.get(),
130 }
131 }
132
133 pub fn as_byte_size_config<T: ByteSizeOf>(
136 &self,
137 ) -> BatchConfigParts<SizeLimit<ByteSizeOfItemSize>, Vec<T>> {
138 self.as_item_size_config(ByteSizeOfItemSize)
139 }
140
141 pub fn as_item_size_config<T, I>(&self, item_size: I) -> BatchConfigParts<SizeLimit<I>, Vec<T>>
144 where
145 I: ItemBatchSize<T>,
146 {
147 BatchConfigParts {
148 batch_limiter: SizeLimit {
149 batch_size_limit: self.size_limit,
150 batch_item_limit: self.item_limit,
151 current_size: 0,
152 item_size_calculator: item_size,
153 },
154 batch_data: vec![],
155 timeout: self.timeout,
156 }
157 }
158
159 pub fn as_reducer_config<I, T, B>(
162 &self,
163 item_size: I,
164 reducer: B,
165 ) -> BatchConfigParts<SizeLimit<I>, B>
166 where
167 I: ItemBatchSize<T>,
168 B: BatchData<T>,
169 {
170 BatchConfigParts {
171 batch_limiter: SizeLimit {
172 batch_size_limit: self.size_limit,
173 batch_item_limit: self.item_limit,
174 current_size: 0,
175 item_size_calculator: item_size,
176 },
177 batch_data: reducer,
178 timeout: self.timeout,
179 }
180 }
181}
182
183#[pin_project]
184pub struct PartitionedBatcher<St, Prt, KT, C, F, B>
185where
186 Prt: Partitioner,
187{
188 state: F,
191 batches: HashMap<Prt::Key, C, BuildHasherDefault<XxHash64>>,
194 closed_batches: Vec<(Prt::Key, B)>,
198 timer: KT,
200 partitioner: Prt,
202 #[pin]
203 stream: Fuse<St>,
205}
206
207impl<St, Prt, C, F, B> PartitionedBatcher<St, Prt, ExpirationQueue<Prt::Key>, C, F, B>
208where
209 St: Stream<Item = Prt::Item>,
210 Prt: Partitioner + Unpin,
211 Prt::Key: Eq + Hash + Clone,
212 Prt::Item: ByteSizeOf,
213 C: BatchConfig<Prt::Item>,
214 F: Fn() -> C + Send,
215{
216 pub fn new(stream: St, partitioner: Prt, settings: F) -> Self {
217 let timeout = settings().timeout();
218 Self {
219 state: settings,
220 batches: HashMap::default(),
221 closed_batches: Vec::default(),
222 timer: ExpirationQueue::new(timeout),
223 partitioner,
224 stream: stream.fuse(),
225 }
226 }
227}
228
229#[cfg(test)]
230impl<St, Prt, KT, C, F, B> PartitionedBatcher<St, Prt, KT, C, F, B>
231where
232 St: Stream<Item = Prt::Item>,
233 Prt: Partitioner + Unpin,
234 Prt::Key: Eq + Hash + Clone,
235 Prt::Item: ByteSizeOf,
236 C: BatchConfig<Prt::Item>,
237 F: Fn() -> C + Send,
238{
239 pub fn with_timer(stream: St, partitioner: Prt, timer: KT, settings: F) -> Self {
240 Self {
241 state: settings,
242 batches: HashMap::default(),
243 closed_batches: Vec::default(),
244 timer,
245 partitioner,
246 stream: stream.fuse(),
247 }
248 }
249}
250
251impl<St, Prt, KT, C, F, B> Stream for PartitionedBatcher<St, Prt, KT, C, F, B>
252where
253 St: Stream<Item = Prt::Item>,
254 Prt: Partitioner + Unpin,
255 Prt::Key: Eq + Hash + Clone,
256 Prt::Item: ByteSizeOf,
257 KT: KeyedTimer<Prt::Key>,
258 C: BatchConfig<Prt::Item, Batch = B>,
259 F: Fn() -> C + Send,
260{
261 type Item = (Prt::Key, B);
262
263 fn size_hint(&self) -> (usize, Option<usize>) {
264 self.stream.size_hint()
265 }
266
267 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
268 let mut this = self.project();
269 loop {
270 if !this.closed_batches.is_empty() {
271 return Poll::Ready(this.closed_batches.pop());
272 }
273 match this.stream.as_mut().poll_next(cx) {
274 Poll::Pending => match this.timer.poll_expired(cx) {
275 Poll::Pending | Poll::Ready(None) => return Poll::Pending,
278 Poll::Ready(Some(item_key)) => {
279 let mut batch = this
280 .batches
281 .remove(&item_key)
282 .expect("batch should exist if it is set to expire");
283 this.closed_batches.push((item_key, batch.take_batch()));
284 }
285 },
286 Poll::Ready(None) => {
287 if !this.batches.is_empty() {
293 this.timer.clear();
294 this.closed_batches.extend(
295 this.batches
296 .drain()
297 .map(|(key, mut batch)| (key, batch.take_batch())),
298 );
299 continue;
300 }
301 return Poll::Ready(None);
302 }
303 Poll::Ready(Some(item)) => {
304 let item_key = this.partitioner.partition(&item);
305
306 let batch = if let Some(batch) = this.batches.get_mut(&item_key) {
308 batch
309 } else {
310 let batch = (this.state)();
311 this.batches.insert(item_key.clone(), batch);
312 this.timer.insert(item_key.clone());
313 this.batches
314 .get_mut(&item_key)
315 .expect("batch has just been inserted so should exist")
316 };
317
318 let (fits, metadata) = batch.item_fits_in_batch(&item);
319 if !fits {
320 this.closed_batches
324 .push((item_key.clone(), batch.take_batch()));
325
326 this.timer.insert(item_key.clone());
330 }
331
332 batch.push(item, metadata);
334 if batch.is_batch_full() {
335 this.closed_batches
338 .push((item_key.clone(), batch.take_batch()));
339 this.batches.remove(&item_key);
340 this.timer.remove(&item_key);
341 }
342 }
343 }
344 }
345 }
346}
347
348#[allow(clippy::cast_sign_loss)]
349#[cfg(test)]
350mod test {
351 use std::{
352 collections::{HashMap, HashSet},
353 num::{NonZeroU8, NonZeroUsize},
354 pin::Pin,
355 task::{Context, Poll},
356 time::Duration,
357 };
358
359 use futures::{Stream, stream};
360 use pin_project::pin_project;
361 use proptest::prelude::*;
362 use tokio::{pin, time::advance};
363 use vector_core::{partition::Partitioner, time::KeyedTimer};
364
365 use crate::{
366 BatcherSettings,
367 partitioned_batcher::{ExpirationQueue, PartitionedBatcher},
368 };
369
370 #[derive(Debug)]
371 struct TestTimer {
378 responses: Vec<Poll<Option<u8>>>,
379 valid_keys: HashSet<u8>,
380 }
381
382 impl TestTimer {
383 fn new(responses: Vec<Poll<Option<u8>>>) -> Self {
384 Self {
385 responses,
386 valid_keys: HashSet::new(),
387 }
388 }
389 }
390
391 impl KeyedTimer<u8> for TestTimer {
392 fn clear(&mut self) {
393 self.valid_keys.clear();
394 }
395
396 fn insert(&mut self, item_key: u8) {
397 self.valid_keys.insert(item_key);
398 }
399
400 fn remove(&mut self, item_key: &u8) {
401 self.valid_keys.remove(item_key);
402 }
403
404 fn poll_expired(&mut self, _cx: &mut Context) -> Poll<Option<u8>> {
405 match self.responses.pop() {
406 Some(Poll::Pending) => unreachable!(),
407 None | Some(Poll::Ready(None)) => Poll::Ready(None),
408 Some(Poll::Ready(Some(k))) => {
409 if self.valid_keys.contains(&k) {
410 Poll::Ready(Some(k))
411 } else {
412 Poll::Ready(None)
413 }
414 }
415 }
416 }
417 }
418
419 fn arb_timer() -> impl Strategy<Value = TestTimer> {
420 Vec::<(bool, u8)>::arbitrary()
422 .prop_map(|v| {
423 v.into_iter()
424 .map(|(b, k)| {
425 if b {
426 Poll::Ready(Some(k))
427 } else {
428 Poll::Ready(None)
429 }
430 })
431 .collect()
432 })
433 .prop_map(TestTimer::new)
434 }
435
436 #[pin_project]
441 #[derive(Debug)]
442 struct TestPartitioner {
443 key_space: NonZeroU8,
444 }
445
446 impl Partitioner for TestPartitioner {
447 type Item = u64;
448 type Key = u8;
449
450 #[allow(clippy::cast_possible_truncation)]
451 fn partition(&self, item: &Self::Item) -> Self::Key {
452 let key = *item % u64::from(self.key_space.get());
453 key as Self::Key
454 }
455 }
456
457 fn arb_partitioner() -> impl Strategy<Value = TestPartitioner> {
458 (1..u8::MAX,).prop_map(|(ks,)| TestPartitioner {
459 key_space: NonZeroU8::new(ks).unwrap(),
460 })
461 }
462
463 proptest! {
464 #[test]
465 fn size_hint_eq(stream: Vec<u64>,
466 item_limit in 1..u16::MAX,
467 allocation_limit in 8..128,
468 partitioner in arb_partitioner(),
469 timer in arb_timer()) {
470 let mut stream = stream::iter(stream.into_iter());
476 let stream_size_hint = stream.size_hint();
477 let item_limit = NonZeroUsize::new(item_limit as usize).unwrap();
478 let allocation_limit = NonZeroUsize::new(allocation_limit as usize).unwrap();
479 let batch_settings = BatcherSettings::new(Duration::from_secs(1), allocation_limit, item_limit);
480
481 let batcher = PartitionedBatcher::with_timer(&mut stream, partitioner, timer,
482 Box::new(move || batch_settings.as_byte_size_config()));
483 let batcher_size_hint = batcher.size_hint();
484
485 assert_eq!(stream_size_hint, batcher_size_hint);
486 }
487 }
488
489 proptest! {
490 #[test]
491 fn batch_item_size_leq_limit(stream: Vec<u64>,
492 item_limit in 1..u16::MAX,
493 allocation_limit in 8..128,
494 partitioner in arb_partitioner(),
495 timer in arb_timer()) {
496 let noop_waker = futures::task::noop_waker();
499 let mut cx = Context::from_waker(&noop_waker);
500
501 let mut stream = stream::iter(stream.into_iter());
502 let item_limit = NonZeroUsize::new(item_limit as usize).unwrap();
503 let allocation_limit = NonZeroUsize::new(allocation_limit as usize).unwrap();
504 let batch_settings = BatcherSettings::new(Duration::from_secs(1), allocation_limit, item_limit);
505 let mut batcher = PartitionedBatcher::with_timer(&mut stream, partitioner,
506 timer, Box::new(move || batch_settings.as_byte_size_config()));
507 let mut batcher = Pin::new(&mut batcher);
508
509 loop {
510 match batcher.as_mut().poll_next(&mut cx) {
511 Poll::Pending => {}
512 Poll::Ready(None) => {
513 break;
514 }
515 Poll::Ready(Some((_, batch))) => {
516 debug_assert!(
517 batch.len() <= item_limit.get(),
518 "{} < {}",
519 batch.len(),
520 item_limit.get()
521 );
522 }
523 }
524 }
525 }
526 }
527
528 fn separate_partitions(
534 stream: Vec<u64>,
535 partitioner: &TestPartitioner,
536 ) -> HashMap<u8, Vec<u64>> {
537 let mut map = stream
538 .into_iter()
539 .map(|item| {
540 let key = partitioner.partition(&item);
541 (key, item)
542 })
543 .fold(
544 HashMap::default(),
545 |mut acc: HashMap<u8, Vec<u64>>, (key, item)| {
546 let arr: &mut Vec<u64> = acc.entry(key).or_default();
547 arr.push(item);
548 acc
549 },
550 );
551 for part in map.values_mut() {
552 part.reverse();
553 }
554 map
555 }
556
557 proptest! {
558 #[test]
559 fn batch_does_not_reorder(stream: Vec<u64>,
560 item_limit in 1..u16::MAX,
561 allocation_limit in 8..128,
562 partitioner in arb_partitioner(),
563 timer in arb_timer()) {
564 let noop_waker = futures::task::noop_waker();
568 let mut cx = Context::from_waker(&noop_waker);
569
570 let mut partitions = separate_partitions(stream.clone(), &partitioner);
571
572 let mut stream = stream::iter(stream.into_iter());
573 let item_limit = NonZeroUsize::new(item_limit as usize).unwrap();
574 let allocation_limit = NonZeroUsize::new(allocation_limit as usize).unwrap();
575 let batch_settings = BatcherSettings::new(Duration::from_secs(1), allocation_limit, item_limit);
576 let mut batcher = PartitionedBatcher::with_timer(&mut stream, partitioner,
577 timer, Box::new(move || batch_settings.as_byte_size_config()));
578 let mut batcher = Pin::new(&mut batcher);
579
580 loop {
581 match batcher.as_mut().poll_next(&mut cx) {
582 Poll::Pending => {}
583 Poll::Ready(None) => {
584 break;
585 }
586 Poll::Ready(Some((key, actual_batch))) => {
587 let expected_partition = partitions
588 .get_mut(&key)
589 .expect("impossible situation");
590
591 for item in actual_batch {
592 assert_eq!(item, expected_partition.pop().unwrap());
593 }
594 }
595 }
596 }
597 for v in partitions.values() {
598 assert!(v.is_empty());
599 }
600 }
601 }
602
603 proptest! {
604 #[test]
605 fn batch_does_not_lose_items(stream: Vec<u64>,
606 item_limit in 1..u16::MAX,
607 allocation_limit in 8..128,
608 partitioner in arb_partitioner(),
609 timer in arb_timer()) {
610 let noop_waker = futures::task::noop_waker();
613 let mut cx = Context::from_waker(&noop_waker);
614
615 let total_items = stream.len();
616 let mut stream = stream::iter(stream.into_iter());
617 let item_limit = NonZeroUsize::new(item_limit as usize).unwrap();
618 let allocation_limit = NonZeroUsize::new(allocation_limit as usize).unwrap();
619 let batch_settings = BatcherSettings::new(Duration::from_secs(1), allocation_limit, item_limit);
620 let mut batcher = PartitionedBatcher::with_timer(&mut stream, partitioner,
621 timer, Box::new(move || batch_settings.as_byte_size_config()));
622 let mut batcher = Pin::new(&mut batcher);
623
624 let mut observed_items = 0;
625 loop {
626 match batcher.as_mut().poll_next(&mut cx) {
627 Poll::Pending => {}
628 Poll::Ready(None) => {
629 assert_eq!(observed_items, total_items);
632 break;
633 }
634 Poll::Ready(Some((_, batch))) => {
635 observed_items += batch.len();
636 assert!(observed_items <= total_items);
637 }
638 }
639 }
640 }
641 }
642
643 #[tokio::test(start_paused = true)]
644 #[allow(clippy::semicolon_if_nothing_returned)] async fn expiration_queue_impl_keyed_timer() {
646 let timeout = Duration::from_millis(100); let mut expiration_queue: ExpirationQueue<u8> = ExpirationQueue::new(timeout);
652
653 assert_eq!(0, expiration_queue.len());
656 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
657 assert_eq!(result, Poll::Ready(None));
658
659 expiration_queue.insert(128);
663 assert_eq!(1, expiration_queue.len());
664
665 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
666 assert_eq!(result, Poll::Pending);
667
668 advance(timeout + Duration::from_nanos(1)).await;
669 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
670 assert_eq!(result, Poll::Ready(Some(128)));
671 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
672 assert_eq!(result, Poll::Ready(None));
673
674 assert_eq!(0, expiration_queue.len());
676 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
677 assert_eq!(result, Poll::Ready(None));
678
679 expiration_queue.insert(128);
682 expiration_queue.insert(64);
683 expiration_queue.insert(32);
684 assert_eq!(3, expiration_queue.len());
685 expiration_queue.clear();
686 assert_eq!(0, expiration_queue.len());
687 let result = single_poll(|cx| expiration_queue.poll_expired(cx));
688 assert_eq!(result, Poll::Ready(None));
689 }
690
691 fn single_poll<T, F>(mut f: F) -> Poll<T>
692 where
693 F: FnMut(&mut Context<'_>) -> Poll<T>,
694 {
695 let noop_waker = futures::task::noop_waker();
696 let mut cx = Context::from_waker(&noop_waker);
697
698 f(&mut cx)
699 }
700}