1#![allow(missing_docs)]
2use std::{
3 collections::HashMap,
4 convert::Infallible,
5 fs::File,
6 future::{Future, ready},
7 io::Read,
8 iter,
9 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
10 path::{Path, PathBuf},
11 pin::Pin,
12 sync::{
13 Arc,
14 atomic::{AtomicUsize, Ordering},
15 },
16 task::{Context, Poll, ready},
17};
18
19use chrono::{DateTime, SubsecRound, Utc};
20use flate2::read::MultiGzDecoder;
21use futures::{FutureExt, SinkExt, Stream, StreamExt, TryStreamExt, stream, task::noop_waker_ref};
22use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
23use portpicker::pick_unused_port;
24use rand::{Rng, rng};
25use rand_distr::Alphanumeric;
26use tokio::{
27 io::{AsyncRead, AsyncWrite, AsyncWriteExt, Result as IoResult},
28 net::{TcpListener, TcpStream, ToSocketAddrs},
29 runtime,
30 sync::oneshot,
31 task::JoinHandle,
32 time::{Duration, Instant, sleep},
33};
34use tokio_stream::wrappers::TcpListenerStream;
35#[cfg(unix)]
36use tokio_stream::wrappers::UnixListenerStream;
37use tokio_util::codec::{Encoder, FramedRead, FramedWrite, LinesCodec};
38use vector_lib::{
39 buffers::topology::channel::LimitedReceiver,
40 event::{
41 BatchNotifier, BatchStatusReceiver, Event, EventArray, LogEvent, Metric, MetricKind,
42 MetricTags, MetricValue,
43 },
44};
45#[cfg(test)]
46use zstd::Decoder as ZstdDecoder;
47
48use crate::{
49 config::{Config, GenerateConfig},
50 topology::{RunningTopology, ShutdownErrorReceiver},
51 trace,
52};
53
54const WAIT_FOR_SECS: u64 = 5; const WAIT_FOR_MIN_MILLIS: u64 = 5; const WAIT_FOR_MAX_MILLIS: u64 = 500; #[cfg(any(test, feature = "test-utils"))]
59pub mod components;
60
61#[cfg(test)]
62pub mod http;
63
64#[cfg(test)]
65pub mod metrics;
66
67#[cfg(test)]
68pub mod mock;
69
70pub mod compression;
71pub mod stats;
72
73#[cfg(test)]
74pub mod integration;
75
76#[macro_export]
77macro_rules! assert_downcast_matches {
78 ($e:expr_2021, $t:ty, $v:pat) => {{
79 match $e.downcast_ref::<$t>() {
80 Some($v) => (),
81 got => panic!("Assertion failed: got wrong error variant {:?}", got),
82 }
83 }};
84}
85
86#[macro_export]
87macro_rules! log_event {
88 ($($key:expr_2021 => $value:expr_2021),* $(,)?) => {
89 #[allow(unused_variables)]
90 {
91 let mut event = $crate::event::Event::Log($crate::event::LogEvent::default());
92 let log = event.as_mut_log();
93 $(
94 log.insert($key, $value);
95 )*
96 event
97 }
98 };
99}
100
101pub fn test_generate_config<T>()
102where
103 for<'de> T: GenerateConfig + serde::Deserialize<'de>,
104{
105 let cfg = toml::to_string(&T::generate_config()).unwrap();
106
107 toml::from_str::<T>(&cfg)
108 .unwrap_or_else(|e| panic!("Invalid config generated from string:\n\n{e}\n'{cfg}'"));
109}
110
111pub fn open_fixture(path: impl AsRef<Path>) -> crate::Result<serde_json::Value> {
112 let test_file = match File::open(path) {
113 Ok(file) => file,
114 Err(e) => return Err(e.into()),
115 };
116 let value: serde_json::Value = serde_json::from_reader(test_file)?;
117 Ok(value)
118}
119
120pub fn next_addr_for_ip(ip: IpAddr) -> SocketAddr {
121 let port = pick_unused_port(ip);
122 SocketAddr::new(ip, port)
123}
124
125pub fn next_addr() -> SocketAddr {
126 next_addr_for_ip(IpAddr::V4(Ipv4Addr::LOCALHOST))
127}
128
129pub fn next_addr_any() -> SocketAddr {
130 next_addr_for_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED))
131}
132
133pub fn next_addr_v6() -> SocketAddr {
134 next_addr_for_ip(IpAddr::V6(Ipv6Addr::LOCALHOST))
135}
136
137pub fn trace_init() {
138 #[cfg(unix)]
139 let color = {
140 use std::io::IsTerminal;
141 std::io::stdout().is_terminal()
142 || std::env::var("NEXTEST")
143 .ok()
144 .and(Some(true))
145 .unwrap_or(false)
146 };
147 #[cfg(not(unix))]
150 let color = false;
151
152 let levels = std::env::var("VECTOR_LOG").unwrap_or_else(|_| "error".to_string());
153
154 trace::init(color, false, &levels, 10);
155
156 vector_lib::metrics::init_test();
158}
159
160pub async fn send_lines(
161 addr: SocketAddr,
162 lines: impl IntoIterator<Item = String>,
163) -> Result<SocketAddr, Infallible> {
164 send_encodable(addr, LinesCodec::new(), lines).await
165}
166
167pub async fn send_encodable<I, E: From<std::io::Error> + std::fmt::Debug>(
168 addr: SocketAddr,
169 encoder: impl Encoder<I, Error = E>,
170 lines: impl IntoIterator<Item = I>,
171) -> Result<SocketAddr, Infallible> {
172 let stream = TcpStream::connect(&addr).await.unwrap();
173
174 let local_addr = stream.local_addr().unwrap();
175
176 let mut sink = FramedWrite::new(stream, encoder);
177
178 let mut lines = stream::iter(lines.into_iter()).map(Ok);
179 sink.send_all(&mut lines).await.unwrap();
180
181 let stream = sink.get_mut();
182 stream.shutdown().await.unwrap();
183
184 Ok(local_addr)
185}
186
187pub async fn send_lines_tls(
188 addr: SocketAddr,
189 host: String,
190 lines: impl Iterator<Item = String>,
191 ca: impl Into<Option<&Path>>,
192 client_cert: impl Into<Option<&Path>>,
193 client_key: impl Into<Option<&Path>>,
194) -> Result<SocketAddr, Infallible> {
195 let stream = TcpStream::connect(&addr).await.unwrap();
196
197 let local_addr = stream.local_addr().unwrap();
198
199 let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
200 if let Some(ca) = ca.into() {
201 connector.set_ca_file(ca).unwrap();
202 } else {
203 connector.set_verify(SslVerifyMode::NONE);
204 }
205
206 if let Some(cert_file) = client_cert.into() {
207 connector.set_certificate_chain_file(cert_file).unwrap();
208 }
209
210 if let Some(key_file) = client_key.into() {
211 connector
212 .set_private_key_file(key_file, SslFiletype::PEM)
213 .unwrap();
214 }
215
216 let ssl = connector
217 .build()
218 .configure()
219 .unwrap()
220 .into_ssl(&host)
221 .unwrap();
222
223 let mut stream = tokio_openssl::SslStream::new(ssl, stream).unwrap();
224 Pin::new(&mut stream).connect().await.unwrap();
225 let mut sink = FramedWrite::new(stream, LinesCodec::new());
226
227 let mut lines = stream::iter(lines).map(Ok);
228 sink.send_all(&mut lines).await.unwrap();
229
230 let stream = sink.get_mut().get_mut();
231 stream.shutdown().await.unwrap();
232
233 Ok(local_addr)
234}
235
236pub fn temp_file() -> PathBuf {
237 let path = std::env::temp_dir();
238 let file_name = random_string(16);
239 path.join(file_name + ".log")
240}
241
242pub fn temp_dir() -> PathBuf {
243 let path = std::env::temp_dir();
244 let dir_name = random_string(16);
245 path.join(dir_name)
246}
247
248pub fn random_table_name() -> String {
249 format!("test_{}", random_string(10).to_lowercase())
250}
251
252pub fn map_event_batch_stream(
253 stream: impl Stream<Item = Event>,
254 batch: Option<BatchNotifier>,
255) -> impl Stream<Item = EventArray> {
256 stream.map(move |event| event.with_batch_notifier_option(&batch).into())
257}
258
259fn map_batch_stream(
261 stream: impl Stream<Item = LogEvent>,
262 batch: Option<BatchNotifier>,
263) -> impl Stream<Item = EventArray> {
264 stream.map(move |log| vec![log.with_batch_notifier_option(&batch)].into())
265}
266
267pub fn generate_lines_with_stream<Gen: FnMut(usize) -> String>(
268 generator: Gen,
269 count: usize,
270 batch: Option<BatchNotifier>,
271) -> (Vec<String>, impl Stream<Item = EventArray>) {
272 let lines = (0..count).map(generator).collect::<Vec<_>>();
273 let stream = map_batch_stream(
274 stream::iter(lines.clone()).map(LogEvent::from_str_legacy),
275 batch,
276 );
277 (lines, stream)
278}
279
280pub fn random_lines_with_stream(
281 len: usize,
282 count: usize,
283 batch: Option<BatchNotifier>,
284) -> (Vec<String>, impl Stream<Item = EventArray>) {
285 let generator = move |_| random_string(len);
286 generate_lines_with_stream(generator, count, batch)
287}
288
289pub fn generate_events_with_stream<Gen: FnMut(usize) -> Event>(
290 generator: Gen,
291 count: usize,
292 batch: Option<BatchNotifier>,
293) -> (Vec<Event>, impl Stream<Item = EventArray>) {
294 let events = (0..count).map(generator).collect::<Vec<_>>();
295 let stream = map_batch_stream(
296 stream::iter(events.clone()).map(|event| event.into_log()),
297 batch,
298 );
299 (events, stream)
300}
301
302pub fn random_metrics_with_stream(
303 count: usize,
304 batch: Option<BatchNotifier>,
305 tags: Option<MetricTags>,
306) -> (Vec<Event>, impl Stream<Item = EventArray>) {
307 random_metrics_with_stream_timestamp(
308 count,
309 batch,
310 tags,
311 Utc::now().trunc_subsecs(3),
312 std::time::Duration::from_secs(2),
313 )
314}
315
316pub fn random_metrics_with_stream_timestamp(
328 count: usize,
329 batch: Option<BatchNotifier>,
330 tags: Option<MetricTags>,
331 timestamp: DateTime<Utc>,
332 timestamp_offset: std::time::Duration,
333) -> (Vec<Event>, impl Stream<Item = EventArray>) {
334 let events: Vec<_> = (0..count)
335 .map(|index| {
336 let ts = timestamp + (timestamp_offset * index as u32);
337 Event::Metric(
338 Metric::new(
339 format!("counter_{}", rng().random::<u32>()),
340 MetricKind::Incremental,
341 MetricValue::Counter {
342 value: index as f64,
343 },
344 )
345 .with_timestamp(Some(ts))
346 .with_tags(tags.clone()),
347 )
348 .with_source_type("a_source_like_none_other")
350 })
351 .collect();
352
353 let stream = map_event_batch_stream(stream::iter(events.clone()), batch);
354 (events, stream)
355}
356
357pub fn random_events_with_stream(
358 len: usize,
359 count: usize,
360 batch: Option<BatchNotifier>,
361) -> (Vec<Event>, impl Stream<Item = EventArray>) {
362 let events = (0..count)
363 .map(|_| Event::from(LogEvent::from_str_legacy(random_string(len))))
364 .collect::<Vec<_>>();
365 let stream = map_batch_stream(
366 stream::iter(events.clone()).map(|event| event.into_log()),
367 batch,
368 );
369 (events, stream)
370}
371
372pub fn random_updated_events_with_stream<F>(
373 len: usize,
374 count: usize,
375 batch: Option<BatchNotifier>,
376 update_fn: F,
377) -> (Vec<Event>, impl Stream<Item = EventArray>)
378where
379 F: Fn((usize, LogEvent)) -> LogEvent,
380{
381 let events = (0..count)
382 .map(|_| LogEvent::from_str_legacy(random_string(len)))
383 .enumerate()
384 .map(update_fn)
385 .map(Event::Log)
386 .collect::<Vec<_>>();
387 let stream = map_batch_stream(
388 stream::iter(events.clone()).map(|event| event.into_log()),
389 batch,
390 );
391 (events, stream)
392}
393
394pub fn create_events_batch_with_fn<F: Fn() -> Event>(
395 create_event_fn: F,
396 num_events: usize,
397) -> (Vec<Event>, BatchStatusReceiver) {
398 let mut events = (0..num_events)
399 .map(|_| create_event_fn())
400 .collect::<Vec<_>>();
401 let receiver = BatchNotifier::apply_to(&mut events);
402 (events, receiver)
403}
404
405pub fn random_string(len: usize) -> String {
406 rng()
407 .sample_iter(&Alphanumeric)
408 .take(len)
409 .map(char::from)
410 .collect::<String>()
411}
412
413pub fn random_lines(len: usize) -> impl Iterator<Item = String> {
414 iter::repeat_with(move || random_string(len))
415}
416
417pub fn random_map(max_size: usize, field_len: usize) -> HashMap<String, String> {
418 let size = rng().random_range(0..max_size);
419
420 (0..size)
421 .map(move |_| (random_string(field_len), random_string(field_len)))
422 .collect()
423}
424
425pub fn random_maps(
426 max_size: usize,
427 field_len: usize,
428) -> impl Iterator<Item = HashMap<String, String>> {
429 iter::repeat_with(move || random_map(max_size, field_len))
430}
431
432pub async fn collect_n<S>(rx: S, n: usize) -> Vec<S::Item>
433where
434 S: Stream,
435{
436 rx.take(n).collect().await
437}
438
439pub async fn collect_n_stream<T, S: Stream<Item = T> + Unpin>(stream: &mut S, n: usize) -> Vec<T> {
440 let mut events = Vec::with_capacity(n);
441
442 while events.len() < n {
443 let e = stream.next().await.unwrap();
444 events.push(e);
445 }
446 events
447}
448
449pub async fn collect_ready<S>(mut rx: S) -> Vec<S::Item>
450where
451 S: Stream + Unpin,
452{
453 let waker = noop_waker_ref();
454 let mut cx = Context::from_waker(waker);
455
456 let mut vec = Vec::new();
457 loop {
458 match rx.poll_next_unpin(&mut cx) {
459 Poll::Ready(Some(item)) => vec.push(item),
460 Poll::Ready(None) | Poll::Pending => return vec,
461 }
462 }
463}
464
465pub async fn collect_limited<T: Send + 'static>(mut rx: LimitedReceiver<T>) -> Vec<T> {
466 let mut items = Vec::new();
467 while let Some(item) = rx.next().await {
468 items.push(item);
469 }
470 items
471}
472
473pub async fn collect_n_limited<T: Send + 'static>(mut rx: LimitedReceiver<T>, n: usize) -> Vec<T> {
474 let mut items = Vec::new();
475 while items.len() < n {
476 match rx.next().await {
477 Some(item) => items.push(item),
478 None => break,
479 }
480 }
481 items
482}
483
484pub fn lines_from_file<P: AsRef<Path>>(path: P) -> Vec<String> {
485 trace!(message = "Reading file.", path = %path.as_ref().display());
486 let mut file = File::open(path).unwrap();
487 let mut output = String::new();
488 file.read_to_string(&mut output).unwrap();
489 output.lines().map(|s| s.to_owned()).collect()
490}
491
492pub fn lines_from_gzip_file<P: AsRef<Path>>(path: P) -> Vec<String> {
493 trace!(message = "Reading gzip file.", path = %path.as_ref().display());
494 let mut file = File::open(path).unwrap();
495 let mut gzip_bytes = Vec::new();
496 file.read_to_end(&mut gzip_bytes).unwrap();
497 let mut output = String::new();
498 MultiGzDecoder::new(&gzip_bytes[..])
499 .read_to_string(&mut output)
500 .unwrap();
501 output.lines().map(|s| s.to_owned()).collect()
502}
503
504#[cfg(test)]
505pub fn lines_from_zstd_file<P: AsRef<Path>>(path: P) -> Vec<String> {
506 trace!(message = "Reading zstd file.", path = %path.as_ref().display());
507 let file = File::open(path).unwrap();
508 let mut output = String::new();
509 ZstdDecoder::new(file)
510 .unwrap()
511 .read_to_string(&mut output)
512 .unwrap();
513 output.lines().map(|s| s.to_owned()).collect()
514}
515
516pub fn runtime() -> runtime::Runtime {
517 runtime::Builder::new_multi_thread()
518 .enable_all()
519 .build()
520 .unwrap()
521}
522
523pub async fn wait_for_duration<F, Fut>(mut f: F, duration: Duration)
525where
526 F: FnMut() -> Fut,
527 Fut: Future<Output = bool> + Send + 'static,
528{
529 let started = Instant::now();
530 let mut delay = WAIT_FOR_MIN_MILLIS;
531 while !f().await {
532 sleep(Duration::from_millis(delay)).await;
533 if started.elapsed() > duration {
534 panic!("Timed out while waiting");
535 }
536 delay = (delay * 2).min(WAIT_FOR_MAX_MILLIS);
538 }
539}
540
541pub async fn wait_for<F, Fut>(f: F)
543where
544 F: FnMut() -> Fut,
545 Fut: Future<Output = bool> + Send + 'static,
546{
547 wait_for_duration(f, Duration::from_secs(WAIT_FOR_SECS)).await
548}
549
550pub async fn wait_for_tcp<A>(addr: A)
552where
553 A: ToSocketAddrs + Clone + Send + 'static,
554{
555 wait_for(move || {
556 let addr = addr.clone();
557 async move { TcpStream::connect(addr).await.is_ok() }
558 })
559 .await
560}
561
562pub async fn wait_for_tcp_duration(addr: SocketAddr, duration: Duration) {
564 wait_for_duration(
565 || async move { TcpStream::connect(addr).await.is_ok() },
566 duration,
567 )
568 .await
569}
570
571pub async fn wait_for_atomic_usize<T, F>(value: T, unblock: F)
572where
573 T: AsRef<AtomicUsize>,
574 F: Fn(usize) -> bool,
575{
576 let value = value.as_ref();
577 wait_for(|| ready(unblock(value.load(Ordering::SeqCst)))).await
578}
579
580pub async fn retry_until<'a, F, Fut, T, E>(mut f: F, retry: Duration, until: Duration) -> T
582where
583 F: FnMut() -> Fut,
584 Fut: Future<Output = Result<T, E>> + Send + 'a,
585{
586 let started = Instant::now();
587 while started.elapsed() < until {
588 match f().await {
589 Ok(res) => return res,
590 Err(_) => tokio::time::sleep(retry).await,
591 }
592 }
593 panic!("Timeout")
594}
595
596pub struct CountReceiver<T> {
597 count: Arc<AtomicUsize>,
598 trigger: Option<oneshot::Sender<()>>,
599 connected: Option<oneshot::Receiver<()>>,
600 handle: JoinHandle<Vec<T>>,
601}
602
603impl<T: Send + 'static> CountReceiver<T> {
604 pub fn count(&self) -> usize {
605 self.count.load(Ordering::Relaxed)
606 }
607
608 pub async fn connected(&mut self) {
610 if let Some(tripwire) = self.connected.take() {
611 tripwire.await.unwrap();
612 }
613 }
614
615 fn new<F, Fut>(make_fut: F) -> CountReceiver<T>
616 where
617 F: FnOnce(Arc<AtomicUsize>, oneshot::Receiver<()>, oneshot::Sender<()>) -> Fut,
618 Fut: Future<Output = Vec<T>> + Send + 'static,
619 {
620 let count = Arc::new(AtomicUsize::new(0));
621 let (trigger, tripwire) = oneshot::channel();
622 let (trigger_connected, connected) = oneshot::channel();
623
624 CountReceiver {
625 count: Arc::clone(&count),
626 trigger: Some(trigger),
627 connected: Some(connected),
628 handle: tokio::spawn(make_fut(count, tripwire, trigger_connected)),
629 }
630 }
631
632 pub fn receive_items_stream<S, F, Fut>(make_stream: F) -> CountReceiver<T>
633 where
634 S: Stream<Item = T> + Send + 'static,
635 F: FnOnce(oneshot::Receiver<()>, oneshot::Sender<()>) -> Fut + Send + 'static,
636 Fut: Future<Output = S> + Send + 'static,
637 {
638 CountReceiver::new(|count, tripwire, connected| async move {
639 let stream = make_stream(tripwire, connected).await;
640 stream
641 .inspect(move |_| {
642 count.fetch_add(1, Ordering::Relaxed);
643 })
644 .collect::<Vec<T>>()
645 .await
646 })
647 }
648}
649
650impl<T> Future for CountReceiver<T> {
651 type Output = Vec<T>;
652
653 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
654 let this = self.get_mut();
655 if let Some(trigger) = this.trigger.take() {
656 _ = trigger.send(());
657 }
658
659 let result = ready!(this.handle.poll_unpin(cx));
660 Poll::Ready(result.unwrap())
661 }
662}
663
664impl CountReceiver<String> {
665 pub fn receive_lines(addr: SocketAddr) -> CountReceiver<String> {
666 CountReceiver::new(|count, tripwire, connected| async move {
667 let listener = TcpListener::bind(addr).await.unwrap();
668 CountReceiver::receive_lines_stream(
669 TcpListenerStream::new(listener),
670 count,
671 tripwire,
672 Some(connected),
673 )
674 .await
675 })
676 }
677
678 #[cfg(unix)]
679 pub fn receive_lines_unix<P>(path: P) -> CountReceiver<String>
680 where
681 P: AsRef<Path> + Send + 'static,
682 {
683 CountReceiver::new(|count, tripwire, connected| async move {
684 let listener = tokio::net::UnixListener::bind(path).unwrap();
685 CountReceiver::receive_lines_stream(
686 UnixListenerStream::new(listener),
687 count,
688 tripwire,
689 Some(connected),
690 )
691 .await
692 })
693 }
694
695 async fn receive_lines_stream<S, T>(
696 stream: S,
697 count: Arc<AtomicUsize>,
698 tripwire: oneshot::Receiver<()>,
699 mut connected: Option<oneshot::Sender<()>>,
700 ) -> Vec<String>
701 where
702 S: Stream<Item = IoResult<T>>,
703 T: AsyncWrite + AsyncRead,
704 {
705 stream
706 .take_until(tripwire)
707 .map_ok(|socket| FramedRead::new(socket, LinesCodec::new()))
708 .map(|x| {
709 connected.take().map(|trigger| trigger.send(()));
710 x.unwrap()
711 })
712 .flatten()
713 .map(|x| x.unwrap())
714 .inspect(move |_| {
715 count.fetch_add(1, Ordering::Relaxed);
716 })
717 .collect::<Vec<String>>()
718 .await
719 }
720}
721
722impl CountReceiver<Event> {
723 pub fn receive_events<S>(stream: S) -> CountReceiver<Event>
724 where
725 S: Stream<Item = Event> + Send + 'static,
726 {
727 CountReceiver::new(|count, tripwire, connected| async move {
728 connected.send(()).unwrap();
729 stream
730 .take_until(tripwire)
731 .inspect(move |_| {
732 count.fetch_add(1, Ordering::Relaxed);
733 })
734 .collect::<Vec<Event>>()
735 .await
736 })
737 }
738}
739
740pub async fn start_topology(
741 mut config: Config,
742 require_healthy: impl Into<Option<bool>>,
743) -> (RunningTopology, ShutdownErrorReceiver) {
744 config.healthchecks.set_require_healthy(require_healthy);
745 RunningTopology::start_init_validated(config, Default::default())
746 .await
747 .unwrap()
748}
749
750pub async fn spawn_collect_n<F, S>(future: F, stream: S, n: usize) -> Vec<Event>
756where
757 F: Future<Output = ()> + Send + 'static,
758 S: Stream<Item = Event>,
759{
760 let sender = tokio::spawn(future);
765 let events = collect_n(stream, n).await;
766 sender.await.expect("Failed to send data");
767 events
768}
769
770pub async fn spawn_collect_ready<F, S>(future: F, stream: S, sleep: u64) -> Vec<Event>
776where
777 F: Future<Output = ()> + Send + 'static,
778 S: Stream<Item = Event> + Unpin,
779{
780 let sender = tokio::spawn(future);
781 tokio::time::sleep(Duration::from_secs(sleep)).await;
782 let events = collect_ready(stream).await;
783 sender.await.expect("Failed to send data");
784 events
785}
786
787#[cfg(test)]
788mod tests {
789 use std::{
790 sync::{Arc, RwLock},
791 time::Duration,
792 };
793
794 use super::retry_until;
795
796 async fn retry_until_helper(count: Arc<RwLock<i32>>) -> Result<(), ()> {
798 if *count.read().unwrap() < 3 {
799 let mut c = count.write().unwrap();
800 *c += 1;
801 return Err(());
802 }
803 Ok(())
804 }
805
806 #[tokio::test]
807 async fn retry_until_before_timeout() {
808 let count = Arc::new(RwLock::new(0));
809 let func = || {
810 let count = Arc::clone(&count);
811 retry_until_helper(count)
812 };
813
814 retry_until(func, Duration::from_millis(10), Duration::from_secs(1)).await;
815 }
816}