vector/sources/util/net/tcp/
request_limiter.rs

1use std::{
2    cmp::Ordering,
3    sync::{Arc, Mutex},
4};
5
6use tokio::sync::{OwnedSemaphorePermit, Semaphore};
7
8use crate::stats::EwmaDefault;
9
10const EWMA_WEIGHT: f64 = 0.1;
11const MINIMUM_PERMITS: usize = 2;
12
13pub struct RequestLimiterPermit {
14    semaphore_permit: Option<OwnedSemaphorePermit>,
15    request_limiter_data: Arc<Mutex<RequestLimiterData>>,
16}
17
18impl RequestLimiterPermit {
19    pub fn decoding_finished(&self, num_events: usize) {
20        let mut request_limiter_data = self.request_limiter_data.lock().unwrap();
21        request_limiter_data.update_average(num_events);
22    }
23}
24
25impl Drop for RequestLimiterPermit {
26    fn drop(&mut self) {
27        if let Ok(mut request_limiter_data) = self.request_limiter_data.lock() {
28            let target = request_limiter_data.target_requests_in_flight();
29            let current = request_limiter_data.total_permits;
30
31            match target.cmp(&current) {
32                Ordering::Greater => request_limiter_data.increase_permits(),
33                Ordering::Equal => {
34                    // only release the current permit (when the inner permit is dropped automatically)
35                }
36                Ordering::Less => {
37                    let permit = self.semaphore_permit.take().unwrap();
38                    request_limiter_data.decrease_permits(permit);
39                }
40            }
41        }
42    }
43}
44
45struct RequestLimiterData {
46    event_limit_target: usize,
47    total_permits: usize,
48    average_request_size: EwmaDefault,
49    semaphore: Arc<Semaphore>,
50    max_requests: usize,
51}
52
53impl RequestLimiterData {
54    pub fn update_average(&mut self, num_events: usize) {
55        if num_events > 0 {
56            self.average_request_size.update(num_events as f64);
57        }
58    }
59
60    pub fn target_requests_in_flight(&self) -> usize {
61        let target = (self.event_limit_target as f64) / self.average_request_size.average();
62        #[allow(clippy::manual_clamp)]
63        (target as usize)
64            .max(MINIMUM_PERMITS)
65            .min(self.max_requests)
66    }
67
68    pub fn increase_permits(&mut self) {
69        self.total_permits += 1;
70        self.semaphore.add_permits(1);
71    }
72
73    pub fn decrease_permits(&mut self, permit: OwnedSemaphorePermit) {
74        if self.total_permits > MINIMUM_PERMITS {
75            permit.forget();
76            self.total_permits -= 1;
77        }
78    }
79}
80
81#[derive(Clone)]
82pub struct RequestLimiter {
83    semaphore: Arc<Semaphore>,
84    data: Arc<Mutex<RequestLimiterData>>,
85}
86
87impl RequestLimiter {
88    /// event_limit_target: The limit to the number of events that will be in-flight at one time.
89    /// max_requests: The most number of requests that can be processed concurrently
90    /// The numbers of events in a request is not known until after it has been decoded, so this is not a hard limit.
91    pub fn new(event_limit_target: usize, max_requests: usize) -> RequestLimiter {
92        assert!(event_limit_target > 0);
93
94        let semaphore = Arc::new(Semaphore::new(MINIMUM_PERMITS));
95        RequestLimiter {
96            semaphore: Arc::clone(&semaphore),
97            data: Arc::new(Mutex::new(RequestLimiterData {
98                event_limit_target,
99                total_permits: MINIMUM_PERMITS,
100                average_request_size: EwmaDefault::new(EWMA_WEIGHT, event_limit_target as f64),
101                semaphore,
102                max_requests,
103            })),
104        }
105    }
106
107    pub async fn acquire(&self) -> RequestLimiterPermit {
108        let permit = Arc::clone(&self.semaphore).acquire_owned().await;
109        RequestLimiterPermit {
110            semaphore_permit: permit.ok(),
111            request_limiter_data: Arc::clone(&self.data),
112        }
113    }
114}
115
116#[cfg(test)]
117mod test {
118    use approx::assert_abs_diff_eq;
119
120    use super::*;
121
122    #[tokio::test]
123    async fn test_average_convergence() {
124        let limiter = RequestLimiter::new(100, 100);
125
126        for _ in 0..100 {
127            let permit = limiter.acquire().await;
128            permit.decoding_finished(5);
129            drop(permit);
130        }
131        let data = limiter.data.lock().unwrap();
132        assert_abs_diff_eq!(data.target_requests_in_flight(), 100 / 5, epsilon = 1);
133    }
134
135    #[tokio::test]
136    async fn test_minimum_permits() {
137        let limiter = RequestLimiter::new(100, 100);
138
139        for _ in 0..100 {
140            let permit = limiter.acquire().await;
141            permit.decoding_finished(500);
142            drop(permit);
143        }
144        let data = limiter.data.lock().unwrap();
145        assert_eq!(data.target_requests_in_flight(), MINIMUM_PERMITS);
146    }
147
148    #[tokio::test]
149    async fn test_maximum_permits() {
150        let request_limit = 50;
151        let limiter = RequestLimiter::new(1000, request_limit);
152
153        for _ in 0..100 {
154            let permit = limiter.acquire().await;
155            permit.decoding_finished(1);
156            drop(permit);
157        }
158        let data = limiter.data.lock().unwrap();
159        assert_eq!(data.target_requests_in_flight(), request_limit);
160    }
161}