vector/sources/util/net/tcp/
request_limiter.rs1use std::{
2 cmp::Ordering,
3 sync::{Arc, Mutex},
4};
5
6use tokio::sync::{OwnedSemaphorePermit, Semaphore};
7use vector_lib::stats::EwmaDefault;
8
9const EWMA_WEIGHT: f64 = 0.1;
10const MINIMUM_PERMITS: usize = 2;
11
12pub struct RequestLimiterPermit {
13 semaphore_permit: Option<OwnedSemaphorePermit>,
14 request_limiter_data: Arc<Mutex<RequestLimiterData>>,
15}
16
17impl RequestLimiterPermit {
18 pub fn decoding_finished(&self, num_events: usize) {
19 let mut request_limiter_data = self.request_limiter_data.lock().unwrap();
20 request_limiter_data.update_average(num_events);
21 }
22}
23
24impl Drop for RequestLimiterPermit {
25 fn drop(&mut self) {
26 if let Ok(mut request_limiter_data) = self.request_limiter_data.lock() {
27 let target = request_limiter_data.target_requests_in_flight();
28 let current = request_limiter_data.total_permits;
29
30 match target.cmp(¤t) {
31 Ordering::Greater => request_limiter_data.increase_permits(),
32 Ordering::Equal => {
33 }
35 Ordering::Less => {
36 let permit = self.semaphore_permit.take().unwrap();
37 request_limiter_data.decrease_permits(permit);
38 }
39 }
40 }
41 }
42}
43
44struct RequestLimiterData {
45 event_limit_target: usize,
46 total_permits: usize,
47 average_request_size: EwmaDefault,
48 semaphore: Arc<Semaphore>,
49 max_requests: usize,
50}
51
52impl RequestLimiterData {
53 pub fn update_average(&mut self, num_events: usize) {
54 if num_events > 0 {
55 self.average_request_size.update(num_events as f64);
56 }
57 }
58
59 pub fn target_requests_in_flight(&self) -> usize {
60 let target = (self.event_limit_target as f64) / self.average_request_size.average();
61 #[allow(clippy::manual_clamp)]
62 (target as usize)
63 .max(MINIMUM_PERMITS)
64 .min(self.max_requests)
65 }
66
67 pub fn increase_permits(&mut self) {
68 self.total_permits += 1;
69 self.semaphore.add_permits(1);
70 }
71
72 pub fn decrease_permits(&mut self, permit: OwnedSemaphorePermit) {
73 if self.total_permits > MINIMUM_PERMITS {
74 permit.forget();
75 self.total_permits -= 1;
76 }
77 }
78}
79
80#[derive(Clone)]
81pub struct RequestLimiter {
82 semaphore: Arc<Semaphore>,
83 data: Arc<Mutex<RequestLimiterData>>,
84}
85
86impl RequestLimiter {
87 pub fn new(event_limit_target: usize, max_requests: usize) -> RequestLimiter {
91 assert!(event_limit_target > 0);
92
93 let semaphore = Arc::new(Semaphore::new(MINIMUM_PERMITS));
94 RequestLimiter {
95 semaphore: Arc::clone(&semaphore),
96 data: Arc::new(Mutex::new(RequestLimiterData {
97 event_limit_target,
98 total_permits: MINIMUM_PERMITS,
99 average_request_size: EwmaDefault::new(EWMA_WEIGHT, event_limit_target as f64),
100 semaphore,
101 max_requests,
102 })),
103 }
104 }
105
106 pub async fn acquire(&self) -> RequestLimiterPermit {
107 let permit = Arc::clone(&self.semaphore).acquire_owned().await;
108 RequestLimiterPermit {
109 semaphore_permit: permit.ok(),
110 request_limiter_data: Arc::clone(&self.data),
111 }
112 }
113}
114
115#[cfg(test)]
116mod test {
117 use approx::assert_abs_diff_eq;
118
119 use super::*;
120
121 #[tokio::test]
122 async fn test_average_convergence() {
123 let limiter = RequestLimiter::new(100, 100);
124
125 for _ in 0..100 {
126 let permit = limiter.acquire().await;
127 permit.decoding_finished(5);
128 drop(permit);
129 }
130 let data = limiter.data.lock().unwrap();
131 assert_abs_diff_eq!(data.target_requests_in_flight(), 100 / 5, epsilon = 1);
132 }
133
134 #[tokio::test]
135 async fn test_minimum_permits() {
136 let limiter = RequestLimiter::new(100, 100);
137
138 for _ in 0..100 {
139 let permit = limiter.acquire().await;
140 permit.decoding_finished(500);
141 drop(permit);
142 }
143 let data = limiter.data.lock().unwrap();
144 assert_eq!(data.target_requests_in_flight(), MINIMUM_PERMITS);
145 }
146
147 #[tokio::test]
148 async fn test_maximum_permits() {
149 let request_limit = 50;
150 let limiter = RequestLimiter::new(1000, request_limit);
151
152 for _ in 0..100 {
153 let permit = limiter.acquire().await;
154 permit.decoding_finished(1);
155 drop(permit);
156 }
157 let data = limiter.data.lock().unwrap();
158 assert_eq!(data.target_requests_in_flight(), request_limit);
159 }
160}