vector/sources/util/net/tcp/
request_limiter.rs1use std::cmp::Ordering;
2use std::sync::{Arc, Mutex};
3
4use tokio::sync::{OwnedSemaphorePermit, Semaphore};
5
6use crate::stats::EwmaDefault;
7
8const EWMA_WEIGHT: f64 = 0.1;
9const MINIMUM_PERMITS: usize = 2;
10
11pub struct RequestLimiterPermit {
12 semaphore_permit: Option<OwnedSemaphorePermit>,
13 request_limiter_data: Arc<Mutex<RequestLimiterData>>,
14}
15
16impl RequestLimiterPermit {
17 pub fn decoding_finished(&self, num_events: usize) {
18 let mut request_limiter_data = self.request_limiter_data.lock().unwrap();
19 request_limiter_data.update_average(num_events);
20 }
21}
22
23impl Drop for RequestLimiterPermit {
24 fn drop(&mut self) {
25 if let Ok(mut request_limiter_data) = self.request_limiter_data.lock() {
26 let target = request_limiter_data.target_requests_in_flight();
27 let current = request_limiter_data.total_permits;
28
29 match target.cmp(¤t) {
30 Ordering::Greater => request_limiter_data.increase_permits(),
31 Ordering::Equal => {
32 }
34 Ordering::Less => {
35 let permit = self.semaphore_permit.take().unwrap();
36 request_limiter_data.decrease_permits(permit);
37 }
38 }
39 }
40 }
41}
42
43struct RequestLimiterData {
44 event_limit_target: usize,
45 total_permits: usize,
46 average_request_size: EwmaDefault,
47 semaphore: Arc<Semaphore>,
48 max_requests: usize,
49}
50
51impl RequestLimiterData {
52 pub fn update_average(&mut self, num_events: usize) {
53 if num_events > 0 {
54 self.average_request_size.update(num_events as f64);
55 }
56 }
57
58 pub fn target_requests_in_flight(&self) -> usize {
59 let target = (self.event_limit_target as f64) / self.average_request_size.average();
60 #[allow(clippy::manual_clamp)]
61 (target as usize)
62 .max(MINIMUM_PERMITS)
63 .min(self.max_requests)
64 }
65
66 pub fn increase_permits(&mut self) {
67 self.total_permits += 1;
68 self.semaphore.add_permits(1);
69 }
70
71 pub fn decrease_permits(&mut self, permit: OwnedSemaphorePermit) {
72 if self.total_permits > MINIMUM_PERMITS {
73 permit.forget();
74 self.total_permits -= 1;
75 }
76 }
77}
78
79#[derive(Clone)]
80pub struct RequestLimiter {
81 semaphore: Arc<Semaphore>,
82 data: Arc<Mutex<RequestLimiterData>>,
83}
84
85impl RequestLimiter {
86 pub fn new(event_limit_target: usize, max_requests: usize) -> RequestLimiter {
90 assert!(event_limit_target > 0);
91
92 let semaphore = Arc::new(Semaphore::new(MINIMUM_PERMITS));
93 RequestLimiter {
94 semaphore: Arc::clone(&semaphore),
95 data: Arc::new(Mutex::new(RequestLimiterData {
96 event_limit_target,
97 total_permits: MINIMUM_PERMITS,
98 average_request_size: EwmaDefault::new(EWMA_WEIGHT, event_limit_target as f64),
99 semaphore,
100 max_requests,
101 })),
102 }
103 }
104
105 pub async fn acquire(&self) -> RequestLimiterPermit {
106 let permit = Arc::clone(&self.semaphore).acquire_owned().await;
107 RequestLimiterPermit {
108 semaphore_permit: permit.ok(),
109 request_limiter_data: Arc::clone(&self.data),
110 }
111 }
112}
113
114#[cfg(test)]
115mod test {
116 use approx::assert_abs_diff_eq;
117
118 use super::*;
119
120 #[tokio::test]
121 async fn test_average_convergence() {
122 let limiter = RequestLimiter::new(100, 100);
123
124 for _ in 0..100 {
125 let permit = limiter.acquire().await;
126 permit.decoding_finished(5);
127 drop(permit);
128 }
129 let data = limiter.data.lock().unwrap();
130 assert_abs_diff_eq!(data.target_requests_in_flight(), 100 / 5, epsilon = 1);
131 }
132
133 #[tokio::test]
134 async fn test_minimum_permits() {
135 let limiter = RequestLimiter::new(100, 100);
136
137 for _ in 0..100 {
138 let permit = limiter.acquire().await;
139 permit.decoding_finished(500);
140 drop(permit);
141 }
142 let data = limiter.data.lock().unwrap();
143 assert_eq!(data.target_requests_in_flight(), MINIMUM_PERMITS);
144 }
145
146 #[tokio::test]
147 async fn test_maximum_permits() {
148 let request_limit = 50;
149 let limiter = RequestLimiter::new(1000, request_limit);
150
151 for _ in 0..100 {
152 let permit = limiter.acquire().await;
153 permit.decoding_finished(1);
154 drop(permit);
155 }
156 let data = limiter.data.lock().unwrap();
157 assert_eq!(data.target_requests_in_flight(), request_limit);
158 }
159}