1use std::{
2 borrow::Cow,
3 cmp,
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use futures::FutureExt;
11use tokio::time::{sleep, Sleep};
12use tower::{retry::Policy, timeout::error::Elapsed};
13use vector_lib::configurable::configurable_component;
14
15use crate::Error;
16
17pub enum RetryAction<Request = ()> {
18 Retry(Cow<'static, str>),
20 RetryPartial(Box<dyn Fn(Request) -> Request + Send + Sync>),
22 DontRetry(Cow<'static, str>),
24 Successful,
26}
27
28pub trait RetryLogic: Clone + Send + Sync + 'static {
29 type Error: std::error::Error + Send + Sync + 'static;
30 type Request;
31 type Response;
32
33 fn is_retriable_error(&self, error: &Self::Error) -> bool;
36
37 fn should_retry_response(&self, _response: &Self::Response) -> RetryAction<Self::Request> {
44 RetryAction::Successful
46 }
47
48 fn on_retriable_error(&self, _error: &Self::Error) {}
50}
51
52#[configurable_component]
54#[derive(Clone, Copy, Debug, Default)]
55pub enum JitterMode {
56 None,
58
59 #[default]
68 Full,
69}
70
71#[derive(Debug, Clone)]
72pub struct FibonacciRetryPolicy<L> {
73 remaining_attempts: usize,
74 previous_duration: Duration,
75 current_duration: Duration,
76 jitter_mode: JitterMode,
77 current_jitter_duration: Duration,
78 max_duration: Duration,
79 logic: L,
80}
81
82pub struct RetryPolicyFuture {
83 delay: Pin<Box<Sleep>>,
84}
85
86impl<L: RetryLogic> FibonacciRetryPolicy<L> {
87 pub fn new(
88 remaining_attempts: usize,
89 initial_backoff: Duration,
90 max_duration: Duration,
91 logic: L,
92 jitter_mode: JitterMode,
93 ) -> Self {
94 FibonacciRetryPolicy {
95 remaining_attempts,
96 previous_duration: Duration::from_secs(0),
97 current_duration: initial_backoff,
98 jitter_mode,
99 current_jitter_duration: Self::add_full_jitter(initial_backoff),
100 max_duration,
101 logic,
102 }
103 }
104
105 fn add_full_jitter(d: Duration) -> Duration {
106 let jitter = (rand::random::<u64>() % (d.as_millis() as u64)) + 1;
107 Duration::from_millis(jitter)
108 }
109
110 const fn backoff(&self) -> Duration {
111 match self.jitter_mode {
112 JitterMode::None => self.current_duration,
113 JitterMode::Full => self.current_jitter_duration,
114 }
115 }
116
117 fn advance(&mut self) {
118 let sum = self
119 .previous_duration
120 .checked_add(self.current_duration)
121 .unwrap_or(Duration::MAX);
122 let next_duration = cmp::min(sum, self.max_duration);
123 self.remaining_attempts = self.remaining_attempts.saturating_sub(1);
124 self.previous_duration = self.current_duration;
125 self.current_duration = next_duration;
126 self.current_jitter_duration = Self::add_full_jitter(next_duration);
127 }
128
129 fn build_retry(&mut self) -> RetryPolicyFuture {
130 self.advance();
131 let delay = Box::pin(sleep(self.backoff()));
132
133 debug!(message = "Retrying request.", delay_ms = %self.backoff().as_millis());
134 RetryPolicyFuture { delay }
135 }
136}
137
138impl<Req, Res, L> Policy<Req, Res, Error> for FibonacciRetryPolicy<L>
139where
140 Req: Clone + Send + 'static,
141 L: RetryLogic<Request = Req, Response = Res>,
142{
143 type Future = RetryPolicyFuture;
144
145 fn retry(&mut self, req: &mut Req, result: &mut Result<Res, Error>) -> Option<Self::Future> {
148 match result {
149 Ok(response) => match self.logic.should_retry_response(response) {
150 RetryAction::Retry(reason) => {
151 if self.remaining_attempts == 0 {
152 error!(
153 message = "OK/retry response but retries exhausted; dropping the request.",
154 reason = ?reason,
155 internal_log_rate_limit = true,
156 );
157 return None;
158 }
159
160 warn!(message = "Retrying after response.", reason = %reason, internal_log_rate_limit = true);
161 Some(self.build_retry())
162 }
163 RetryAction::RetryPartial(modify_request) => {
164 if self.remaining_attempts == 0 {
165 error!(
166 message =
167 "OK/retry response but retries exhausted; dropping the request.",
168 internal_log_rate_limit = true,
169 );
170 return None;
171 }
172 *req = modify_request(req.clone());
173 error!(
174 message = "OK/retrying partial after response.",
175 internal_log_rate_limit = true
176 );
177 Some(self.build_retry())
178 }
179 RetryAction::DontRetry(reason) => {
180 error!(message = "Not retriable; dropping the request.", reason = ?reason, internal_log_rate_limit = true);
181 None
182 }
183
184 RetryAction::Successful => None,
185 },
186 Err(error) => {
187 if self.remaining_attempts == 0 {
188 error!(message = "Retries exhausted; dropping the request.", %error, internal_log_rate_limit = true);
189 return None;
190 }
191
192 if let Some(expected) = error.downcast_ref::<L::Error>() {
193 if self.logic.is_retriable_error(expected) {
194 self.logic.on_retriable_error(expected);
195 warn!(message = "Retrying after error.", error = %expected, internal_log_rate_limit = true);
196 Some(self.build_retry())
197 } else {
198 error!(
199 message = "Non-retriable error; dropping the request.",
200 %error,
201 internal_log_rate_limit = true,
202 );
203 None
204 }
205 } else if error.downcast_ref::<Elapsed>().is_some() {
206 warn!(
207 message = "Request timed out. If this happens often while the events are actually reaching their destination, try decreasing `batch.max_bytes` and/or using `compression` if applicable. Alternatively `request.timeout_secs` can be increased.",
208 internal_log_rate_limit = true
209 );
210 Some(self.build_retry())
211 } else {
212 error!(
213 message = "Unexpected error type; dropping the request.",
214 %error,
215 internal_log_rate_limit = true
216 );
217 None
218 }
219 }
220 }
221 }
222
223 fn clone_request(&mut self, request: &Req) -> Option<Req> {
224 Some(request.clone())
225 }
226}
227
228impl Unpin for RetryPolicyFuture {}
231
232impl Future for RetryPolicyFuture {
233 type Output = ();
234
235 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
236 std::task::ready!(self.delay.poll_unpin(cx));
237 Poll::Ready(())
238 }
239}
240
241impl<Request> RetryAction<Request> {
242 pub const fn is_retryable(&self) -> bool {
243 matches!(self, RetryAction::Retry(_) | RetryAction::RetryPartial(_))
244 }
245
246 pub const fn is_not_retryable(&self) -> bool {
247 matches!(self, RetryAction::DontRetry(_))
248 }
249
250 pub const fn is_successful(&self) -> bool {
251 matches!(self, RetryAction::Successful)
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use std::{fmt, time::Duration};
258
259 use tokio::time;
260 use tokio_test::{assert_pending, assert_ready_err, assert_ready_ok, task};
261 use tower::retry::RetryLayer;
262 use tower_test::{assert_request_eq, mock};
263
264 use super::*;
265 use crate::test_util::trace_init;
266
267 #[tokio::test]
268 async fn service_error_retry() {
269 trace_init();
270
271 time::pause();
272
273 let policy = FibonacciRetryPolicy::new(
274 5,
275 Duration::from_secs(1),
276 Duration::from_secs(10),
277 SvcRetryLogic,
278 JitterMode::None,
279 );
280
281 let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
282
283 assert_ready_ok!(svc.poll_ready());
284
285 let fut = svc.call("hello");
286 let mut fut = task::spawn(fut);
287
288 assert_request_eq!(handle, "hello").send_error(Error(true));
289
290 assert_pending!(fut.poll());
291
292 time::advance(Duration::from_secs(2)).await;
293 assert_pending!(fut.poll());
294
295 assert_request_eq!(handle, "hello").send_response("world");
296 assert_eq!(fut.await.unwrap(), "world");
297 }
298
299 #[tokio::test]
300 async fn service_error_no_retry() {
301 trace_init();
302
303 let policy = FibonacciRetryPolicy::new(
304 5,
305 Duration::from_secs(1),
306 Duration::from_secs(10),
307 SvcRetryLogic,
308 JitterMode::None,
309 );
310
311 let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
312
313 assert_ready_ok!(svc.poll_ready());
314
315 let mut fut = task::spawn(svc.call("hello"));
316 assert_request_eq!(handle, "hello").send_error(Error(false));
317 assert_ready_err!(fut.poll());
318 }
319
320 #[tokio::test]
321 async fn timeout_error() {
322 trace_init();
323
324 time::pause();
325
326 let policy = FibonacciRetryPolicy::new(
327 5,
328 Duration::from_secs(1),
329 Duration::from_secs(10),
330 SvcRetryLogic,
331 JitterMode::None,
332 );
333
334 let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
335
336 assert_ready_ok!(svc.poll_ready());
337
338 let mut fut = task::spawn(svc.call("hello"));
339 assert_request_eq!(handle, "hello").send_error(Elapsed::new());
340 assert_pending!(fut.poll());
341
342 time::advance(Duration::from_secs(2)).await;
343 assert_pending!(fut.poll());
344
345 assert_request_eq!(handle, "hello").send_response("world");
346 assert_eq!(fut.await.unwrap(), "world");
347 }
348
349 #[test]
350 fn backoff_grows_to_max() {
351 let mut policy = FibonacciRetryPolicy::new(
352 10,
353 Duration::from_secs(1),
354 Duration::from_secs(10),
355 SvcRetryLogic,
356 JitterMode::None,
357 );
358 assert_eq!(Duration::from_secs(1), policy.backoff());
359
360 policy.advance();
361 assert_eq!(Duration::from_secs(1), policy.backoff());
362
363 policy.advance();
364 assert_eq!(Duration::from_secs(2), policy.backoff());
365
366 policy.advance();
367 assert_eq!(Duration::from_secs(3), policy.backoff());
368
369 policy.advance();
370 assert_eq!(Duration::from_secs(5), policy.backoff());
371
372 policy.advance();
373 assert_eq!(Duration::from_secs(8), policy.backoff());
374
375 policy.advance();
376 assert_eq!(Duration::from_secs(10), policy.backoff());
377
378 policy.advance();
379 assert_eq!(Duration::from_secs(10), policy.backoff());
380 }
381
382 #[test]
383 fn backoff_grows_to_max_with_jitter() {
384 let max_duration = Duration::from_secs(10);
385 let mut policy = FibonacciRetryPolicy::new(
386 10,
387 Duration::from_secs(1),
388 max_duration,
389 SvcRetryLogic,
390 JitterMode::Full,
391 );
392
393 let expected_fib = [1, 1, 2, 3, 5, 8];
394
395 for (i, &exp_fib_secs) in expected_fib.iter().enumerate() {
396 let backoff = policy.backoff();
397 let upper_bound = Duration::from_secs(exp_fib_secs);
398
399 assert!(
401 !backoff.is_zero() && backoff <= upper_bound,
402 "Attempt {}: Expected backoff to be within 0 and {:?}, got {:?}",
403 i + 1,
404 upper_bound,
405 backoff
406 );
407
408 policy.advance();
409 }
410
411 for _ in 0..4 {
413 let backoff = policy.backoff();
414 assert!(
415 !backoff.is_zero() && backoff <= max_duration,
416 "Expected backoff to not exceed {max_duration:?}, got {backoff:?}"
417 );
418
419 policy.advance();
420 }
421 }
422
423 #[derive(Debug, Clone)]
424 struct SvcRetryLogic;
425
426 impl RetryLogic for SvcRetryLogic {
427 type Error = Error;
428 type Request = &'static str;
429 type Response = &'static str;
430
431 fn is_retriable_error(&self, error: &Self::Error) -> bool {
432 error.0
433 }
434 }
435
436 #[derive(Debug)]
437 struct Error(bool);
438
439 impl fmt::Display for Error {
440 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441 write!(f, "error")
442 }
443 }
444
445 impl std::error::Error for Error {}
446}