vector/test_util/
mod.rs

1#![allow(missing_docs)]
2use std::{
3    collections::HashMap,
4    convert::Infallible,
5    fs::File,
6    future::{ready, Future},
7    io::Read,
8    iter,
9    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
10    path::{Path, PathBuf},
11    pin::Pin,
12    sync::{
13        atomic::{AtomicUsize, Ordering},
14        Arc,
15    },
16    task::{ready, Context, Poll},
17};
18
19use chrono::{DateTime, SubsecRound, Utc};
20use flate2::read::MultiGzDecoder;
21use futures::{stream, task::noop_waker_ref, FutureExt, SinkExt, Stream, StreamExt, TryStreamExt};
22use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
23use portpicker::pick_unused_port;
24use rand::{rng, Rng};
25use rand_distr::Alphanumeric;
26use tokio::{
27    io::{AsyncRead, AsyncWrite, AsyncWriteExt, Result as IoResult},
28    net::{TcpListener, TcpStream, ToSocketAddrs},
29    runtime,
30    sync::oneshot,
31    task::JoinHandle,
32    time::{sleep, Duration, Instant},
33};
34use tokio_stream::wrappers::TcpListenerStream;
35#[cfg(unix)]
36use tokio_stream::wrappers::UnixListenerStream;
37use tokio_util::codec::{Encoder, FramedRead, FramedWrite, LinesCodec};
38use vector_lib::event::{
39    BatchNotifier, BatchStatusReceiver, Event, EventArray, LogEvent, MetricTags, MetricValue,
40};
41use vector_lib::{
42    buffers::topology::channel::LimitedReceiver,
43    event::{Metric, MetricKind},
44};
45#[cfg(test)]
46use zstd::Decoder as ZstdDecoder;
47
48use crate::{
49    config::{Config, GenerateConfig},
50    topology::{RunningTopology, ShutdownErrorReceiver},
51    trace,
52};
53
54const WAIT_FOR_SECS: u64 = 5; // The default time to wait in `wait_for`
55const WAIT_FOR_MIN_MILLIS: u64 = 5; // The minimum time to pause before retrying
56const WAIT_FOR_MAX_MILLIS: u64 = 500; // The maximum time to pause before retrying
57
58#[cfg(any(test, feature = "test-utils"))]
59pub mod components;
60
61#[cfg(test)]
62pub mod http;
63
64#[cfg(test)]
65pub mod metrics;
66
67#[cfg(test)]
68pub mod mock;
69
70pub mod compression;
71pub mod stats;
72
73#[cfg(test)]
74pub mod integration;
75
76#[macro_export]
77macro_rules! assert_downcast_matches {
78    ($e:expr_2021, $t:ty, $v:pat) => {{
79        match $e.downcast_ref::<$t>() {
80            Some($v) => (),
81            got => panic!("Assertion failed: got wrong error variant {:?}", got),
82        }
83    }};
84}
85
86#[macro_export]
87macro_rules! log_event {
88    ($($key:expr_2021 => $value:expr_2021),*  $(,)?) => {
89        #[allow(unused_variables)]
90        {
91            let mut event = $crate::event::Event::Log($crate::event::LogEvent::default());
92            let log = event.as_mut_log();
93            $(
94                log.insert($key, $value);
95            )*
96            event
97        }
98    };
99}
100
101pub fn test_generate_config<T>()
102where
103    for<'de> T: GenerateConfig + serde::Deserialize<'de>,
104{
105    let cfg = toml::to_string(&T::generate_config()).unwrap();
106
107    toml::from_str::<T>(&cfg)
108        .unwrap_or_else(|e| panic!("Invalid config generated from string:\n\n{e}\n'{cfg}'"));
109}
110
111pub fn open_fixture(path: impl AsRef<Path>) -> crate::Result<serde_json::Value> {
112    let test_file = match File::open(path) {
113        Ok(file) => file,
114        Err(e) => return Err(e.into()),
115    };
116    let value: serde_json::Value = serde_json::from_reader(test_file)?;
117    Ok(value)
118}
119
120pub fn next_addr_for_ip(ip: IpAddr) -> SocketAddr {
121    let port = pick_unused_port(ip);
122    SocketAddr::new(ip, port)
123}
124
125pub fn next_addr() -> SocketAddr {
126    next_addr_for_ip(IpAddr::V4(Ipv4Addr::LOCALHOST))
127}
128
129pub fn next_addr_any() -> SocketAddr {
130    next_addr_for_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED))
131}
132
133pub fn next_addr_v6() -> SocketAddr {
134    next_addr_for_ip(IpAddr::V6(Ipv6Addr::LOCALHOST))
135}
136
137pub fn trace_init() {
138    #[cfg(unix)]
139    let color = {
140        use std::io::IsTerminal;
141        std::io::stdout().is_terminal()
142    };
143    // Windows: ANSI colors are not supported by cmd.exe
144    // Color is false for everything except unix.
145    #[cfg(not(unix))]
146    let color = false;
147
148    let levels = std::env::var("TEST_LOG").unwrap_or_else(|_| "error".to_string());
149
150    trace::init(color, false, &levels, 10);
151
152    // Initialize metrics as well
153    vector_lib::metrics::init_test();
154}
155
156pub async fn send_lines(
157    addr: SocketAddr,
158    lines: impl IntoIterator<Item = String>,
159) -> Result<SocketAddr, Infallible> {
160    send_encodable(addr, LinesCodec::new(), lines).await
161}
162
163pub async fn send_encodable<I, E: From<std::io::Error> + std::fmt::Debug>(
164    addr: SocketAddr,
165    encoder: impl Encoder<I, Error = E>,
166    lines: impl IntoIterator<Item = I>,
167) -> Result<SocketAddr, Infallible> {
168    let stream = TcpStream::connect(&addr).await.unwrap();
169
170    let local_addr = stream.local_addr().unwrap();
171
172    let mut sink = FramedWrite::new(stream, encoder);
173
174    let mut lines = stream::iter(lines.into_iter()).map(Ok);
175    sink.send_all(&mut lines).await.unwrap();
176
177    let stream = sink.get_mut();
178    stream.shutdown().await.unwrap();
179
180    Ok(local_addr)
181}
182
183pub async fn send_lines_tls(
184    addr: SocketAddr,
185    host: String,
186    lines: impl Iterator<Item = String>,
187    ca: impl Into<Option<&Path>>,
188    client_cert: impl Into<Option<&Path>>,
189    client_key: impl Into<Option<&Path>>,
190) -> Result<SocketAddr, Infallible> {
191    let stream = TcpStream::connect(&addr).await.unwrap();
192
193    let local_addr = stream.local_addr().unwrap();
194
195    let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
196    if let Some(ca) = ca.into() {
197        connector.set_ca_file(ca).unwrap();
198    } else {
199        connector.set_verify(SslVerifyMode::NONE);
200    }
201
202    if let Some(cert_file) = client_cert.into() {
203        connector.set_certificate_chain_file(cert_file).unwrap();
204    }
205
206    if let Some(key_file) = client_key.into() {
207        connector
208            .set_private_key_file(key_file, SslFiletype::PEM)
209            .unwrap();
210    }
211
212    let ssl = connector
213        .build()
214        .configure()
215        .unwrap()
216        .into_ssl(&host)
217        .unwrap();
218
219    let mut stream = tokio_openssl::SslStream::new(ssl, stream).unwrap();
220    Pin::new(&mut stream).connect().await.unwrap();
221    let mut sink = FramedWrite::new(stream, LinesCodec::new());
222
223    let mut lines = stream::iter(lines).map(Ok);
224    sink.send_all(&mut lines).await.unwrap();
225
226    let stream = sink.get_mut().get_mut();
227    stream.shutdown().await.unwrap();
228
229    Ok(local_addr)
230}
231
232pub fn temp_file() -> PathBuf {
233    let path = std::env::temp_dir();
234    let file_name = random_string(16);
235    path.join(file_name + ".log")
236}
237
238pub fn temp_dir() -> PathBuf {
239    let path = std::env::temp_dir();
240    let dir_name = random_string(16);
241    path.join(dir_name)
242}
243
244pub fn random_table_name() -> String {
245    format!("test_{}", random_string(10).to_lowercase())
246}
247
248pub fn map_event_batch_stream(
249    stream: impl Stream<Item = Event>,
250    batch: Option<BatchNotifier>,
251) -> impl Stream<Item = EventArray> {
252    stream.map(move |event| event.with_batch_notifier_option(&batch).into())
253}
254
255// TODO refactor to have a single implementation for `Event`, `LogEvent` and `Metric`.
256fn map_batch_stream(
257    stream: impl Stream<Item = LogEvent>,
258    batch: Option<BatchNotifier>,
259) -> impl Stream<Item = EventArray> {
260    stream.map(move |log| vec![log.with_batch_notifier_option(&batch)].into())
261}
262
263pub fn generate_lines_with_stream<Gen: FnMut(usize) -> String>(
264    generator: Gen,
265    count: usize,
266    batch: Option<BatchNotifier>,
267) -> (Vec<String>, impl Stream<Item = EventArray>) {
268    let lines = (0..count).map(generator).collect::<Vec<_>>();
269    let stream = map_batch_stream(
270        stream::iter(lines.clone()).map(LogEvent::from_str_legacy),
271        batch,
272    );
273    (lines, stream)
274}
275
276pub fn random_lines_with_stream(
277    len: usize,
278    count: usize,
279    batch: Option<BatchNotifier>,
280) -> (Vec<String>, impl Stream<Item = EventArray>) {
281    let generator = move |_| random_string(len);
282    generate_lines_with_stream(generator, count, batch)
283}
284
285pub fn generate_events_with_stream<Gen: FnMut(usize) -> Event>(
286    generator: Gen,
287    count: usize,
288    batch: Option<BatchNotifier>,
289) -> (Vec<Event>, impl Stream<Item = EventArray>) {
290    let events = (0..count).map(generator).collect::<Vec<_>>();
291    let stream = map_batch_stream(
292        stream::iter(events.clone()).map(|event| event.into_log()),
293        batch,
294    );
295    (events, stream)
296}
297
298pub fn random_metrics_with_stream(
299    count: usize,
300    batch: Option<BatchNotifier>,
301    tags: Option<MetricTags>,
302) -> (Vec<Event>, impl Stream<Item = EventArray>) {
303    random_metrics_with_stream_timestamp(
304        count,
305        batch,
306        tags,
307        Utc::now().trunc_subsecs(3),
308        std::time::Duration::from_secs(2),
309    )
310}
311
312/// Generates event metrics with the provided tags and timestamp.
313///
314/// # Parameters
315/// - `count`: the number of metrics to generate
316/// - `batch`: the batch notifier to use with the stream
317/// - `tags`: the tags to apply to each metric event
318/// - `timestamp`: the timestamp to use for each metric event
319/// - `timestamp_offset`: the offset from the `timestamp` to use for each additional metric
320///
321/// # Returns
322/// A tuple of the generated metric events and the stream of the generated events
323pub fn random_metrics_with_stream_timestamp(
324    count: usize,
325    batch: Option<BatchNotifier>,
326    tags: Option<MetricTags>,
327    timestamp: DateTime<Utc>,
328    timestamp_offset: std::time::Duration,
329) -> (Vec<Event>, impl Stream<Item = EventArray>) {
330    let events: Vec<_> = (0..count)
331        .map(|index| {
332            let ts = timestamp + (timestamp_offset * index as u32);
333            Event::Metric(
334                Metric::new(
335                    format!("counter_{}", rng().random::<u32>()),
336                    MetricKind::Incremental,
337                    MetricValue::Counter {
338                        value: index as f64,
339                    },
340                )
341                .with_timestamp(Some(ts))
342                .with_tags(tags.clone()),
343            )
344            // this ensures we get Origin Metadata, with an undefined service but that's ok.
345            .with_source_type("a_source_like_none_other")
346        })
347        .collect();
348
349    let stream = map_event_batch_stream(stream::iter(events.clone()), batch);
350    (events, stream)
351}
352
353pub fn random_events_with_stream(
354    len: usize,
355    count: usize,
356    batch: Option<BatchNotifier>,
357) -> (Vec<Event>, impl Stream<Item = EventArray>) {
358    let events = (0..count)
359        .map(|_| Event::from(LogEvent::from_str_legacy(random_string(len))))
360        .collect::<Vec<_>>();
361    let stream = map_batch_stream(
362        stream::iter(events.clone()).map(|event| event.into_log()),
363        batch,
364    );
365    (events, stream)
366}
367
368pub fn random_updated_events_with_stream<F>(
369    len: usize,
370    count: usize,
371    batch: Option<BatchNotifier>,
372    update_fn: F,
373) -> (Vec<Event>, impl Stream<Item = EventArray>)
374where
375    F: Fn((usize, LogEvent)) -> LogEvent,
376{
377    let events = (0..count)
378        .map(|_| LogEvent::from_str_legacy(random_string(len)))
379        .enumerate()
380        .map(update_fn)
381        .map(Event::Log)
382        .collect::<Vec<_>>();
383    let stream = map_batch_stream(
384        stream::iter(events.clone()).map(|event| event.into_log()),
385        batch,
386    );
387    (events, stream)
388}
389
390pub fn create_events_batch_with_fn<F: Fn() -> Event>(
391    create_event_fn: F,
392    num_events: usize,
393) -> (Vec<Event>, BatchStatusReceiver) {
394    let mut events = (0..num_events)
395        .map(|_| create_event_fn())
396        .collect::<Vec<_>>();
397    let receiver = BatchNotifier::apply_to(&mut events);
398    (events, receiver)
399}
400
401pub fn random_string(len: usize) -> String {
402    rng()
403        .sample_iter(&Alphanumeric)
404        .take(len)
405        .map(char::from)
406        .collect::<String>()
407}
408
409pub fn random_lines(len: usize) -> impl Iterator<Item = String> {
410    iter::repeat_with(move || random_string(len))
411}
412
413pub fn random_map(max_size: usize, field_len: usize) -> HashMap<String, String> {
414    let size = rng().random_range(0..max_size);
415
416    (0..size)
417        .map(move |_| (random_string(field_len), random_string(field_len)))
418        .collect()
419}
420
421pub fn random_maps(
422    max_size: usize,
423    field_len: usize,
424) -> impl Iterator<Item = HashMap<String, String>> {
425    iter::repeat_with(move || random_map(max_size, field_len))
426}
427
428pub async fn collect_n<S>(rx: S, n: usize) -> Vec<S::Item>
429where
430    S: Stream,
431{
432    rx.take(n).collect().await
433}
434
435pub async fn collect_n_stream<T, S: Stream<Item = T> + Unpin>(stream: &mut S, n: usize) -> Vec<T> {
436    let mut events = Vec::with_capacity(n);
437
438    while events.len() < n {
439        let e = stream.next().await.unwrap();
440        events.push(e);
441    }
442    events
443}
444
445pub async fn collect_ready<S>(mut rx: S) -> Vec<S::Item>
446where
447    S: Stream + Unpin,
448{
449    let waker = noop_waker_ref();
450    let mut cx = Context::from_waker(waker);
451
452    let mut vec = Vec::new();
453    loop {
454        match rx.poll_next_unpin(&mut cx) {
455            Poll::Ready(Some(item)) => vec.push(item),
456            Poll::Ready(None) | Poll::Pending => return vec,
457        }
458    }
459}
460
461pub async fn collect_limited<T: Send + 'static>(mut rx: LimitedReceiver<T>) -> Vec<T> {
462    let mut items = Vec::new();
463    while let Some(item) = rx.next().await {
464        items.push(item);
465    }
466    items
467}
468
469pub async fn collect_n_limited<T: Send + 'static>(mut rx: LimitedReceiver<T>, n: usize) -> Vec<T> {
470    let mut items = Vec::new();
471    while items.len() < n {
472        match rx.next().await {
473            Some(item) => items.push(item),
474            None => break,
475        }
476    }
477    items
478}
479
480pub fn lines_from_file<P: AsRef<Path>>(path: P) -> Vec<String> {
481    trace!(message = "Reading file.", path = %path.as_ref().display());
482    let mut file = File::open(path).unwrap();
483    let mut output = String::new();
484    file.read_to_string(&mut output).unwrap();
485    output.lines().map(|s| s.to_owned()).collect()
486}
487
488pub fn lines_from_gzip_file<P: AsRef<Path>>(path: P) -> Vec<String> {
489    trace!(message = "Reading gzip file.", path = %path.as_ref().display());
490    let mut file = File::open(path).unwrap();
491    let mut gzip_bytes = Vec::new();
492    file.read_to_end(&mut gzip_bytes).unwrap();
493    let mut output = String::new();
494    MultiGzDecoder::new(&gzip_bytes[..])
495        .read_to_string(&mut output)
496        .unwrap();
497    output.lines().map(|s| s.to_owned()).collect()
498}
499
500#[cfg(test)]
501pub fn lines_from_zstd_file<P: AsRef<Path>>(path: P) -> Vec<String> {
502    trace!(message = "Reading zstd file.", path = %path.as_ref().display());
503    let file = File::open(path).unwrap();
504    let mut output = String::new();
505    ZstdDecoder::new(file)
506        .unwrap()
507        .read_to_string(&mut output)
508        .unwrap();
509    output.lines().map(|s| s.to_owned()).collect()
510}
511
512pub fn runtime() -> runtime::Runtime {
513    runtime::Builder::new_multi_thread()
514        .enable_all()
515        .build()
516        .unwrap()
517}
518
519// Wait for a Future to resolve, or the duration to elapse (will panic)
520pub async fn wait_for_duration<F, Fut>(mut f: F, duration: Duration)
521where
522    F: FnMut() -> Fut,
523    Fut: Future<Output = bool> + Send + 'static,
524{
525    let started = Instant::now();
526    let mut delay = WAIT_FOR_MIN_MILLIS;
527    while !f().await {
528        sleep(Duration::from_millis(delay)).await;
529        if started.elapsed() > duration {
530            panic!("Timed out while waiting");
531        }
532        // quadratic backoff up to a maximum delay
533        delay = (delay * 2).min(WAIT_FOR_MAX_MILLIS);
534    }
535}
536
537// Wait for 5 seconds
538pub async fn wait_for<F, Fut>(f: F)
539where
540    F: FnMut() -> Fut,
541    Fut: Future<Output = bool> + Send + 'static,
542{
543    wait_for_duration(f, Duration::from_secs(WAIT_FOR_SECS)).await
544}
545
546// Wait (for 5 secs) for a TCP socket to be reachable
547pub async fn wait_for_tcp<A>(addr: A)
548where
549    A: ToSocketAddrs + Clone + Send + 'static,
550{
551    wait_for(move || {
552        let addr = addr.clone();
553        async move { TcpStream::connect(addr).await.is_ok() }
554    })
555    .await
556}
557
558// Allows specifying a custom duration to wait for a TCP socket to be reachable
559pub async fn wait_for_tcp_duration(addr: SocketAddr, duration: Duration) {
560    wait_for_duration(
561        || async move { TcpStream::connect(addr).await.is_ok() },
562        duration,
563    )
564    .await
565}
566
567pub async fn wait_for_atomic_usize<T, F>(value: T, unblock: F)
568where
569    T: AsRef<AtomicUsize>,
570    F: Fn(usize) -> bool,
571{
572    let value = value.as_ref();
573    wait_for(|| ready(unblock(value.load(Ordering::SeqCst)))).await
574}
575
576// Retries a func every `retry` duration until given an Ok(T); panics after `until` elapses
577pub async fn retry_until<'a, F, Fut, T, E>(mut f: F, retry: Duration, until: Duration) -> T
578where
579    F: FnMut() -> Fut,
580    Fut: Future<Output = Result<T, E>> + Send + 'a,
581{
582    let started = Instant::now();
583    while started.elapsed() < until {
584        match f().await {
585            Ok(res) => return res,
586            Err(_) => tokio::time::sleep(retry).await,
587        }
588    }
589    panic!("Timeout")
590}
591
592pub struct CountReceiver<T> {
593    count: Arc<AtomicUsize>,
594    trigger: Option<oneshot::Sender<()>>,
595    connected: Option<oneshot::Receiver<()>>,
596    handle: JoinHandle<Vec<T>>,
597}
598
599impl<T: Send + 'static> CountReceiver<T> {
600    pub fn count(&self) -> usize {
601        self.count.load(Ordering::Relaxed)
602    }
603
604    /// Succeeds once first connection has been made.
605    pub async fn connected(&mut self) {
606        if let Some(tripwire) = self.connected.take() {
607            tripwire.await.unwrap();
608        }
609    }
610
611    fn new<F, Fut>(make_fut: F) -> CountReceiver<T>
612    where
613        F: FnOnce(Arc<AtomicUsize>, oneshot::Receiver<()>, oneshot::Sender<()>) -> Fut,
614        Fut: Future<Output = Vec<T>> + Send + 'static,
615    {
616        let count = Arc::new(AtomicUsize::new(0));
617        let (trigger, tripwire) = oneshot::channel();
618        let (trigger_connected, connected) = oneshot::channel();
619
620        CountReceiver {
621            count: Arc::clone(&count),
622            trigger: Some(trigger),
623            connected: Some(connected),
624            handle: tokio::spawn(make_fut(count, tripwire, trigger_connected)),
625        }
626    }
627
628    pub fn receive_items_stream<S, F, Fut>(make_stream: F) -> CountReceiver<T>
629    where
630        S: Stream<Item = T> + Send + 'static,
631        F: FnOnce(oneshot::Receiver<()>, oneshot::Sender<()>) -> Fut + Send + 'static,
632        Fut: Future<Output = S> + Send + 'static,
633    {
634        CountReceiver::new(|count, tripwire, connected| async move {
635            let stream = make_stream(tripwire, connected).await;
636            stream
637                .inspect(move |_| {
638                    count.fetch_add(1, Ordering::Relaxed);
639                })
640                .collect::<Vec<T>>()
641                .await
642        })
643    }
644}
645
646impl<T> Future for CountReceiver<T> {
647    type Output = Vec<T>;
648
649    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
650        let this = self.get_mut();
651        if let Some(trigger) = this.trigger.take() {
652            _ = trigger.send(());
653        }
654
655        let result = ready!(this.handle.poll_unpin(cx));
656        Poll::Ready(result.unwrap())
657    }
658}
659
660impl CountReceiver<String> {
661    pub fn receive_lines(addr: SocketAddr) -> CountReceiver<String> {
662        CountReceiver::new(|count, tripwire, connected| async move {
663            let listener = TcpListener::bind(addr).await.unwrap();
664            CountReceiver::receive_lines_stream(
665                TcpListenerStream::new(listener),
666                count,
667                tripwire,
668                Some(connected),
669            )
670            .await
671        })
672    }
673
674    #[cfg(unix)]
675    pub fn receive_lines_unix<P>(path: P) -> CountReceiver<String>
676    where
677        P: AsRef<Path> + Send + 'static,
678    {
679        CountReceiver::new(|count, tripwire, connected| async move {
680            let listener = tokio::net::UnixListener::bind(path).unwrap();
681            CountReceiver::receive_lines_stream(
682                UnixListenerStream::new(listener),
683                count,
684                tripwire,
685                Some(connected),
686            )
687            .await
688        })
689    }
690
691    async fn receive_lines_stream<S, T>(
692        stream: S,
693        count: Arc<AtomicUsize>,
694        tripwire: oneshot::Receiver<()>,
695        mut connected: Option<oneshot::Sender<()>>,
696    ) -> Vec<String>
697    where
698        S: Stream<Item = IoResult<T>>,
699        T: AsyncWrite + AsyncRead,
700    {
701        stream
702            .take_until(tripwire)
703            .map_ok(|socket| FramedRead::new(socket, LinesCodec::new()))
704            .map(|x| {
705                connected.take().map(|trigger| trigger.send(()));
706                x.unwrap()
707            })
708            .flatten()
709            .map(|x| x.unwrap())
710            .inspect(move |_| {
711                count.fetch_add(1, Ordering::Relaxed);
712            })
713            .collect::<Vec<String>>()
714            .await
715    }
716}
717
718impl CountReceiver<Event> {
719    pub fn receive_events<S>(stream: S) -> CountReceiver<Event>
720    where
721        S: Stream<Item = Event> + Send + 'static,
722    {
723        CountReceiver::new(|count, tripwire, connected| async move {
724            connected.send(()).unwrap();
725            stream
726                .take_until(tripwire)
727                .inspect(move |_| {
728                    count.fetch_add(1, Ordering::Relaxed);
729                })
730                .collect::<Vec<Event>>()
731                .await
732        })
733    }
734}
735
736pub async fn start_topology(
737    mut config: Config,
738    require_healthy: impl Into<Option<bool>>,
739) -> (RunningTopology, ShutdownErrorReceiver) {
740    config.healthchecks.set_require_healthy(require_healthy);
741    RunningTopology::start_init_validated(config, Default::default())
742        .await
743        .unwrap()
744}
745
746/// Collect the first `n` events from a stream while a future is spawned
747/// in the background. This is used for tests where the collect has to
748/// happen concurrent with the sending process (ie the stream is
749/// handling finalization, which is required for the future to receive
750/// an acknowledgement).
751pub async fn spawn_collect_n<F, S>(future: F, stream: S, n: usize) -> Vec<Event>
752where
753    F: Future<Output = ()> + Send + 'static,
754    S: Stream<Item = Event>,
755{
756    // TODO: Switch to using `select!` so that we can drive `future` to completion while also driving `collect_n`,
757    // such that if `future` panics, we break out and don't continue driving `collect_n`. In most cases, `future`
758    // completing successfully is what actually drives events into `stream`, so continuing to wait for all N events when
759    // the catalyst has failed is.... almost never the desired behavior.
760    let sender = tokio::spawn(future);
761    let events = collect_n(stream, n).await;
762    sender.await.expect("Failed to send data");
763    events
764}
765
766/// Collect all the ready events from a stream after spawning a future
767/// in the background and letting it run for a given interval. This is
768/// used for tests where the collect has to happen concurrent with the
769/// sending process (ie the stream is handling finalization, which is
770/// required for the future to receive an acknowledgement).
771pub async fn spawn_collect_ready<F, S>(future: F, stream: S, sleep: u64) -> Vec<Event>
772where
773    F: Future<Output = ()> + Send + 'static,
774    S: Stream<Item = Event> + Unpin,
775{
776    let sender = tokio::spawn(future);
777    tokio::time::sleep(Duration::from_secs(sleep)).await;
778    let events = collect_ready(stream).await;
779    sender.await.expect("Failed to send data");
780    events
781}
782
783#[cfg(test)]
784mod tests {
785    use std::{
786        sync::{Arc, RwLock},
787        time::Duration,
788    };
789
790    use super::retry_until;
791
792    // helper which errors the first 3x, and succeeds on the 4th
793    async fn retry_until_helper(count: Arc<RwLock<i32>>) -> Result<(), ()> {
794        if *count.read().unwrap() < 3 {
795            let mut c = count.write().unwrap();
796            *c += 1;
797            return Err(());
798        }
799        Ok(())
800    }
801
802    #[tokio::test]
803    async fn retry_until_before_timeout() {
804        let count = Arc::new(RwLock::new(0));
805        let func = || {
806            let count = Arc::clone(&count);
807            retry_until_helper(count)
808        };
809
810        retry_until(func, Duration::from_millis(10), Duration::from_secs(1)).await;
811    }
812}