vector/sources/util/net/tcp/
request_limiter.rs

1use std::cmp::Ordering;
2use std::sync::{Arc, Mutex};
3
4use tokio::sync::{OwnedSemaphorePermit, Semaphore};
5
6use crate::stats::EwmaDefault;
7
8const EWMA_WEIGHT: f64 = 0.1;
9const MINIMUM_PERMITS: usize = 2;
10
11pub struct RequestLimiterPermit {
12    semaphore_permit: Option<OwnedSemaphorePermit>,
13    request_limiter_data: Arc<Mutex<RequestLimiterData>>,
14}
15
16impl RequestLimiterPermit {
17    pub fn decoding_finished(&self, num_events: usize) {
18        let mut request_limiter_data = self.request_limiter_data.lock().unwrap();
19        request_limiter_data.update_average(num_events);
20    }
21}
22
23impl Drop for RequestLimiterPermit {
24    fn drop(&mut self) {
25        if let Ok(mut request_limiter_data) = self.request_limiter_data.lock() {
26            let target = request_limiter_data.target_requests_in_flight();
27            let current = request_limiter_data.total_permits;
28
29            match target.cmp(&current) {
30                Ordering::Greater => request_limiter_data.increase_permits(),
31                Ordering::Equal => {
32                    // only release the current permit (when the inner permit is dropped automatically)
33                }
34                Ordering::Less => {
35                    let permit = self.semaphore_permit.take().unwrap();
36                    request_limiter_data.decrease_permits(permit);
37                }
38            }
39        }
40    }
41}
42
43struct RequestLimiterData {
44    event_limit_target: usize,
45    total_permits: usize,
46    average_request_size: EwmaDefault,
47    semaphore: Arc<Semaphore>,
48    max_requests: usize,
49}
50
51impl RequestLimiterData {
52    pub fn update_average(&mut self, num_events: usize) {
53        if num_events > 0 {
54            self.average_request_size.update(num_events as f64);
55        }
56    }
57
58    pub fn target_requests_in_flight(&self) -> usize {
59        let target = (self.event_limit_target as f64) / self.average_request_size.average();
60        #[allow(clippy::manual_clamp)]
61        (target as usize)
62            .max(MINIMUM_PERMITS)
63            .min(self.max_requests)
64    }
65
66    pub fn increase_permits(&mut self) {
67        self.total_permits += 1;
68        self.semaphore.add_permits(1);
69    }
70
71    pub fn decrease_permits(&mut self, permit: OwnedSemaphorePermit) {
72        if self.total_permits > MINIMUM_PERMITS {
73            permit.forget();
74            self.total_permits -= 1;
75        }
76    }
77}
78
79#[derive(Clone)]
80pub struct RequestLimiter {
81    semaphore: Arc<Semaphore>,
82    data: Arc<Mutex<RequestLimiterData>>,
83}
84
85impl RequestLimiter {
86    /// event_limit_target: The limit to the number of events that will be in-flight at one time.
87    /// max_requests: The most number of requests that can be processed concurrently
88    /// The numbers of events in a request is not known until after it has been decoded, so this is not a hard limit.
89    pub fn new(event_limit_target: usize, max_requests: usize) -> RequestLimiter {
90        assert!(event_limit_target > 0);
91
92        let semaphore = Arc::new(Semaphore::new(MINIMUM_PERMITS));
93        RequestLimiter {
94            semaphore: Arc::clone(&semaphore),
95            data: Arc::new(Mutex::new(RequestLimiterData {
96                event_limit_target,
97                total_permits: MINIMUM_PERMITS,
98                average_request_size: EwmaDefault::new(EWMA_WEIGHT, event_limit_target as f64),
99                semaphore,
100                max_requests,
101            })),
102        }
103    }
104
105    pub async fn acquire(&self) -> RequestLimiterPermit {
106        let permit = Arc::clone(&self.semaphore).acquire_owned().await;
107        RequestLimiterPermit {
108            semaphore_permit: permit.ok(),
109            request_limiter_data: Arc::clone(&self.data),
110        }
111    }
112}
113
114#[cfg(test)]
115mod test {
116    use approx::assert_abs_diff_eq;
117
118    use super::*;
119
120    #[tokio::test]
121    async fn test_average_convergence() {
122        let limiter = RequestLimiter::new(100, 100);
123
124        for _ in 0..100 {
125            let permit = limiter.acquire().await;
126            permit.decoding_finished(5);
127            drop(permit);
128        }
129        let data = limiter.data.lock().unwrap();
130        assert_abs_diff_eq!(data.target_requests_in_flight(), 100 / 5, epsilon = 1);
131    }
132
133    #[tokio::test]
134    async fn test_minimum_permits() {
135        let limiter = RequestLimiter::new(100, 100);
136
137        for _ in 0..100 {
138            let permit = limiter.acquire().await;
139            permit.decoding_finished(500);
140            drop(permit);
141        }
142        let data = limiter.data.lock().unwrap();
143        assert_eq!(data.target_requests_in_flight(), MINIMUM_PERMITS);
144    }
145
146    #[tokio::test]
147    async fn test_maximum_permits() {
148        let request_limit = 50;
149        let limiter = RequestLimiter::new(1000, request_limit);
150
151        for _ in 0..100 {
152            let permit = limiter.acquire().await;
153            permit.decoding_finished(1);
154            drop(permit);
155        }
156        let data = limiter.data.lock().unwrap();
157        assert_eq!(data.target_requests_in_flight(), request_limit);
158    }
159}