vector/sources/util/net/tcp/
request_limiter.rs1use std::{
2 cmp::Ordering,
3 sync::{Arc, Mutex},
4};
5
6use tokio::sync::{OwnedSemaphorePermit, Semaphore};
7
8use crate::stats::EwmaDefault;
9
10const EWMA_WEIGHT: f64 = 0.1;
11const MINIMUM_PERMITS: usize = 2;
12
13pub struct RequestLimiterPermit {
14 semaphore_permit: Option<OwnedSemaphorePermit>,
15 request_limiter_data: Arc<Mutex<RequestLimiterData>>,
16}
17
18impl RequestLimiterPermit {
19 pub fn decoding_finished(&self, num_events: usize) {
20 let mut request_limiter_data = self.request_limiter_data.lock().unwrap();
21 request_limiter_data.update_average(num_events);
22 }
23}
24
25impl Drop for RequestLimiterPermit {
26 fn drop(&mut self) {
27 if let Ok(mut request_limiter_data) = self.request_limiter_data.lock() {
28 let target = request_limiter_data.target_requests_in_flight();
29 let current = request_limiter_data.total_permits;
30
31 match target.cmp(¤t) {
32 Ordering::Greater => request_limiter_data.increase_permits(),
33 Ordering::Equal => {
34 }
36 Ordering::Less => {
37 let permit = self.semaphore_permit.take().unwrap();
38 request_limiter_data.decrease_permits(permit);
39 }
40 }
41 }
42 }
43}
44
45struct RequestLimiterData {
46 event_limit_target: usize,
47 total_permits: usize,
48 average_request_size: EwmaDefault,
49 semaphore: Arc<Semaphore>,
50 max_requests: usize,
51}
52
53impl RequestLimiterData {
54 pub fn update_average(&mut self, num_events: usize) {
55 if num_events > 0 {
56 self.average_request_size.update(num_events as f64);
57 }
58 }
59
60 pub fn target_requests_in_flight(&self) -> usize {
61 let target = (self.event_limit_target as f64) / self.average_request_size.average();
62 #[allow(clippy::manual_clamp)]
63 (target as usize)
64 .max(MINIMUM_PERMITS)
65 .min(self.max_requests)
66 }
67
68 pub fn increase_permits(&mut self) {
69 self.total_permits += 1;
70 self.semaphore.add_permits(1);
71 }
72
73 pub fn decrease_permits(&mut self, permit: OwnedSemaphorePermit) {
74 if self.total_permits > MINIMUM_PERMITS {
75 permit.forget();
76 self.total_permits -= 1;
77 }
78 }
79}
80
81#[derive(Clone)]
82pub struct RequestLimiter {
83 semaphore: Arc<Semaphore>,
84 data: Arc<Mutex<RequestLimiterData>>,
85}
86
87impl RequestLimiter {
88 pub fn new(event_limit_target: usize, max_requests: usize) -> RequestLimiter {
92 assert!(event_limit_target > 0);
93
94 let semaphore = Arc::new(Semaphore::new(MINIMUM_PERMITS));
95 RequestLimiter {
96 semaphore: Arc::clone(&semaphore),
97 data: Arc::new(Mutex::new(RequestLimiterData {
98 event_limit_target,
99 total_permits: MINIMUM_PERMITS,
100 average_request_size: EwmaDefault::new(EWMA_WEIGHT, event_limit_target as f64),
101 semaphore,
102 max_requests,
103 })),
104 }
105 }
106
107 pub async fn acquire(&self) -> RequestLimiterPermit {
108 let permit = Arc::clone(&self.semaphore).acquire_owned().await;
109 RequestLimiterPermit {
110 semaphore_permit: permit.ok(),
111 request_limiter_data: Arc::clone(&self.data),
112 }
113 }
114}
115
116#[cfg(test)]
117mod test {
118 use approx::assert_abs_diff_eq;
119
120 use super::*;
121
122 #[tokio::test]
123 async fn test_average_convergence() {
124 let limiter = RequestLimiter::new(100, 100);
125
126 for _ in 0..100 {
127 let permit = limiter.acquire().await;
128 permit.decoding_finished(5);
129 drop(permit);
130 }
131 let data = limiter.data.lock().unwrap();
132 assert_abs_diff_eq!(data.target_requests_in_flight(), 100 / 5, epsilon = 1);
133 }
134
135 #[tokio::test]
136 async fn test_minimum_permits() {
137 let limiter = RequestLimiter::new(100, 100);
138
139 for _ in 0..100 {
140 let permit = limiter.acquire().await;
141 permit.decoding_finished(500);
142 drop(permit);
143 }
144 let data = limiter.data.lock().unwrap();
145 assert_eq!(data.target_requests_in_flight(), MINIMUM_PERMITS);
146 }
147
148 #[tokio::test]
149 async fn test_maximum_permits() {
150 let request_limit = 50;
151 let limiter = RequestLimiter::new(1000, request_limit);
152
153 for _ in 0..100 {
154 let permit = limiter.acquire().await;
155 permit.decoding_finished(1);
156 drop(permit);
157 }
158 let data = limiter.data.lock().unwrap();
159 assert_eq!(data.target_requests_in_flight(), request_limit);
160 }
161}