vector/test_util/
mod.rs

1#![allow(missing_docs)]
2
3use std::{
4    collections::HashMap,
5    convert::Infallible,
6    fs::File,
7    future::{Future, ready},
8    io::Read,
9    iter,
10    net::SocketAddr,
11    path::{Path, PathBuf},
12    pin::Pin,
13    sync::{
14        Arc,
15        atomic::{AtomicUsize, Ordering},
16    },
17    task::{Context, Poll, ready},
18};
19
20use chrono::{DateTime, SubsecRound, Utc};
21use flate2::read::MultiGzDecoder;
22use futures::{FutureExt, SinkExt, Stream, StreamExt, TryStreamExt, stream, task::noop_waker_ref};
23use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
24use rand::{Rng, rng};
25use rand_distr::Alphanumeric;
26use tokio::{
27    io::{AsyncRead, AsyncWrite, AsyncWriteExt, Result as IoResult},
28    net::{TcpListener, TcpStream, ToSocketAddrs},
29    runtime,
30    sync::oneshot,
31    task::JoinHandle,
32    time::{Duration, Instant, sleep},
33};
34use tokio_stream::wrappers::TcpListenerStream;
35#[cfg(unix)]
36use tokio_stream::wrappers::UnixListenerStream;
37use tokio_util::codec::{Encoder, FramedRead, FramedWrite, LinesCodec};
38use vector_lib::{
39    buffers::topology::channel::LimitedReceiver,
40    event::{
41        BatchNotifier, BatchStatusReceiver, Event, EventArray, LogEvent, Metric, MetricKind,
42        MetricTags, MetricValue,
43    },
44};
45#[cfg(test)]
46use zstd::Decoder as ZstdDecoder;
47
48use crate::{
49    config::{Config, GenerateConfig},
50    topology::{RunningTopology, ShutdownErrorReceiver},
51    trace,
52};
53
54const WAIT_FOR_SECS: u64 = 5; // The default time to wait in `wait_for`
55const WAIT_FOR_MIN_MILLIS: u64 = 5; // The minimum time to pause before retrying
56const WAIT_FOR_MAX_MILLIS: u64 = 500; // The maximum time to pause before retrying
57
58pub mod addr;
59pub mod compression;
60pub mod stats;
61
62#[cfg(any(test, feature = "test-utils"))]
63pub mod components;
64#[cfg(test)]
65pub mod http;
66#[cfg(test)]
67pub mod integration;
68#[cfg(test)]
69pub mod metrics;
70#[cfg(test)]
71pub mod mock;
72
73#[macro_export]
74macro_rules! assert_downcast_matches {
75    ($e:expr_2021, $t:ty, $v:pat) => {{
76        match $e.downcast_ref::<$t>() {
77            Some($v) => (),
78            got => panic!("Assertion failed: got wrong error variant {:?}", got),
79        }
80    }};
81}
82
83#[macro_export]
84macro_rules! log_event {
85    ($($key:expr_2021 => $value:expr_2021),*  $(,)?) => {
86        #[allow(unused_variables)]
87        {
88            let mut event = $crate::event::Event::Log($crate::event::LogEvent::default());
89            let log = event.as_mut_log();
90            $(
91                log.insert($key, $value);
92            )*
93            event
94        }
95    };
96}
97
98pub fn test_generate_config<T>()
99where
100    for<'de> T: GenerateConfig + serde::Deserialize<'de>,
101{
102    let cfg = toml::to_string(&T::generate_config()).unwrap();
103
104    toml::from_str::<T>(&cfg)
105        .unwrap_or_else(|e| panic!("Invalid config generated from string:\n\n{e}\n'{cfg}'"));
106}
107
108pub fn open_fixture(path: impl AsRef<Path>) -> crate::Result<serde_json::Value> {
109    let test_file = match File::open(path) {
110        Ok(file) => file,
111        Err(e) => return Err(e.into()),
112    };
113    let value: serde_json::Value = serde_json::from_reader(test_file)?;
114    Ok(value)
115}
116
117pub fn trace_init() {
118    #[cfg(unix)]
119    let color = {
120        use std::io::IsTerminal;
121        std::io::stdout().is_terminal()
122            || std::env::var("NEXTEST")
123                .ok()
124                .and(Some(true))
125                .unwrap_or(false)
126    };
127    // Windows: ANSI colors are not supported by cmd.exe
128    // Color is false for everything except unix.
129    #[cfg(not(unix))]
130    let color = false;
131
132    let levels = std::env::var("VECTOR_LOG").unwrap_or_else(|_| "error".to_string());
133
134    trace::init(color, false, &levels, 10);
135
136    // Initialize metrics as well
137    vector_lib::metrics::init_test();
138}
139
140pub async fn send_lines(
141    addr: SocketAddr,
142    lines: impl IntoIterator<Item = String>,
143) -> Result<SocketAddr, Infallible> {
144    send_encodable(addr, LinesCodec::new(), lines).await
145}
146
147pub async fn send_encodable<I, E: From<std::io::Error> + std::fmt::Debug>(
148    addr: SocketAddr,
149    encoder: impl Encoder<I, Error = E>,
150    lines: impl IntoIterator<Item = I>,
151) -> Result<SocketAddr, Infallible> {
152    let stream = TcpStream::connect(&addr).await.unwrap();
153
154    let local_addr = stream.local_addr().unwrap();
155
156    let mut sink = FramedWrite::new(stream, encoder);
157
158    let mut lines = stream::iter(lines.into_iter()).map(Ok);
159    sink.send_all(&mut lines).await.unwrap();
160
161    let stream = sink.get_mut();
162    stream.shutdown().await.unwrap();
163
164    Ok(local_addr)
165}
166
167pub async fn send_lines_tls(
168    addr: SocketAddr,
169    host: String,
170    lines: impl Iterator<Item = String>,
171    ca: impl Into<Option<&Path>>,
172    client_cert: impl Into<Option<&Path>>,
173    client_key: impl Into<Option<&Path>>,
174) -> Result<SocketAddr, Infallible> {
175    let stream = TcpStream::connect(&addr).await.unwrap();
176
177    let local_addr = stream.local_addr().unwrap();
178
179    let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
180    if let Some(ca) = ca.into() {
181        connector.set_ca_file(ca).unwrap();
182    } else {
183        connector.set_verify(SslVerifyMode::NONE);
184    }
185
186    if let Some(cert_file) = client_cert.into() {
187        connector.set_certificate_chain_file(cert_file).unwrap();
188    }
189
190    if let Some(key_file) = client_key.into() {
191        connector
192            .set_private_key_file(key_file, SslFiletype::PEM)
193            .unwrap();
194    }
195
196    let ssl = connector
197        .build()
198        .configure()
199        .unwrap()
200        .into_ssl(&host)
201        .unwrap();
202
203    let mut stream = tokio_openssl::SslStream::new(ssl, stream).unwrap();
204    Pin::new(&mut stream).connect().await.unwrap();
205    let mut sink = FramedWrite::new(stream, LinesCodec::new());
206
207    let mut lines = stream::iter(lines).map(Ok);
208    sink.send_all(&mut lines).await.unwrap();
209
210    let stream = sink.get_mut().get_mut();
211    stream.shutdown().await.unwrap();
212
213    Ok(local_addr)
214}
215
216pub fn temp_file() -> PathBuf {
217    let path = std::env::temp_dir();
218    let file_name = random_string(16);
219    path.join(file_name + ".log")
220}
221
222pub fn temp_dir() -> PathBuf {
223    let path = std::env::temp_dir();
224    let dir_name = random_string(16);
225    path.join(dir_name)
226}
227
228pub fn random_table_name() -> String {
229    format!("test_{}", random_string(10).to_lowercase())
230}
231
232pub fn map_event_batch_stream(
233    stream: impl Stream<Item = Event>,
234    batch: Option<BatchNotifier>,
235) -> impl Stream<Item = EventArray> {
236    stream.map(move |event| event.with_batch_notifier_option(&batch).into())
237}
238
239// TODO refactor to have a single implementation for `Event`, `LogEvent` and `Metric`.
240fn map_batch_stream(
241    stream: impl Stream<Item = LogEvent>,
242    batch: Option<BatchNotifier>,
243) -> impl Stream<Item = EventArray> {
244    stream.map(move |log| vec![log.with_batch_notifier_option(&batch)].into())
245}
246
247pub fn generate_lines_with_stream<Gen: FnMut(usize) -> String>(
248    generator: Gen,
249    count: usize,
250    batch: Option<BatchNotifier>,
251) -> (Vec<String>, impl Stream<Item = EventArray>) {
252    let lines = (0..count).map(generator).collect::<Vec<_>>();
253    let stream = map_batch_stream(
254        stream::iter(lines.clone()).map(LogEvent::from_str_legacy),
255        batch,
256    );
257    (lines, stream)
258}
259
260pub fn random_lines_with_stream(
261    len: usize,
262    count: usize,
263    batch: Option<BatchNotifier>,
264) -> (Vec<String>, impl Stream<Item = EventArray>) {
265    let generator = move |_| random_string(len);
266    generate_lines_with_stream(generator, count, batch)
267}
268
269pub fn generate_events_with_stream<Gen: FnMut(usize) -> Event>(
270    generator: Gen,
271    count: usize,
272    batch: Option<BatchNotifier>,
273) -> (Vec<Event>, impl Stream<Item = EventArray>) {
274    let events = (0..count).map(generator).collect::<Vec<_>>();
275    let stream = map_batch_stream(
276        stream::iter(events.clone()).map(|event| event.into_log()),
277        batch,
278    );
279    (events, stream)
280}
281
282pub fn random_metrics_with_stream(
283    count: usize,
284    batch: Option<BatchNotifier>,
285    tags: Option<MetricTags>,
286) -> (Vec<Event>, impl Stream<Item = EventArray>) {
287    random_metrics_with_stream_timestamp(
288        count,
289        batch,
290        tags,
291        Utc::now().trunc_subsecs(3),
292        std::time::Duration::from_secs(2),
293    )
294}
295
296/// Generates event metrics with the provided tags and timestamp.
297///
298/// # Parameters
299/// - `count`: the number of metrics to generate
300/// - `batch`: the batch notifier to use with the stream
301/// - `tags`: the tags to apply to each metric event
302/// - `timestamp`: the timestamp to use for each metric event
303/// - `timestamp_offset`: the offset from the `timestamp` to use for each additional metric
304///
305/// # Returns
306/// A tuple of the generated metric events and the stream of the generated events
307pub fn random_metrics_with_stream_timestamp(
308    count: usize,
309    batch: Option<BatchNotifier>,
310    tags: Option<MetricTags>,
311    timestamp: DateTime<Utc>,
312    timestamp_offset: std::time::Duration,
313) -> (Vec<Event>, impl Stream<Item = EventArray>) {
314    let events: Vec<_> = (0..count)
315        .map(|index| {
316            let ts = timestamp + (timestamp_offset * index as u32);
317            Event::Metric(
318                Metric::new(
319                    format!("counter_{}", rng().random::<u32>()),
320                    MetricKind::Incremental,
321                    MetricValue::Counter {
322                        value: index as f64,
323                    },
324                )
325                .with_timestamp(Some(ts))
326                .with_tags(tags.clone()),
327            )
328            // this ensures we get Origin Metadata, with an undefined service but that's ok.
329            .with_source_type("a_source_like_none_other")
330        })
331        .collect();
332
333    let stream = map_event_batch_stream(stream::iter(events.clone()), batch);
334    (events, stream)
335}
336
337pub fn random_events_with_stream(
338    len: usize,
339    count: usize,
340    batch: Option<BatchNotifier>,
341) -> (Vec<Event>, impl Stream<Item = EventArray>) {
342    let events = (0..count)
343        .map(|_| Event::from(LogEvent::from_str_legacy(random_string(len))))
344        .collect::<Vec<_>>();
345    let stream = map_batch_stream(
346        stream::iter(events.clone()).map(|event| event.into_log()),
347        batch,
348    );
349    (events, stream)
350}
351
352pub fn random_updated_events_with_stream<F>(
353    len: usize,
354    count: usize,
355    batch: Option<BatchNotifier>,
356    update_fn: F,
357) -> (Vec<Event>, impl Stream<Item = EventArray>)
358where
359    F: Fn((usize, LogEvent)) -> LogEvent,
360{
361    let events = (0..count)
362        .map(|_| LogEvent::from_str_legacy(random_string(len)))
363        .enumerate()
364        .map(update_fn)
365        .map(Event::Log)
366        .collect::<Vec<_>>();
367    let stream = map_batch_stream(
368        stream::iter(events.clone()).map(|event| event.into_log()),
369        batch,
370    );
371    (events, stream)
372}
373
374pub fn create_events_batch_with_fn<F: Fn() -> Event>(
375    create_event_fn: F,
376    num_events: usize,
377) -> (Vec<Event>, BatchStatusReceiver) {
378    let mut events = (0..num_events)
379        .map(|_| create_event_fn())
380        .collect::<Vec<_>>();
381    let receiver = BatchNotifier::apply_to(&mut events);
382    (events, receiver)
383}
384
385pub fn random_string(len: usize) -> String {
386    rng()
387        .sample_iter(&Alphanumeric)
388        .take(len)
389        .map(char::from)
390        .collect::<String>()
391}
392
393pub fn random_lines(len: usize) -> impl Iterator<Item = String> {
394    iter::repeat_with(move || random_string(len))
395}
396
397pub fn random_map(max_size: usize, field_len: usize) -> HashMap<String, String> {
398    let size = rng().random_range(0..max_size);
399
400    (0..size)
401        .map(move |_| (random_string(field_len), random_string(field_len)))
402        .collect()
403}
404
405pub fn random_maps(
406    max_size: usize,
407    field_len: usize,
408) -> impl Iterator<Item = HashMap<String, String>> {
409    iter::repeat_with(move || random_map(max_size, field_len))
410}
411
412pub async fn collect_n<S>(rx: S, n: usize) -> Vec<S::Item>
413where
414    S: Stream,
415{
416    rx.take(n).collect().await
417}
418
419pub async fn collect_n_stream<T, S: Stream<Item = T> + Unpin>(stream: &mut S, n: usize) -> Vec<T> {
420    let mut events = Vec::with_capacity(n);
421
422    while events.len() < n {
423        let e = stream.next().await.unwrap();
424        events.push(e);
425    }
426    events
427}
428
429pub async fn collect_ready<S>(mut rx: S) -> Vec<S::Item>
430where
431    S: Stream + Unpin,
432{
433    let waker = noop_waker_ref();
434    let mut cx = Context::from_waker(waker);
435
436    let mut vec = Vec::new();
437    loop {
438        match rx.poll_next_unpin(&mut cx) {
439            Poll::Ready(Some(item)) => vec.push(item),
440            Poll::Ready(None) | Poll::Pending => return vec,
441        }
442    }
443}
444
445pub async fn collect_limited<T: Send + 'static>(mut rx: LimitedReceiver<T>) -> Vec<T> {
446    let mut items = Vec::new();
447    while let Some(item) = rx.next().await {
448        items.push(item);
449    }
450    items
451}
452
453pub async fn collect_n_limited<T: Send + 'static>(mut rx: LimitedReceiver<T>, n: usize) -> Vec<T> {
454    let mut items = Vec::new();
455    while items.len() < n {
456        match rx.next().await {
457            Some(item) => items.push(item),
458            None => break,
459        }
460    }
461    items
462}
463
464pub fn lines_from_file<P: AsRef<Path>>(path: P) -> Vec<String> {
465    trace!(message = "Reading file.", path = %path.as_ref().display());
466    let mut file = File::open(path).unwrap();
467    let mut output = String::new();
468    file.read_to_string(&mut output).unwrap();
469    output.lines().map(|s| s.to_owned()).collect()
470}
471
472pub fn lines_from_gzip_file<P: AsRef<Path>>(path: P) -> Vec<String> {
473    trace!(message = "Reading gzip file.", path = %path.as_ref().display());
474    let mut file = File::open(path).unwrap();
475    let mut gzip_bytes = Vec::new();
476    file.read_to_end(&mut gzip_bytes).unwrap();
477    let mut output = String::new();
478    MultiGzDecoder::new(&gzip_bytes[..])
479        .read_to_string(&mut output)
480        .unwrap();
481    output.lines().map(|s| s.to_owned()).collect()
482}
483
484#[cfg(test)]
485pub fn lines_from_zstd_file<P: AsRef<Path>>(path: P) -> Vec<String> {
486    trace!(message = "Reading zstd file.", path = %path.as_ref().display());
487    let file = File::open(path).unwrap();
488    let mut output = String::new();
489    ZstdDecoder::new(file)
490        .unwrap()
491        .read_to_string(&mut output)
492        .unwrap();
493    output.lines().map(|s| s.to_owned()).collect()
494}
495
496pub fn runtime() -> runtime::Runtime {
497    runtime::Builder::new_multi_thread()
498        .enable_all()
499        .build()
500        .unwrap()
501}
502
503// Wait for a Future to resolve, or the duration to elapse (will panic)
504pub async fn wait_for_duration<F, Fut>(mut f: F, duration: Duration)
505where
506    F: FnMut() -> Fut,
507    Fut: Future<Output = bool> + Send + 'static,
508{
509    let started = Instant::now();
510    let mut delay = WAIT_FOR_MIN_MILLIS;
511    while !f().await {
512        sleep(Duration::from_millis(delay)).await;
513        if started.elapsed() > duration {
514            panic!("Timed out while waiting");
515        }
516        // quadratic backoff up to a maximum delay
517        delay = (delay * 2).min(WAIT_FOR_MAX_MILLIS);
518    }
519}
520
521// Wait for 5 seconds
522pub async fn wait_for<F, Fut>(f: F)
523where
524    F: FnMut() -> Fut,
525    Fut: Future<Output = bool> + Send + 'static,
526{
527    wait_for_duration(f, Duration::from_secs(WAIT_FOR_SECS)).await
528}
529
530// Wait (for 5 secs) for a TCP socket to be reachable
531pub async fn wait_for_tcp<A>(addr: A)
532where
533    A: ToSocketAddrs + Clone + Send + 'static,
534{
535    wait_for(move || {
536        let addr = addr.clone();
537        async move { TcpStream::connect(addr).await.is_ok() }
538    })
539    .await
540}
541
542// Allows specifying a custom duration to wait for a TCP socket to be reachable
543pub async fn wait_for_tcp_duration(addr: SocketAddr, duration: Duration) {
544    wait_for_duration(
545        || async move { TcpStream::connect(addr).await.is_ok() },
546        duration,
547    )
548    .await
549}
550
551pub async fn wait_for_atomic_usize<T, F>(value: T, unblock: F)
552where
553    T: AsRef<AtomicUsize>,
554    F: Fn(usize) -> bool,
555{
556    let value = value.as_ref();
557    wait_for(|| ready(unblock(value.load(Ordering::SeqCst)))).await
558}
559
560// Retries a func every `retry` duration until given an Ok(T); panics after `until` elapses
561pub async fn retry_until<'a, F, Fut, T, E>(mut f: F, retry: Duration, until: Duration) -> T
562where
563    F: FnMut() -> Fut,
564    Fut: Future<Output = Result<T, E>> + Send + 'a,
565{
566    let started = Instant::now();
567    while started.elapsed() < until {
568        match f().await {
569            Ok(res) => return res,
570            Err(_) => tokio::time::sleep(retry).await,
571        }
572    }
573    panic!("Timeout")
574}
575
576pub struct CountReceiver<T> {
577    count: Arc<AtomicUsize>,
578    trigger: Option<oneshot::Sender<()>>,
579    connected: Option<oneshot::Receiver<()>>,
580    handle: JoinHandle<Vec<T>>,
581}
582
583impl<T: Send + 'static> CountReceiver<T> {
584    pub fn count(&self) -> usize {
585        self.count.load(Ordering::Relaxed)
586    }
587
588    /// Succeeds once first connection has been made.
589    pub async fn connected(&mut self) {
590        if let Some(tripwire) = self.connected.take() {
591            tripwire.await.unwrap();
592        }
593    }
594
595    fn new<F, Fut>(make_fut: F) -> CountReceiver<T>
596    where
597        F: FnOnce(Arc<AtomicUsize>, oneshot::Receiver<()>, oneshot::Sender<()>) -> Fut,
598        Fut: Future<Output = Vec<T>> + Send + 'static,
599    {
600        let count = Arc::new(AtomicUsize::new(0));
601        let (trigger, tripwire) = oneshot::channel();
602        let (trigger_connected, connected) = oneshot::channel();
603
604        CountReceiver {
605            count: Arc::clone(&count),
606            trigger: Some(trigger),
607            connected: Some(connected),
608            handle: tokio::spawn(make_fut(count, tripwire, trigger_connected)),
609        }
610    }
611
612    pub fn receive_items_stream<S, F, Fut>(make_stream: F) -> CountReceiver<T>
613    where
614        S: Stream<Item = T> + Send + 'static,
615        F: FnOnce(oneshot::Receiver<()>, oneshot::Sender<()>) -> Fut + Send + 'static,
616        Fut: Future<Output = S> + Send + 'static,
617    {
618        CountReceiver::new(|count, tripwire, connected| async move {
619            let stream = make_stream(tripwire, connected).await;
620            stream
621                .inspect(move |_| {
622                    count.fetch_add(1, Ordering::Relaxed);
623                })
624                .collect::<Vec<T>>()
625                .await
626        })
627    }
628}
629
630impl<T> Future for CountReceiver<T> {
631    type Output = Vec<T>;
632
633    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
634        let this = self.get_mut();
635        if let Some(trigger) = this.trigger.take() {
636            _ = trigger.send(());
637        }
638
639        let result = ready!(this.handle.poll_unpin(cx));
640        Poll::Ready(result.unwrap())
641    }
642}
643
644impl CountReceiver<String> {
645    pub fn receive_lines(addr: SocketAddr) -> CountReceiver<String> {
646        CountReceiver::new(|count, tripwire, connected| async move {
647            let listener = TcpListener::bind(addr).await.unwrap();
648            CountReceiver::receive_lines_stream(
649                TcpListenerStream::new(listener),
650                count,
651                tripwire,
652                Some(connected),
653            )
654            .await
655        })
656    }
657
658    #[cfg(unix)]
659    pub fn receive_lines_unix<P>(path: P) -> CountReceiver<String>
660    where
661        P: AsRef<Path> + Send + 'static,
662    {
663        CountReceiver::new(|count, tripwire, connected| async move {
664            let listener = tokio::net::UnixListener::bind(path).unwrap();
665            CountReceiver::receive_lines_stream(
666                UnixListenerStream::new(listener),
667                count,
668                tripwire,
669                Some(connected),
670            )
671            .await
672        })
673    }
674
675    async fn receive_lines_stream<S, T>(
676        stream: S,
677        count: Arc<AtomicUsize>,
678        tripwire: oneshot::Receiver<()>,
679        mut connected: Option<oneshot::Sender<()>>,
680    ) -> Vec<String>
681    where
682        S: Stream<Item = IoResult<T>>,
683        T: AsyncWrite + AsyncRead,
684    {
685        stream
686            .take_until(tripwire)
687            .map_ok(|socket| FramedRead::new(socket, LinesCodec::new()))
688            .map(|x| {
689                connected.take().map(|trigger| trigger.send(()));
690                x.unwrap()
691            })
692            .flatten()
693            .map(|x| x.unwrap())
694            .inspect(move |_| {
695                count.fetch_add(1, Ordering::Relaxed);
696            })
697            .collect::<Vec<String>>()
698            .await
699    }
700}
701
702impl CountReceiver<Event> {
703    pub fn receive_events<S>(stream: S) -> CountReceiver<Event>
704    where
705        S: Stream<Item = Event> + Send + 'static,
706    {
707        CountReceiver::new(|count, tripwire, connected| async move {
708            connected.send(()).unwrap();
709            stream
710                .take_until(tripwire)
711                .inspect(move |_| {
712                    count.fetch_add(1, Ordering::Relaxed);
713                })
714                .collect::<Vec<Event>>()
715                .await
716        })
717    }
718}
719
720pub async fn start_topology(
721    mut config: Config,
722    require_healthy: impl Into<Option<bool>>,
723) -> (RunningTopology, ShutdownErrorReceiver) {
724    config.healthchecks.set_require_healthy(require_healthy);
725    RunningTopology::start_init_validated(config, Default::default())
726        .await
727        .unwrap()
728}
729
730/// Collect the first `n` events from a stream while a future is spawned
731/// in the background. This is used for tests where the collect has to
732/// happen concurrent with the sending process (ie the stream is
733/// handling finalization, which is required for the future to receive
734/// an acknowledgement).
735pub async fn spawn_collect_n<F, S>(future: F, stream: S, n: usize) -> Vec<Event>
736where
737    F: Future<Output = ()> + Send + 'static,
738    S: Stream<Item = Event>,
739{
740    // TODO: Switch to using `select!` so that we can drive `future` to completion while also driving `collect_n`,
741    // such that if `future` panics, we break out and don't continue driving `collect_n`. In most cases, `future`
742    // completing successfully is what actually drives events into `stream`, so continuing to wait for all N events when
743    // the catalyst has failed is.... almost never the desired behavior.
744    let sender = tokio::spawn(future);
745    let events = collect_n(stream, n).await;
746    sender.await.expect("Failed to send data");
747    events
748}
749
750/// Collect all the ready events from a stream after spawning a future
751/// in the background and letting it run for a given interval. This is
752/// used for tests where the collect has to happen concurrent with the
753/// sending process (ie the stream is handling finalization, which is
754/// required for the future to receive an acknowledgement).
755pub async fn spawn_collect_ready<F, S>(future: F, stream: S, sleep: u64) -> Vec<Event>
756where
757    F: Future<Output = ()> + Send + 'static,
758    S: Stream<Item = Event> + Unpin,
759{
760    let sender = tokio::spawn(future);
761    tokio::time::sleep(Duration::from_secs(sleep)).await;
762    let events = collect_ready(stream).await;
763    sender.await.expect("Failed to send data");
764    events
765}
766
767#[cfg(test)]
768mod tests {
769    use std::{
770        sync::{Arc, RwLock},
771        time::Duration,
772    };
773
774    use super::retry_until;
775
776    // helper which errors the first 3x, and succeeds on the 4th
777    async fn retry_until_helper(count: Arc<RwLock<i32>>) -> Result<(), ()> {
778        if *count.read().unwrap() < 3 {
779            let mut c = count.write().unwrap();
780            *c += 1;
781            return Err(());
782        }
783        Ok(())
784    }
785
786    #[tokio::test]
787    async fn retry_until_before_timeout() {
788        let count = Arc::new(RwLock::new(0));
789        let func = || {
790            let count = Arc::clone(&count);
791            retry_until_helper(count)
792        };
793
794        retry_until(func, Duration::from_millis(10), Duration::from_secs(1)).await;
795    }
796}