vector/sinks/util/adaptive_concurrency/
service.rs1use std::{
2 fmt,
3 future::Future,
4 mem,
5 sync::Arc,
6 task::{ready, Context, Poll},
7};
8
9use futures::future::BoxFuture;
10use tokio::sync::OwnedSemaphorePermit;
11use tower::{load::Load, Service};
12
13use super::{controller::Controller, future::ResponseFuture, AdaptiveConcurrencySettings};
14use crate::sinks::util::retries::RetryLogic;
15
16pub struct AdaptiveConcurrencyLimit<S, L> {
20 inner: S,
21 pub(super) controller: Arc<Controller<L>>,
22 state: State,
23}
24
25enum State {
26 Waiting(BoxFuture<'static, OwnedSemaphorePermit>),
27 Ready(OwnedSemaphorePermit),
28 Empty,
29}
30
31impl<S, L> AdaptiveConcurrencyLimit<S, L> {
32 pub(crate) fn new(
34 inner: S,
35 logic: L,
36 concurrency: Option<usize>,
37 options: AdaptiveConcurrencySettings,
38 ) -> Self {
39 AdaptiveConcurrencyLimit {
40 inner,
41 controller: Arc::new(Controller::new(concurrency, options, logic)),
42 state: State::Empty,
43 }
44 }
45}
46
47impl<S, L, Request> Service<Request> for AdaptiveConcurrencyLimit<S, L>
48where
49 S: Service<Request>,
50 S::Error: Into<crate::Error>,
51 L: RetryLogic<Response = S::Response>,
52{
53 type Response = S::Response;
54 type Error = crate::Error;
55 type Future = ResponseFuture<S::Future, L>;
56
57 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58 loop {
59 self.state = match self.state {
60 State::Ready(_) => return self.inner.poll_ready(cx).map_err(Into::into),
61 State::Waiting(ref mut fut) => {
62 tokio::pin!(fut);
63 let permit = ready!(fut.poll(cx));
64 State::Ready(permit)
65 }
66 State::Empty => State::Waiting(Box::pin(Arc::clone(&self.controller).acquire())),
67 };
68 }
69 }
70
71 fn call(&mut self, request: Request) -> Self::Future {
72 let permit = match mem::replace(&mut self.state, State::Empty) {
74 State::Ready(permit) => permit,
76 _ => panic!("Maximum requests in-flight; poll_ready must be called first"),
78 };
79
80 self.controller.start_request();
81
82 let future = self.inner.call(request);
84
85 ResponseFuture::new(future, permit, Arc::clone(&self.controller))
86 }
87}
88
89impl<S, L> Load for AdaptiveConcurrencyLimit<S, L> {
90 type Metric = f64;
91
92 fn load(&self) -> Self::Metric {
93 self.controller.load()
94 }
95}
96
97impl<S, L> Clone for AdaptiveConcurrencyLimit<S, L>
98where
99 S: Clone,
100 L: Clone,
101{
102 fn clone(&self) -> Self {
103 Self {
104 inner: self.inner.clone(),
105 controller: Arc::clone(&self.controller),
106 state: State::Empty,
107 }
108 }
109}
110
111impl fmt::Debug for State {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 match self {
114 State::Waiting(_) => f
115 .debug_tuple("State::Waiting")
116 .field(&format_args!("..."))
117 .finish(),
118 State::Ready(r) => f.debug_tuple("State::Ready").field(&r).finish(),
119 State::Empty => f.debug_tuple("State::Empty").finish(),
120 }
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use std::{
127 sync::{Mutex, MutexGuard},
128 time::Duration,
129 };
130
131 use snafu::Snafu;
132 use tokio::time::{advance, pause};
133 use tokio_test::{assert_pending, assert_ready_ok};
134 use tower_test::{
135 assert_request_eq,
136 mock::{
137 self, future::ResponseFuture as MockResponseFuture, Handle, Mock, SendResponse, Spawn,
138 },
139 };
140
141 use super::{
142 super::{
143 controller::{ControllerStatistics, Inner},
144 AdaptiveConcurrencyLimitLayer,
145 },
146 *,
147 };
148 use crate::assert_downcast_matches;
149
150 #[derive(Clone, Copy, Debug, Snafu)]
151 enum TestError {
152 Deferral,
153 }
154
155 #[derive(Clone, Copy, Debug)]
156 struct TestRetryLogic;
157 impl RetryLogic for TestRetryLogic {
158 type Error = TestError;
159 type Request = ();
160 type Response = String;
161 fn is_retriable_error(&self, _error: &Self::Error) -> bool {
162 true
163 }
164 }
165
166 type TestInner = AdaptiveConcurrencyLimit<Mock<String, String>, TestRetryLogic>;
167 struct TestService {
168 service: Spawn<TestInner>,
169 handle: Handle<String, String>,
170 inner: Arc<Mutex<Inner>>,
171 stats: Arc<Mutex<ControllerStatistics>>,
172 sequence: usize,
173 }
174
175 struct Send {
176 request: ResponseFuture<MockResponseFuture<String>, TestRetryLogic>,
177 response: SendResponse<String>,
178 sequence: usize,
179 }
180
181 impl TestService {
182 fn start() -> Self {
183 let layer = AdaptiveConcurrencyLimitLayer::new(
184 None,
185 AdaptiveConcurrencySettings {
186 decrease_ratio: 0.5,
187 ..Default::default()
188 },
189 TestRetryLogic,
190 );
191 let (service, handle) = mock::spawn_layer(layer);
192 let controller = Arc::clone(&service.get_ref().controller);
193 let inner = Arc::clone(&controller.inner);
194 let stats = Arc::clone(&controller.stats);
195 Self {
196 service,
197 handle,
198 inner,
199 stats,
200 sequence: 0,
201 }
202 }
203
204 async fn run<F, Ret>(doit: F) -> ControllerStatistics
205 where
206 F: FnOnce(Self) -> Ret,
207 Ret: Future<Output = ()>,
208 {
209 let svc = Self::start();
210 let stats = Arc::clone(&svc.stats);
212 pause();
213 doit(svc).await;
214 Arc::try_unwrap(stats).unwrap().into_inner().unwrap()
216 }
217
218 async fn send(&mut self, is_ready: bool) -> Send {
219 assert_ready_ok!(self.service.poll_ready());
220 self.sequence += 1;
221 let data = format!("REQUEST #{}", self.sequence);
222 let request = self.service.call(data.clone());
223 let response = assert_request_eq!(self.handle, data);
224 if is_ready {
225 assert_ready_ok!(self.service.poll_ready());
226 } else {
227 assert_pending!(self.service.poll_ready());
228 }
229 Send {
230 request,
231 response,
232 sequence: self.sequence,
233 }
234 }
235
236 fn inner(&self) -> MutexGuard<Inner> {
237 self.inner.lock().unwrap()
238 }
239 }
240
241 impl Send {
242 async fn respond(self) {
243 let data = format!("RESPONSE #{}", self.sequence);
244 self.response.send_response(data.clone());
245 assert_eq!(self.request.await.unwrap(), data);
246 }
247
248 async fn defer(self) {
249 self.response.send_error(TestError::Deferral);
250 assert_downcast_matches!(
251 self.request.await.unwrap_err(),
252 TestError,
253 TestError::Deferral
254 );
255 }
256 }
257
258 #[tokio::test]
259 async fn startup_conditions() {
260 TestService::run(|mut svc| async move {
261 assert_eq!(svc.inner().current_limit, 1);
263 svc.send(false).await;
264 })
265 .await;
266 }
267
268 #[tokio::test]
269 async fn increases_limit() {
270 let stats = TestService::run(|mut svc| async move {
271 assert_eq!(svc.inner().current_limit, 1);
273 let req = svc.send(false).await;
274 advance(Duration::from_secs(1)).await;
275 req.respond().await;
276
277 assert_eq!(svc.inner().current_limit, 1);
279 let req = svc.send(false).await;
280 advance(Duration::from_secs(1)).await;
281 req.respond().await;
282
283 assert_eq!(svc.inner().current_limit, 2);
285 })
286 .await;
287
288 let in_flight = stats.in_flight.stats().unwrap();
289 assert_eq!(in_flight.max, 1);
290 assert_eq!(in_flight.mean, 1.0);
291
292 let observed_rtt = stats.observed_rtt.stats().unwrap();
293 assert_eq!(observed_rtt.mean, 1.0);
294 }
295
296 #[tokio::test]
297 async fn handles_deferral() {
298 TestService::run(|mut svc| async move {
299 assert_eq!(svc.inner().current_limit, 1);
300 let req = svc.send(false).await;
301 advance(Duration::from_secs(1)).await;
302 req.respond().await;
303
304 assert_eq!(svc.inner().current_limit, 1);
305 let req = svc.send(false).await;
306 advance(Duration::from_secs(1)).await;
307 req.respond().await;
308
309 assert_eq!(svc.inner().current_limit, 2);
310
311 let req = svc.send(true).await;
312 advance(Duration::from_secs(1)).await;
313 req.defer().await;
314 assert_eq!(svc.inner().current_limit, 1);
315 })
316 .await;
317 }
318
319 #[tokio::test]
320 async fn rapid_decrease() {
321 TestService::run(|mut svc| async move {
322 let mut reqs = [None, None, None];
323 for &concurrent in &[1, 1, 2, 3] {
324 assert_eq!(svc.inner().current_limit, concurrent);
325 for (i, req) in reqs.iter_mut().take(concurrent).enumerate() {
332 *req = Some(svc.send(i < concurrent - 1).await);
333 }
334 advance(Duration::from_secs(1)).await;
335 for req in reqs.iter_mut().take(concurrent) {
336 req.take().unwrap().respond().await;
337 }
338 }
339
340 assert_eq!(svc.inner().current_limit, 4);
341
342 let req = svc.send(true).await;
343 advance(Duration::from_secs(1)).await;
344 req.defer().await;
345
346 assert_eq!(svc.inner().current_limit, 2);
347 })
348 .await;
349 }
350}