vector/test_util/
mod.rs

1#![allow(missing_docs)]
2use std::{
3    collections::HashMap,
4    convert::Infallible,
5    fs::File,
6    future::{Future, ready},
7    io::Read,
8    iter,
9    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
10    path::{Path, PathBuf},
11    pin::Pin,
12    sync::{
13        Arc,
14        atomic::{AtomicUsize, Ordering},
15    },
16    task::{Context, Poll, ready},
17};
18
19use chrono::{DateTime, SubsecRound, Utc};
20use flate2::read::MultiGzDecoder;
21use futures::{FutureExt, SinkExt, Stream, StreamExt, TryStreamExt, stream, task::noop_waker_ref};
22use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
23use portpicker::pick_unused_port;
24use rand::{Rng, rng};
25use rand_distr::Alphanumeric;
26use tokio::{
27    io::{AsyncRead, AsyncWrite, AsyncWriteExt, Result as IoResult},
28    net::{TcpListener, TcpStream, ToSocketAddrs},
29    runtime,
30    sync::oneshot,
31    task::JoinHandle,
32    time::{Duration, Instant, sleep},
33};
34use tokio_stream::wrappers::TcpListenerStream;
35#[cfg(unix)]
36use tokio_stream::wrappers::UnixListenerStream;
37use tokio_util::codec::{Encoder, FramedRead, FramedWrite, LinesCodec};
38use vector_lib::{
39    buffers::topology::channel::LimitedReceiver,
40    event::{
41        BatchNotifier, BatchStatusReceiver, Event, EventArray, LogEvent, Metric, MetricKind,
42        MetricTags, MetricValue,
43    },
44};
45#[cfg(test)]
46use zstd::Decoder as ZstdDecoder;
47
48use crate::{
49    config::{Config, GenerateConfig},
50    topology::{RunningTopology, ShutdownErrorReceiver},
51    trace,
52};
53
54const WAIT_FOR_SECS: u64 = 5; // The default time to wait in `wait_for`
55const WAIT_FOR_MIN_MILLIS: u64 = 5; // The minimum time to pause before retrying
56const WAIT_FOR_MAX_MILLIS: u64 = 500; // The maximum time to pause before retrying
57
58#[cfg(any(test, feature = "test-utils"))]
59pub mod components;
60
61#[cfg(test)]
62pub mod http;
63
64#[cfg(test)]
65pub mod metrics;
66
67#[cfg(test)]
68pub mod mock;
69
70pub mod compression;
71pub mod stats;
72
73#[cfg(test)]
74pub mod integration;
75
76#[macro_export]
77macro_rules! assert_downcast_matches {
78    ($e:expr_2021, $t:ty, $v:pat) => {{
79        match $e.downcast_ref::<$t>() {
80            Some($v) => (),
81            got => panic!("Assertion failed: got wrong error variant {:?}", got),
82        }
83    }};
84}
85
86#[macro_export]
87macro_rules! log_event {
88    ($($key:expr_2021 => $value:expr_2021),*  $(,)?) => {
89        #[allow(unused_variables)]
90        {
91            let mut event = $crate::event::Event::Log($crate::event::LogEvent::default());
92            let log = event.as_mut_log();
93            $(
94                log.insert($key, $value);
95            )*
96            event
97        }
98    };
99}
100
101pub fn test_generate_config<T>()
102where
103    for<'de> T: GenerateConfig + serde::Deserialize<'de>,
104{
105    let cfg = toml::to_string(&T::generate_config()).unwrap();
106
107    toml::from_str::<T>(&cfg)
108        .unwrap_or_else(|e| panic!("Invalid config generated from string:\n\n{e}\n'{cfg}'"));
109}
110
111pub fn open_fixture(path: impl AsRef<Path>) -> crate::Result<serde_json::Value> {
112    let test_file = match File::open(path) {
113        Ok(file) => file,
114        Err(e) => return Err(e.into()),
115    };
116    let value: serde_json::Value = serde_json::from_reader(test_file)?;
117    Ok(value)
118}
119
120pub fn next_addr_for_ip(ip: IpAddr) -> SocketAddr {
121    let port = pick_unused_port(ip);
122    SocketAddr::new(ip, port)
123}
124
125pub fn next_addr() -> SocketAddr {
126    next_addr_for_ip(IpAddr::V4(Ipv4Addr::LOCALHOST))
127}
128
129pub fn next_addr_any() -> SocketAddr {
130    next_addr_for_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED))
131}
132
133pub fn next_addr_v6() -> SocketAddr {
134    next_addr_for_ip(IpAddr::V6(Ipv6Addr::LOCALHOST))
135}
136
137pub fn trace_init() {
138    #[cfg(unix)]
139    let color = {
140        use std::io::IsTerminal;
141        std::io::stdout().is_terminal()
142            || std::env::var("NEXTEST")
143                .ok()
144                .and(Some(true))
145                .unwrap_or(false)
146    };
147    // Windows: ANSI colors are not supported by cmd.exe
148    // Color is false for everything except unix.
149    #[cfg(not(unix))]
150    let color = false;
151
152    let levels = std::env::var("VECTOR_LOG").unwrap_or_else(|_| "error".to_string());
153
154    trace::init(color, false, &levels, 10);
155
156    // Initialize metrics as well
157    vector_lib::metrics::init_test();
158}
159
160pub async fn send_lines(
161    addr: SocketAddr,
162    lines: impl IntoIterator<Item = String>,
163) -> Result<SocketAddr, Infallible> {
164    send_encodable(addr, LinesCodec::new(), lines).await
165}
166
167pub async fn send_encodable<I, E: From<std::io::Error> + std::fmt::Debug>(
168    addr: SocketAddr,
169    encoder: impl Encoder<I, Error = E>,
170    lines: impl IntoIterator<Item = I>,
171) -> Result<SocketAddr, Infallible> {
172    let stream = TcpStream::connect(&addr).await.unwrap();
173
174    let local_addr = stream.local_addr().unwrap();
175
176    let mut sink = FramedWrite::new(stream, encoder);
177
178    let mut lines = stream::iter(lines.into_iter()).map(Ok);
179    sink.send_all(&mut lines).await.unwrap();
180
181    let stream = sink.get_mut();
182    stream.shutdown().await.unwrap();
183
184    Ok(local_addr)
185}
186
187pub async fn send_lines_tls(
188    addr: SocketAddr,
189    host: String,
190    lines: impl Iterator<Item = String>,
191    ca: impl Into<Option<&Path>>,
192    client_cert: impl Into<Option<&Path>>,
193    client_key: impl Into<Option<&Path>>,
194) -> Result<SocketAddr, Infallible> {
195    let stream = TcpStream::connect(&addr).await.unwrap();
196
197    let local_addr = stream.local_addr().unwrap();
198
199    let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
200    if let Some(ca) = ca.into() {
201        connector.set_ca_file(ca).unwrap();
202    } else {
203        connector.set_verify(SslVerifyMode::NONE);
204    }
205
206    if let Some(cert_file) = client_cert.into() {
207        connector.set_certificate_chain_file(cert_file).unwrap();
208    }
209
210    if let Some(key_file) = client_key.into() {
211        connector
212            .set_private_key_file(key_file, SslFiletype::PEM)
213            .unwrap();
214    }
215
216    let ssl = connector
217        .build()
218        .configure()
219        .unwrap()
220        .into_ssl(&host)
221        .unwrap();
222
223    let mut stream = tokio_openssl::SslStream::new(ssl, stream).unwrap();
224    Pin::new(&mut stream).connect().await.unwrap();
225    let mut sink = FramedWrite::new(stream, LinesCodec::new());
226
227    let mut lines = stream::iter(lines).map(Ok);
228    sink.send_all(&mut lines).await.unwrap();
229
230    let stream = sink.get_mut().get_mut();
231    stream.shutdown().await.unwrap();
232
233    Ok(local_addr)
234}
235
236pub fn temp_file() -> PathBuf {
237    let path = std::env::temp_dir();
238    let file_name = random_string(16);
239    path.join(file_name + ".log")
240}
241
242pub fn temp_dir() -> PathBuf {
243    let path = std::env::temp_dir();
244    let dir_name = random_string(16);
245    path.join(dir_name)
246}
247
248pub fn random_table_name() -> String {
249    format!("test_{}", random_string(10).to_lowercase())
250}
251
252pub fn map_event_batch_stream(
253    stream: impl Stream<Item = Event>,
254    batch: Option<BatchNotifier>,
255) -> impl Stream<Item = EventArray> {
256    stream.map(move |event| event.with_batch_notifier_option(&batch).into())
257}
258
259// TODO refactor to have a single implementation for `Event`, `LogEvent` and `Metric`.
260fn map_batch_stream(
261    stream: impl Stream<Item = LogEvent>,
262    batch: Option<BatchNotifier>,
263) -> impl Stream<Item = EventArray> {
264    stream.map(move |log| vec![log.with_batch_notifier_option(&batch)].into())
265}
266
267pub fn generate_lines_with_stream<Gen: FnMut(usize) -> String>(
268    generator: Gen,
269    count: usize,
270    batch: Option<BatchNotifier>,
271) -> (Vec<String>, impl Stream<Item = EventArray>) {
272    let lines = (0..count).map(generator).collect::<Vec<_>>();
273    let stream = map_batch_stream(
274        stream::iter(lines.clone()).map(LogEvent::from_str_legacy),
275        batch,
276    );
277    (lines, stream)
278}
279
280pub fn random_lines_with_stream(
281    len: usize,
282    count: usize,
283    batch: Option<BatchNotifier>,
284) -> (Vec<String>, impl Stream<Item = EventArray>) {
285    let generator = move |_| random_string(len);
286    generate_lines_with_stream(generator, count, batch)
287}
288
289pub fn generate_events_with_stream<Gen: FnMut(usize) -> Event>(
290    generator: Gen,
291    count: usize,
292    batch: Option<BatchNotifier>,
293) -> (Vec<Event>, impl Stream<Item = EventArray>) {
294    let events = (0..count).map(generator).collect::<Vec<_>>();
295    let stream = map_batch_stream(
296        stream::iter(events.clone()).map(|event| event.into_log()),
297        batch,
298    );
299    (events, stream)
300}
301
302pub fn random_metrics_with_stream(
303    count: usize,
304    batch: Option<BatchNotifier>,
305    tags: Option<MetricTags>,
306) -> (Vec<Event>, impl Stream<Item = EventArray>) {
307    random_metrics_with_stream_timestamp(
308        count,
309        batch,
310        tags,
311        Utc::now().trunc_subsecs(3),
312        std::time::Duration::from_secs(2),
313    )
314}
315
316/// Generates event metrics with the provided tags and timestamp.
317///
318/// # Parameters
319/// - `count`: the number of metrics to generate
320/// - `batch`: the batch notifier to use with the stream
321/// - `tags`: the tags to apply to each metric event
322/// - `timestamp`: the timestamp to use for each metric event
323/// - `timestamp_offset`: the offset from the `timestamp` to use for each additional metric
324///
325/// # Returns
326/// A tuple of the generated metric events and the stream of the generated events
327pub fn random_metrics_with_stream_timestamp(
328    count: usize,
329    batch: Option<BatchNotifier>,
330    tags: Option<MetricTags>,
331    timestamp: DateTime<Utc>,
332    timestamp_offset: std::time::Duration,
333) -> (Vec<Event>, impl Stream<Item = EventArray>) {
334    let events: Vec<_> = (0..count)
335        .map(|index| {
336            let ts = timestamp + (timestamp_offset * index as u32);
337            Event::Metric(
338                Metric::new(
339                    format!("counter_{}", rng().random::<u32>()),
340                    MetricKind::Incremental,
341                    MetricValue::Counter {
342                        value: index as f64,
343                    },
344                )
345                .with_timestamp(Some(ts))
346                .with_tags(tags.clone()),
347            )
348            // this ensures we get Origin Metadata, with an undefined service but that's ok.
349            .with_source_type("a_source_like_none_other")
350        })
351        .collect();
352
353    let stream = map_event_batch_stream(stream::iter(events.clone()), batch);
354    (events, stream)
355}
356
357pub fn random_events_with_stream(
358    len: usize,
359    count: usize,
360    batch: Option<BatchNotifier>,
361) -> (Vec<Event>, impl Stream<Item = EventArray>) {
362    let events = (0..count)
363        .map(|_| Event::from(LogEvent::from_str_legacy(random_string(len))))
364        .collect::<Vec<_>>();
365    let stream = map_batch_stream(
366        stream::iter(events.clone()).map(|event| event.into_log()),
367        batch,
368    );
369    (events, stream)
370}
371
372pub fn random_updated_events_with_stream<F>(
373    len: usize,
374    count: usize,
375    batch: Option<BatchNotifier>,
376    update_fn: F,
377) -> (Vec<Event>, impl Stream<Item = EventArray>)
378where
379    F: Fn((usize, LogEvent)) -> LogEvent,
380{
381    let events = (0..count)
382        .map(|_| LogEvent::from_str_legacy(random_string(len)))
383        .enumerate()
384        .map(update_fn)
385        .map(Event::Log)
386        .collect::<Vec<_>>();
387    let stream = map_batch_stream(
388        stream::iter(events.clone()).map(|event| event.into_log()),
389        batch,
390    );
391    (events, stream)
392}
393
394pub fn create_events_batch_with_fn<F: Fn() -> Event>(
395    create_event_fn: F,
396    num_events: usize,
397) -> (Vec<Event>, BatchStatusReceiver) {
398    let mut events = (0..num_events)
399        .map(|_| create_event_fn())
400        .collect::<Vec<_>>();
401    let receiver = BatchNotifier::apply_to(&mut events);
402    (events, receiver)
403}
404
405pub fn random_string(len: usize) -> String {
406    rng()
407        .sample_iter(&Alphanumeric)
408        .take(len)
409        .map(char::from)
410        .collect::<String>()
411}
412
413pub fn random_lines(len: usize) -> impl Iterator<Item = String> {
414    iter::repeat_with(move || random_string(len))
415}
416
417pub fn random_map(max_size: usize, field_len: usize) -> HashMap<String, String> {
418    let size = rng().random_range(0..max_size);
419
420    (0..size)
421        .map(move |_| (random_string(field_len), random_string(field_len)))
422        .collect()
423}
424
425pub fn random_maps(
426    max_size: usize,
427    field_len: usize,
428) -> impl Iterator<Item = HashMap<String, String>> {
429    iter::repeat_with(move || random_map(max_size, field_len))
430}
431
432pub async fn collect_n<S>(rx: S, n: usize) -> Vec<S::Item>
433where
434    S: Stream,
435{
436    rx.take(n).collect().await
437}
438
439pub async fn collect_n_stream<T, S: Stream<Item = T> + Unpin>(stream: &mut S, n: usize) -> Vec<T> {
440    let mut events = Vec::with_capacity(n);
441
442    while events.len() < n {
443        let e = stream.next().await.unwrap();
444        events.push(e);
445    }
446    events
447}
448
449pub async fn collect_ready<S>(mut rx: S) -> Vec<S::Item>
450where
451    S: Stream + Unpin,
452{
453    let waker = noop_waker_ref();
454    let mut cx = Context::from_waker(waker);
455
456    let mut vec = Vec::new();
457    loop {
458        match rx.poll_next_unpin(&mut cx) {
459            Poll::Ready(Some(item)) => vec.push(item),
460            Poll::Ready(None) | Poll::Pending => return vec,
461        }
462    }
463}
464
465pub async fn collect_limited<T: Send + 'static>(mut rx: LimitedReceiver<T>) -> Vec<T> {
466    let mut items = Vec::new();
467    while let Some(item) = rx.next().await {
468        items.push(item);
469    }
470    items
471}
472
473pub async fn collect_n_limited<T: Send + 'static>(mut rx: LimitedReceiver<T>, n: usize) -> Vec<T> {
474    let mut items = Vec::new();
475    while items.len() < n {
476        match rx.next().await {
477            Some(item) => items.push(item),
478            None => break,
479        }
480    }
481    items
482}
483
484pub fn lines_from_file<P: AsRef<Path>>(path: P) -> Vec<String> {
485    trace!(message = "Reading file.", path = %path.as_ref().display());
486    let mut file = File::open(path).unwrap();
487    let mut output = String::new();
488    file.read_to_string(&mut output).unwrap();
489    output.lines().map(|s| s.to_owned()).collect()
490}
491
492pub fn lines_from_gzip_file<P: AsRef<Path>>(path: P) -> Vec<String> {
493    trace!(message = "Reading gzip file.", path = %path.as_ref().display());
494    let mut file = File::open(path).unwrap();
495    let mut gzip_bytes = Vec::new();
496    file.read_to_end(&mut gzip_bytes).unwrap();
497    let mut output = String::new();
498    MultiGzDecoder::new(&gzip_bytes[..])
499        .read_to_string(&mut output)
500        .unwrap();
501    output.lines().map(|s| s.to_owned()).collect()
502}
503
504#[cfg(test)]
505pub fn lines_from_zstd_file<P: AsRef<Path>>(path: P) -> Vec<String> {
506    trace!(message = "Reading zstd file.", path = %path.as_ref().display());
507    let file = File::open(path).unwrap();
508    let mut output = String::new();
509    ZstdDecoder::new(file)
510        .unwrap()
511        .read_to_string(&mut output)
512        .unwrap();
513    output.lines().map(|s| s.to_owned()).collect()
514}
515
516pub fn runtime() -> runtime::Runtime {
517    runtime::Builder::new_multi_thread()
518        .enable_all()
519        .build()
520        .unwrap()
521}
522
523// Wait for a Future to resolve, or the duration to elapse (will panic)
524pub async fn wait_for_duration<F, Fut>(mut f: F, duration: Duration)
525where
526    F: FnMut() -> Fut,
527    Fut: Future<Output = bool> + Send + 'static,
528{
529    let started = Instant::now();
530    let mut delay = WAIT_FOR_MIN_MILLIS;
531    while !f().await {
532        sleep(Duration::from_millis(delay)).await;
533        if started.elapsed() > duration {
534            panic!("Timed out while waiting");
535        }
536        // quadratic backoff up to a maximum delay
537        delay = (delay * 2).min(WAIT_FOR_MAX_MILLIS);
538    }
539}
540
541// Wait for 5 seconds
542pub async fn wait_for<F, Fut>(f: F)
543where
544    F: FnMut() -> Fut,
545    Fut: Future<Output = bool> + Send + 'static,
546{
547    wait_for_duration(f, Duration::from_secs(WAIT_FOR_SECS)).await
548}
549
550// Wait (for 5 secs) for a TCP socket to be reachable
551pub async fn wait_for_tcp<A>(addr: A)
552where
553    A: ToSocketAddrs + Clone + Send + 'static,
554{
555    wait_for(move || {
556        let addr = addr.clone();
557        async move { TcpStream::connect(addr).await.is_ok() }
558    })
559    .await
560}
561
562// Allows specifying a custom duration to wait for a TCP socket to be reachable
563pub async fn wait_for_tcp_duration(addr: SocketAddr, duration: Duration) {
564    wait_for_duration(
565        || async move { TcpStream::connect(addr).await.is_ok() },
566        duration,
567    )
568    .await
569}
570
571pub async fn wait_for_atomic_usize<T, F>(value: T, unblock: F)
572where
573    T: AsRef<AtomicUsize>,
574    F: Fn(usize) -> bool,
575{
576    let value = value.as_ref();
577    wait_for(|| ready(unblock(value.load(Ordering::SeqCst)))).await
578}
579
580// Retries a func every `retry` duration until given an Ok(T); panics after `until` elapses
581pub async fn retry_until<'a, F, Fut, T, E>(mut f: F, retry: Duration, until: Duration) -> T
582where
583    F: FnMut() -> Fut,
584    Fut: Future<Output = Result<T, E>> + Send + 'a,
585{
586    let started = Instant::now();
587    while started.elapsed() < until {
588        match f().await {
589            Ok(res) => return res,
590            Err(_) => tokio::time::sleep(retry).await,
591        }
592    }
593    panic!("Timeout")
594}
595
596pub struct CountReceiver<T> {
597    count: Arc<AtomicUsize>,
598    trigger: Option<oneshot::Sender<()>>,
599    connected: Option<oneshot::Receiver<()>>,
600    handle: JoinHandle<Vec<T>>,
601}
602
603impl<T: Send + 'static> CountReceiver<T> {
604    pub fn count(&self) -> usize {
605        self.count.load(Ordering::Relaxed)
606    }
607
608    /// Succeeds once first connection has been made.
609    pub async fn connected(&mut self) {
610        if let Some(tripwire) = self.connected.take() {
611            tripwire.await.unwrap();
612        }
613    }
614
615    fn new<F, Fut>(make_fut: F) -> CountReceiver<T>
616    where
617        F: FnOnce(Arc<AtomicUsize>, oneshot::Receiver<()>, oneshot::Sender<()>) -> Fut,
618        Fut: Future<Output = Vec<T>> + Send + 'static,
619    {
620        let count = Arc::new(AtomicUsize::new(0));
621        let (trigger, tripwire) = oneshot::channel();
622        let (trigger_connected, connected) = oneshot::channel();
623
624        CountReceiver {
625            count: Arc::clone(&count),
626            trigger: Some(trigger),
627            connected: Some(connected),
628            handle: tokio::spawn(make_fut(count, tripwire, trigger_connected)),
629        }
630    }
631
632    pub fn receive_items_stream<S, F, Fut>(make_stream: F) -> CountReceiver<T>
633    where
634        S: Stream<Item = T> + Send + 'static,
635        F: FnOnce(oneshot::Receiver<()>, oneshot::Sender<()>) -> Fut + Send + 'static,
636        Fut: Future<Output = S> + Send + 'static,
637    {
638        CountReceiver::new(|count, tripwire, connected| async move {
639            let stream = make_stream(tripwire, connected).await;
640            stream
641                .inspect(move |_| {
642                    count.fetch_add(1, Ordering::Relaxed);
643                })
644                .collect::<Vec<T>>()
645                .await
646        })
647    }
648}
649
650impl<T> Future for CountReceiver<T> {
651    type Output = Vec<T>;
652
653    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
654        let this = self.get_mut();
655        if let Some(trigger) = this.trigger.take() {
656            _ = trigger.send(());
657        }
658
659        let result = ready!(this.handle.poll_unpin(cx));
660        Poll::Ready(result.unwrap())
661    }
662}
663
664impl CountReceiver<String> {
665    pub fn receive_lines(addr: SocketAddr) -> CountReceiver<String> {
666        CountReceiver::new(|count, tripwire, connected| async move {
667            let listener = TcpListener::bind(addr).await.unwrap();
668            CountReceiver::receive_lines_stream(
669                TcpListenerStream::new(listener),
670                count,
671                tripwire,
672                Some(connected),
673            )
674            .await
675        })
676    }
677
678    #[cfg(unix)]
679    pub fn receive_lines_unix<P>(path: P) -> CountReceiver<String>
680    where
681        P: AsRef<Path> + Send + 'static,
682    {
683        CountReceiver::new(|count, tripwire, connected| async move {
684            let listener = tokio::net::UnixListener::bind(path).unwrap();
685            CountReceiver::receive_lines_stream(
686                UnixListenerStream::new(listener),
687                count,
688                tripwire,
689                Some(connected),
690            )
691            .await
692        })
693    }
694
695    async fn receive_lines_stream<S, T>(
696        stream: S,
697        count: Arc<AtomicUsize>,
698        tripwire: oneshot::Receiver<()>,
699        mut connected: Option<oneshot::Sender<()>>,
700    ) -> Vec<String>
701    where
702        S: Stream<Item = IoResult<T>>,
703        T: AsyncWrite + AsyncRead,
704    {
705        stream
706            .take_until(tripwire)
707            .map_ok(|socket| FramedRead::new(socket, LinesCodec::new()))
708            .map(|x| {
709                connected.take().map(|trigger| trigger.send(()));
710                x.unwrap()
711            })
712            .flatten()
713            .map(|x| x.unwrap())
714            .inspect(move |_| {
715                count.fetch_add(1, Ordering::Relaxed);
716            })
717            .collect::<Vec<String>>()
718            .await
719    }
720}
721
722impl CountReceiver<Event> {
723    pub fn receive_events<S>(stream: S) -> CountReceiver<Event>
724    where
725        S: Stream<Item = Event> + Send + 'static,
726    {
727        CountReceiver::new(|count, tripwire, connected| async move {
728            connected.send(()).unwrap();
729            stream
730                .take_until(tripwire)
731                .inspect(move |_| {
732                    count.fetch_add(1, Ordering::Relaxed);
733                })
734                .collect::<Vec<Event>>()
735                .await
736        })
737    }
738}
739
740pub async fn start_topology(
741    mut config: Config,
742    require_healthy: impl Into<Option<bool>>,
743) -> (RunningTopology, ShutdownErrorReceiver) {
744    config.healthchecks.set_require_healthy(require_healthy);
745    RunningTopology::start_init_validated(config, Default::default())
746        .await
747        .unwrap()
748}
749
750/// Collect the first `n` events from a stream while a future is spawned
751/// in the background. This is used for tests where the collect has to
752/// happen concurrent with the sending process (ie the stream is
753/// handling finalization, which is required for the future to receive
754/// an acknowledgement).
755pub async fn spawn_collect_n<F, S>(future: F, stream: S, n: usize) -> Vec<Event>
756where
757    F: Future<Output = ()> + Send + 'static,
758    S: Stream<Item = Event>,
759{
760    // TODO: Switch to using `select!` so that we can drive `future` to completion while also driving `collect_n`,
761    // such that if `future` panics, we break out and don't continue driving `collect_n`. In most cases, `future`
762    // completing successfully is what actually drives events into `stream`, so continuing to wait for all N events when
763    // the catalyst has failed is.... almost never the desired behavior.
764    let sender = tokio::spawn(future);
765    let events = collect_n(stream, n).await;
766    sender.await.expect("Failed to send data");
767    events
768}
769
770/// Collect all the ready events from a stream after spawning a future
771/// in the background and letting it run for a given interval. This is
772/// used for tests where the collect has to happen concurrent with the
773/// sending process (ie the stream is handling finalization, which is
774/// required for the future to receive an acknowledgement).
775pub async fn spawn_collect_ready<F, S>(future: F, stream: S, sleep: u64) -> Vec<Event>
776where
777    F: Future<Output = ()> + Send + 'static,
778    S: Stream<Item = Event> + Unpin,
779{
780    let sender = tokio::spawn(future);
781    tokio::time::sleep(Duration::from_secs(sleep)).await;
782    let events = collect_ready(stream).await;
783    sender.await.expect("Failed to send data");
784    events
785}
786
787#[cfg(test)]
788mod tests {
789    use std::{
790        sync::{Arc, RwLock},
791        time::Duration,
792    };
793
794    use super::retry_until;
795
796    // helper which errors the first 3x, and succeeds on the 4th
797    async fn retry_until_helper(count: Arc<RwLock<i32>>) -> Result<(), ()> {
798        if *count.read().unwrap() < 3 {
799            let mut c = count.write().unwrap();
800            *c += 1;
801            return Err(());
802        }
803        Ok(())
804    }
805
806    #[tokio::test]
807    async fn retry_until_before_timeout() {
808        let count = Arc::new(RwLock::new(0));
809        let func = || {
810            let count = Arc::clone(&count);
811            retry_until_helper(count)
812        };
813
814        retry_until(func, Duration::from_millis(10), Duration::from_secs(1)).await;
815    }
816}