vector/sinks/util/
retries.rs

1use std::{
2    borrow::Cow,
3    cmp,
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use futures::FutureExt;
11use tokio::time::{sleep, Sleep};
12use tower::{retry::Policy, timeout::error::Elapsed};
13use vector_lib::configurable::configurable_component;
14
15use crate::Error;
16
17pub enum RetryAction<Request = ()> {
18    /// Indicate that this request should be retried with a reason
19    Retry(Cow<'static, str>),
20    /// Indicate that a portion of this request should be retried with a generic function
21    RetryPartial(Box<dyn Fn(Request) -> Request + Send + Sync>),
22    /// Indicate that this request should not be retried with a reason
23    DontRetry(Cow<'static, str>),
24    /// Indicate that this request should not be retried but the request was successful
25    Successful,
26}
27
28pub trait RetryLogic: Clone + Send + Sync + 'static {
29    type Error: std::error::Error + Send + Sync + 'static;
30    type Request;
31    type Response;
32
33    /// When the Service call returns an `Err` response, this function allows
34    /// implementors to specify what kinds of errors can be retried.
35    fn is_retriable_error(&self, error: &Self::Error) -> bool;
36
37    /// When the Service call returns an `Ok` response, this function allows
38    /// implementors to specify additional logic to determine if the success response
39    /// is actually an error. This is particularly useful when the downstream service
40    /// of a sink returns a transport protocol layer success but error data in the
41    /// response body. For example, an HTTP 200 status, but the body of the response
42    /// contains a list of errors encountered while processing.
43    fn should_retry_response(&self, _response: &Self::Response) -> RetryAction<Self::Request> {
44        // Treat the default as the request is successful
45        RetryAction::Successful
46    }
47
48    /// Optional hook run when an error is determined to be retriable.
49    fn on_retriable_error(&self, _error: &Self::Error) {}
50}
51
52/// The jitter mode to use for retry backoff behavior.
53#[configurable_component]
54#[derive(Clone, Copy, Debug, Default)]
55pub enum JitterMode {
56    /// No jitter.
57    None,
58
59    /// Full jitter.
60    ///
61    /// The random delay is anywhere from 0 up to the maximum current delay calculated by the backoff
62    /// strategy.
63    ///
64    /// Incorporating full jitter into your backoff strategy can greatly reduce the likelihood
65    /// of creating accidental denial of service (DoS) conditions against your own systems when
66    /// many clients are recovering from a failure state.
67    #[default]
68    Full,
69}
70
71#[derive(Debug, Clone)]
72pub struct FibonacciRetryPolicy<L> {
73    remaining_attempts: usize,
74    previous_duration: Duration,
75    current_duration: Duration,
76    jitter_mode: JitterMode,
77    current_jitter_duration: Duration,
78    max_duration: Duration,
79    logic: L,
80}
81
82pub struct RetryPolicyFuture {
83    delay: Pin<Box<Sleep>>,
84}
85
86impl<L: RetryLogic> FibonacciRetryPolicy<L> {
87    pub fn new(
88        remaining_attempts: usize,
89        initial_backoff: Duration,
90        max_duration: Duration,
91        logic: L,
92        jitter_mode: JitterMode,
93    ) -> Self {
94        FibonacciRetryPolicy {
95            remaining_attempts,
96            previous_duration: Duration::from_secs(0),
97            current_duration: initial_backoff,
98            jitter_mode,
99            current_jitter_duration: Self::add_full_jitter(initial_backoff),
100            max_duration,
101            logic,
102        }
103    }
104
105    fn add_full_jitter(d: Duration) -> Duration {
106        let jitter = (rand::random::<u64>() % (d.as_millis() as u64)) + 1;
107        Duration::from_millis(jitter)
108    }
109
110    const fn backoff(&self) -> Duration {
111        match self.jitter_mode {
112            JitterMode::None => self.current_duration,
113            JitterMode::Full => self.current_jitter_duration,
114        }
115    }
116
117    fn advance(&mut self) {
118        let sum = self
119            .previous_duration
120            .checked_add(self.current_duration)
121            .unwrap_or(Duration::MAX);
122        let next_duration = cmp::min(sum, self.max_duration);
123        self.remaining_attempts = self.remaining_attempts.saturating_sub(1);
124        self.previous_duration = self.current_duration;
125        self.current_duration = next_duration;
126        self.current_jitter_duration = Self::add_full_jitter(next_duration);
127    }
128
129    fn build_retry(&mut self) -> RetryPolicyFuture {
130        self.advance();
131        let delay = Box::pin(sleep(self.backoff()));
132
133        debug!(message = "Retrying request.", delay_ms = %self.backoff().as_millis());
134        RetryPolicyFuture { delay }
135    }
136}
137
138impl<Req, Res, L> Policy<Req, Res, Error> for FibonacciRetryPolicy<L>
139where
140    Req: Clone + Send + 'static,
141    L: RetryLogic<Request = Req, Response = Res>,
142{
143    type Future = RetryPolicyFuture;
144
145    // NOTE: in the error cases- `Error` and `EventsDropped` internal events are emitted by the
146    // driver, so only need to log here.
147    fn retry(&mut self, req: &mut Req, result: &mut Result<Res, Error>) -> Option<Self::Future> {
148        match result {
149            Ok(response) => match self.logic.should_retry_response(response) {
150                RetryAction::Retry(reason) => {
151                    if self.remaining_attempts == 0 {
152                        error!(
153                            message = "OK/retry response but retries exhausted; dropping the request.",
154                            reason = ?reason,
155                            internal_log_rate_limit = true,
156                        );
157                        return None;
158                    }
159
160                    warn!(message = "Retrying after response.", reason = %reason, internal_log_rate_limit = true);
161                    Some(self.build_retry())
162                }
163                RetryAction::RetryPartial(modify_request) => {
164                    if self.remaining_attempts == 0 {
165                        error!(
166                            message =
167                                "OK/retry response but retries exhausted; dropping the request.",
168                            internal_log_rate_limit = true,
169                        );
170                        return None;
171                    }
172                    *req = modify_request(req.clone());
173                    error!(
174                        message = "OK/retrying partial after response.",
175                        internal_log_rate_limit = true
176                    );
177                    Some(self.build_retry())
178                }
179                RetryAction::DontRetry(reason) => {
180                    error!(message = "Not retriable; dropping the request.", reason = ?reason, internal_log_rate_limit = true);
181                    None
182                }
183
184                RetryAction::Successful => None,
185            },
186            Err(error) => {
187                if self.remaining_attempts == 0 {
188                    error!(message = "Retries exhausted; dropping the request.", %error, internal_log_rate_limit = true);
189                    return None;
190                }
191
192                if let Some(expected) = error.downcast_ref::<L::Error>() {
193                    if self.logic.is_retriable_error(expected) {
194                        self.logic.on_retriable_error(expected);
195                        warn!(message = "Retrying after error.", error = %expected, internal_log_rate_limit = true);
196                        Some(self.build_retry())
197                    } else {
198                        error!(
199                            message = "Non-retriable error; dropping the request.",
200                            %error,
201                            internal_log_rate_limit = true,
202                        );
203                        None
204                    }
205                } else if error.downcast_ref::<Elapsed>().is_some() {
206                    warn!(
207                        message = "Request timed out. If this happens often while the events are actually reaching their destination, try decreasing `batch.max_bytes` and/or using `compression` if applicable. Alternatively `request.timeout_secs` can be increased.",
208                        internal_log_rate_limit = true
209                    );
210                    Some(self.build_retry())
211                } else {
212                    error!(
213                        message = "Unexpected error type; dropping the request.",
214                        %error,
215                        internal_log_rate_limit = true
216                    );
217                    None
218                }
219            }
220        }
221    }
222
223    fn clone_request(&mut self, request: &Req) -> Option<Req> {
224        Some(request.clone())
225    }
226}
227
228// Safety: `L` is never pinned and we use no unsafe pin projections
229// therefore this safe.
230impl Unpin for RetryPolicyFuture {}
231
232impl Future for RetryPolicyFuture {
233    type Output = ();
234
235    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
236        std::task::ready!(self.delay.poll_unpin(cx));
237        Poll::Ready(())
238    }
239}
240
241impl<Request> RetryAction<Request> {
242    pub const fn is_retryable(&self) -> bool {
243        matches!(self, RetryAction::Retry(_) | RetryAction::RetryPartial(_))
244    }
245
246    pub const fn is_not_retryable(&self) -> bool {
247        matches!(self, RetryAction::DontRetry(_))
248    }
249
250    pub const fn is_successful(&self) -> bool {
251        matches!(self, RetryAction::Successful)
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use std::{fmt, time::Duration};
258
259    use tokio::time;
260    use tokio_test::{assert_pending, assert_ready_err, assert_ready_ok, task};
261    use tower::retry::RetryLayer;
262    use tower_test::{assert_request_eq, mock};
263
264    use super::*;
265    use crate::test_util::trace_init;
266
267    #[tokio::test]
268    async fn service_error_retry() {
269        trace_init();
270
271        time::pause();
272
273        let policy = FibonacciRetryPolicy::new(
274            5,
275            Duration::from_secs(1),
276            Duration::from_secs(10),
277            SvcRetryLogic,
278            JitterMode::None,
279        );
280
281        let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
282
283        assert_ready_ok!(svc.poll_ready());
284
285        let fut = svc.call("hello");
286        let mut fut = task::spawn(fut);
287
288        assert_request_eq!(handle, "hello").send_error(Error(true));
289
290        assert_pending!(fut.poll());
291
292        time::advance(Duration::from_secs(2)).await;
293        assert_pending!(fut.poll());
294
295        assert_request_eq!(handle, "hello").send_response("world");
296        assert_eq!(fut.await.unwrap(), "world");
297    }
298
299    #[tokio::test]
300    async fn service_error_no_retry() {
301        trace_init();
302
303        let policy = FibonacciRetryPolicy::new(
304            5,
305            Duration::from_secs(1),
306            Duration::from_secs(10),
307            SvcRetryLogic,
308            JitterMode::None,
309        );
310
311        let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
312
313        assert_ready_ok!(svc.poll_ready());
314
315        let mut fut = task::spawn(svc.call("hello"));
316        assert_request_eq!(handle, "hello").send_error(Error(false));
317        assert_ready_err!(fut.poll());
318    }
319
320    #[tokio::test]
321    async fn timeout_error() {
322        trace_init();
323
324        time::pause();
325
326        let policy = FibonacciRetryPolicy::new(
327            5,
328            Duration::from_secs(1),
329            Duration::from_secs(10),
330            SvcRetryLogic,
331            JitterMode::None,
332        );
333
334        let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
335
336        assert_ready_ok!(svc.poll_ready());
337
338        let mut fut = task::spawn(svc.call("hello"));
339        assert_request_eq!(handle, "hello").send_error(Elapsed::new());
340        assert_pending!(fut.poll());
341
342        time::advance(Duration::from_secs(2)).await;
343        assert_pending!(fut.poll());
344
345        assert_request_eq!(handle, "hello").send_response("world");
346        assert_eq!(fut.await.unwrap(), "world");
347    }
348
349    #[test]
350    fn backoff_grows_to_max() {
351        let mut policy = FibonacciRetryPolicy::new(
352            10,
353            Duration::from_secs(1),
354            Duration::from_secs(10),
355            SvcRetryLogic,
356            JitterMode::None,
357        );
358        assert_eq!(Duration::from_secs(1), policy.backoff());
359
360        policy.advance();
361        assert_eq!(Duration::from_secs(1), policy.backoff());
362
363        policy.advance();
364        assert_eq!(Duration::from_secs(2), policy.backoff());
365
366        policy.advance();
367        assert_eq!(Duration::from_secs(3), policy.backoff());
368
369        policy.advance();
370        assert_eq!(Duration::from_secs(5), policy.backoff());
371
372        policy.advance();
373        assert_eq!(Duration::from_secs(8), policy.backoff());
374
375        policy.advance();
376        assert_eq!(Duration::from_secs(10), policy.backoff());
377
378        policy.advance();
379        assert_eq!(Duration::from_secs(10), policy.backoff());
380    }
381
382    #[test]
383    fn backoff_grows_to_max_with_jitter() {
384        let max_duration = Duration::from_secs(10);
385        let mut policy = FibonacciRetryPolicy::new(
386            10,
387            Duration::from_secs(1),
388            max_duration,
389            SvcRetryLogic,
390            JitterMode::Full,
391        );
392
393        let expected_fib = [1, 1, 2, 3, 5, 8];
394
395        for (i, &exp_fib_secs) in expected_fib.iter().enumerate() {
396            let backoff = policy.backoff();
397            let upper_bound = Duration::from_secs(exp_fib_secs);
398
399            // Check if the backoff is within the expected range, considering the jitter
400            assert!(
401                !backoff.is_zero() && backoff <= upper_bound,
402                "Attempt {}: Expected backoff to be within 0 and {:?}, got {:?}",
403                i + 1,
404                upper_bound,
405                backoff
406            );
407
408            policy.advance();
409        }
410
411        // Once the max backoff is reached, it should not exceed the max backoff.
412        for _ in 0..4 {
413            let backoff = policy.backoff();
414            assert!(
415                !backoff.is_zero() && backoff <= max_duration,
416                "Expected backoff to not exceed {max_duration:?}, got {backoff:?}"
417            );
418
419            policy.advance();
420        }
421    }
422
423    #[derive(Debug, Clone)]
424    struct SvcRetryLogic;
425
426    impl RetryLogic for SvcRetryLogic {
427        type Error = Error;
428        type Request = &'static str;
429        type Response = &'static str;
430
431        fn is_retriable_error(&self, error: &Self::Error) -> bool {
432            error.0
433        }
434    }
435
436    #[derive(Debug)]
437    struct Error(bool);
438
439    impl fmt::Display for Error {
440        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441            write!(f, "error")
442        }
443    }
444
445    impl std::error::Error for Error {}
446}