vector/sources/util/net/tcp/
request_limiter.rs

1use std::{
2    cmp::Ordering,
3    sync::{Arc, Mutex},
4};
5
6use tokio::sync::{OwnedSemaphorePermit, Semaphore};
7use vector_lib::stats::EwmaDefault;
8
9const EWMA_WEIGHT: f64 = 0.1;
10const MINIMUM_PERMITS: usize = 2;
11
12pub struct RequestLimiterPermit {
13    semaphore_permit: Option<OwnedSemaphorePermit>,
14    request_limiter_data: Arc<Mutex<RequestLimiterData>>,
15}
16
17impl RequestLimiterPermit {
18    pub fn decoding_finished(&self, num_events: usize) {
19        let mut request_limiter_data = self.request_limiter_data.lock().unwrap();
20        request_limiter_data.update_average(num_events);
21    }
22}
23
24impl Drop for RequestLimiterPermit {
25    fn drop(&mut self) {
26        if let Ok(mut request_limiter_data) = self.request_limiter_data.lock() {
27            let target = request_limiter_data.target_requests_in_flight();
28            let current = request_limiter_data.total_permits;
29
30            match target.cmp(&current) {
31                Ordering::Greater => request_limiter_data.increase_permits(),
32                Ordering::Equal => {
33                    // only release the current permit (when the inner permit is dropped automatically)
34                }
35                Ordering::Less => {
36                    let permit = self.semaphore_permit.take().unwrap();
37                    request_limiter_data.decrease_permits(permit);
38                }
39            }
40        }
41    }
42}
43
44struct RequestLimiterData {
45    event_limit_target: usize,
46    total_permits: usize,
47    average_request_size: EwmaDefault,
48    semaphore: Arc<Semaphore>,
49    max_requests: usize,
50}
51
52impl RequestLimiterData {
53    pub fn update_average(&mut self, num_events: usize) {
54        if num_events > 0 {
55            self.average_request_size.update(num_events as f64);
56        }
57    }
58
59    pub fn target_requests_in_flight(&self) -> usize {
60        let target = (self.event_limit_target as f64) / self.average_request_size.average();
61        #[allow(clippy::manual_clamp)]
62        (target as usize)
63            .max(MINIMUM_PERMITS)
64            .min(self.max_requests)
65    }
66
67    pub fn increase_permits(&mut self) {
68        self.total_permits += 1;
69        self.semaphore.add_permits(1);
70    }
71
72    pub fn decrease_permits(&mut self, permit: OwnedSemaphorePermit) {
73        if self.total_permits > MINIMUM_PERMITS {
74            permit.forget();
75            self.total_permits -= 1;
76        }
77    }
78}
79
80#[derive(Clone)]
81pub struct RequestLimiter {
82    semaphore: Arc<Semaphore>,
83    data: Arc<Mutex<RequestLimiterData>>,
84}
85
86impl RequestLimiter {
87    /// event_limit_target: The limit to the number of events that will be in-flight at one time.
88    /// max_requests: The most number of requests that can be processed concurrently
89    /// The numbers of events in a request is not known until after it has been decoded, so this is not a hard limit.
90    pub fn new(event_limit_target: usize, max_requests: usize) -> RequestLimiter {
91        assert!(event_limit_target > 0);
92
93        let semaphore = Arc::new(Semaphore::new(MINIMUM_PERMITS));
94        RequestLimiter {
95            semaphore: Arc::clone(&semaphore),
96            data: Arc::new(Mutex::new(RequestLimiterData {
97                event_limit_target,
98                total_permits: MINIMUM_PERMITS,
99                average_request_size: EwmaDefault::new(EWMA_WEIGHT, event_limit_target as f64),
100                semaphore,
101                max_requests,
102            })),
103        }
104    }
105
106    pub async fn acquire(&self) -> RequestLimiterPermit {
107        let permit = Arc::clone(&self.semaphore).acquire_owned().await;
108        RequestLimiterPermit {
109            semaphore_permit: permit.ok(),
110            request_limiter_data: Arc::clone(&self.data),
111        }
112    }
113}
114
115#[cfg(test)]
116mod test {
117    use approx::assert_abs_diff_eq;
118
119    use super::*;
120
121    #[tokio::test]
122    async fn test_average_convergence() {
123        let limiter = RequestLimiter::new(100, 100);
124
125        for _ in 0..100 {
126            let permit = limiter.acquire().await;
127            permit.decoding_finished(5);
128            drop(permit);
129        }
130        let data = limiter.data.lock().unwrap();
131        assert_abs_diff_eq!(data.target_requests_in_flight(), 100 / 5, epsilon = 1);
132    }
133
134    #[tokio::test]
135    async fn test_minimum_permits() {
136        let limiter = RequestLimiter::new(100, 100);
137
138        for _ in 0..100 {
139            let permit = limiter.acquire().await;
140            permit.decoding_finished(500);
141            drop(permit);
142        }
143        let data = limiter.data.lock().unwrap();
144        assert_eq!(data.target_requests_in_flight(), MINIMUM_PERMITS);
145    }
146
147    #[tokio::test]
148    async fn test_maximum_permits() {
149        let request_limit = 50;
150        let limiter = RequestLimiter::new(1000, request_limit);
151
152        for _ in 0..100 {
153            let permit = limiter.acquire().await;
154            permit.decoding_finished(1);
155            drop(permit);
156        }
157        let data = limiter.data.lock().unwrap();
158        assert_eq!(data.target_requests_in_flight(), request_limit);
159    }
160}