vector/sinks/util/adaptive_concurrency/
service.rs

1use std::{
2    fmt,
3    future::Future,
4    mem,
5    sync::Arc,
6    task::{ready, Context, Poll},
7};
8
9use futures::future::BoxFuture;
10use tokio::sync::OwnedSemaphorePermit;
11use tower::{load::Load, Service};
12
13use super::{controller::Controller, future::ResponseFuture, AdaptiveConcurrencySettings};
14use crate::sinks::util::retries::RetryLogic;
15
16/// Enforces a limit on the concurrent number of requests the underlying
17/// service can handle. Automatically expands and contracts the actual
18/// concurrency limit depending on observed request response behavior.
19pub struct AdaptiveConcurrencyLimit<S, L> {
20    inner: S,
21    pub(super) controller: Arc<Controller<L>>,
22    state: State,
23}
24
25enum State {
26    Waiting(BoxFuture<'static, OwnedSemaphorePermit>),
27    Ready(OwnedSemaphorePermit),
28    Empty,
29}
30
31impl<S, L> AdaptiveConcurrencyLimit<S, L> {
32    /// Create a new automated concurrency limiter.
33    pub(crate) fn new(
34        inner: S,
35        logic: L,
36        concurrency: Option<usize>,
37        options: AdaptiveConcurrencySettings,
38    ) -> Self {
39        AdaptiveConcurrencyLimit {
40            inner,
41            controller: Arc::new(Controller::new(concurrency, options, logic)),
42            state: State::Empty,
43        }
44    }
45}
46
47impl<S, L, Request> Service<Request> for AdaptiveConcurrencyLimit<S, L>
48where
49    S: Service<Request>,
50    S::Error: Into<crate::Error>,
51    L: RetryLogic<Response = S::Response>,
52{
53    type Response = S::Response;
54    type Error = crate::Error;
55    type Future = ResponseFuture<S::Future, L>;
56
57    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58        loop {
59            self.state = match self.state {
60                State::Ready(_) => return self.inner.poll_ready(cx).map_err(Into::into),
61                State::Waiting(ref mut fut) => {
62                    tokio::pin!(fut);
63                    let permit = ready!(fut.poll(cx));
64                    State::Ready(permit)
65                }
66                State::Empty => State::Waiting(Box::pin(Arc::clone(&self.controller).acquire())),
67            };
68        }
69    }
70
71    fn call(&mut self, request: Request) -> Self::Future {
72        // Make sure a permit has been acquired
73        let permit = match mem::replace(&mut self.state, State::Empty) {
74            // Take the permit.
75            State::Ready(permit) => permit,
76            // whoopsie!
77            _ => panic!("Maximum requests in-flight; poll_ready must be called first"),
78        };
79
80        self.controller.start_request();
81
82        // Call the inner service
83        let future = self.inner.call(request);
84
85        ResponseFuture::new(future, permit, Arc::clone(&self.controller))
86    }
87}
88
89impl<S, L> Load for AdaptiveConcurrencyLimit<S, L> {
90    type Metric = f64;
91
92    fn load(&self) -> Self::Metric {
93        self.controller.load()
94    }
95}
96
97impl<S, L> Clone for AdaptiveConcurrencyLimit<S, L>
98where
99    S: Clone,
100    L: Clone,
101{
102    fn clone(&self) -> Self {
103        Self {
104            inner: self.inner.clone(),
105            controller: Arc::clone(&self.controller),
106            state: State::Empty,
107        }
108    }
109}
110
111impl fmt::Debug for State {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        match self {
114            State::Waiting(_) => f
115                .debug_tuple("State::Waiting")
116                .field(&format_args!("..."))
117                .finish(),
118            State::Ready(r) => f.debug_tuple("State::Ready").field(&r).finish(),
119            State::Empty => f.debug_tuple("State::Empty").finish(),
120        }
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use std::{
127        sync::{Mutex, MutexGuard},
128        time::Duration,
129    };
130
131    use snafu::Snafu;
132    use tokio::time::{advance, pause};
133    use tokio_test::{assert_pending, assert_ready_ok};
134    use tower_test::{
135        assert_request_eq,
136        mock::{
137            self, future::ResponseFuture as MockResponseFuture, Handle, Mock, SendResponse, Spawn,
138        },
139    };
140
141    use super::{
142        super::{
143            controller::{ControllerStatistics, Inner},
144            AdaptiveConcurrencyLimitLayer,
145        },
146        *,
147    };
148    use crate::assert_downcast_matches;
149
150    #[derive(Clone, Copy, Debug, Snafu)]
151    enum TestError {
152        Deferral,
153    }
154
155    #[derive(Clone, Copy, Debug)]
156    struct TestRetryLogic;
157    impl RetryLogic for TestRetryLogic {
158        type Error = TestError;
159        type Request = ();
160        type Response = String;
161        fn is_retriable_error(&self, _error: &Self::Error) -> bool {
162            true
163        }
164    }
165
166    type TestInner = AdaptiveConcurrencyLimit<Mock<String, String>, TestRetryLogic>;
167    struct TestService {
168        service: Spawn<TestInner>,
169        handle: Handle<String, String>,
170        inner: Arc<Mutex<Inner>>,
171        stats: Arc<Mutex<ControllerStatistics>>,
172        sequence: usize,
173    }
174
175    struct Send {
176        request: ResponseFuture<MockResponseFuture<String>, TestRetryLogic>,
177        response: SendResponse<String>,
178        sequence: usize,
179    }
180
181    impl TestService {
182        fn start() -> Self {
183            let layer = AdaptiveConcurrencyLimitLayer::new(
184                None,
185                AdaptiveConcurrencySettings {
186                    decrease_ratio: 0.5,
187                    ..Default::default()
188                },
189                TestRetryLogic,
190            );
191            let (service, handle) = mock::spawn_layer(layer);
192            let controller = Arc::clone(&service.get_ref().controller);
193            let inner = Arc::clone(&controller.inner);
194            let stats = Arc::clone(&controller.stats);
195            Self {
196                service,
197                handle,
198                inner,
199                stats,
200                sequence: 0,
201            }
202        }
203
204        async fn run<F, Ret>(doit: F) -> ControllerStatistics
205        where
206            F: FnOnce(Self) -> Ret,
207            Ret: Future<Output = ()>,
208        {
209            let svc = Self::start();
210            //let inner = svc.inner.clone();
211            let stats = Arc::clone(&svc.stats);
212            pause();
213            doit(svc).await;
214            //dbg!(inner);
215            Arc::try_unwrap(stats).unwrap().into_inner().unwrap()
216        }
217
218        async fn send(&mut self, is_ready: bool) -> Send {
219            assert_ready_ok!(self.service.poll_ready());
220            self.sequence += 1;
221            let data = format!("REQUEST #{}", self.sequence);
222            let request = self.service.call(data.clone());
223            let response = assert_request_eq!(self.handle, data);
224            if is_ready {
225                assert_ready_ok!(self.service.poll_ready());
226            } else {
227                assert_pending!(self.service.poll_ready());
228            }
229            Send {
230                request,
231                response,
232                sequence: self.sequence,
233            }
234        }
235
236        fn inner(&self) -> MutexGuard<Inner> {
237            self.inner.lock().unwrap()
238        }
239    }
240
241    impl Send {
242        async fn respond(self) {
243            let data = format!("RESPONSE #{}", self.sequence);
244            self.response.send_response(data.clone());
245            assert_eq!(self.request.await.unwrap(), data);
246        }
247
248        async fn defer(self) {
249            self.response.send_error(TestError::Deferral);
250            assert_downcast_matches!(
251                self.request.await.unwrap_err(),
252                TestError,
253                TestError::Deferral
254            );
255        }
256    }
257
258    #[tokio::test]
259    async fn startup_conditions() {
260        TestService::run(|mut svc| async move {
261            // Concurrency starts at 1
262            assert_eq!(svc.inner().current_limit, 1);
263            svc.send(false).await;
264        })
265        .await;
266    }
267
268    #[tokio::test]
269    async fn increases_limit() {
270        let stats = TestService::run(|mut svc| async move {
271            // Concurrency starts at 1
272            assert_eq!(svc.inner().current_limit, 1);
273            let req = svc.send(false).await;
274            advance(Duration::from_secs(1)).await;
275            req.respond().await;
276
277            // Concurrency stays at 1 until a measurement
278            assert_eq!(svc.inner().current_limit, 1);
279            let req = svc.send(false).await;
280            advance(Duration::from_secs(1)).await;
281            req.respond().await;
282
283            // After a constant speed measurement, concurrency is increased
284            assert_eq!(svc.inner().current_limit, 2);
285        })
286        .await;
287
288        let in_flight = stats.in_flight.stats().unwrap();
289        assert_eq!(in_flight.max, 1);
290        assert_eq!(in_flight.mean, 1.0);
291
292        let observed_rtt = stats.observed_rtt.stats().unwrap();
293        assert_eq!(observed_rtt.mean, 1.0);
294    }
295
296    #[tokio::test]
297    async fn handles_deferral() {
298        TestService::run(|mut svc| async move {
299            assert_eq!(svc.inner().current_limit, 1);
300            let req = svc.send(false).await;
301            advance(Duration::from_secs(1)).await;
302            req.respond().await;
303
304            assert_eq!(svc.inner().current_limit, 1);
305            let req = svc.send(false).await;
306            advance(Duration::from_secs(1)).await;
307            req.respond().await;
308
309            assert_eq!(svc.inner().current_limit, 2);
310
311            let req = svc.send(true).await;
312            advance(Duration::from_secs(1)).await;
313            req.defer().await;
314            assert_eq!(svc.inner().current_limit, 1);
315        })
316        .await;
317    }
318
319    #[tokio::test]
320    async fn rapid_decrease() {
321        TestService::run(|mut svc| async move {
322            let mut reqs = [None, None, None];
323            for &concurrent in &[1, 1, 2, 3] {
324                assert_eq!(svc.inner().current_limit, concurrent);
325                // This would ideally be done with something like:
326                // let reqs = futures::future::join_all((0..concurrent).map(svc.send)).await
327                // but that runs afoul of the borrow checker since `svc`
328                // must be borrowed mutable with a non-static
329                // lifetime. Resolving it is more work than it's worth
330                // for this test.
331                for (i, req) in reqs.iter_mut().take(concurrent).enumerate() {
332                    *req = Some(svc.send(i < concurrent - 1).await);
333                }
334                advance(Duration::from_secs(1)).await;
335                for req in reqs.iter_mut().take(concurrent) {
336                    req.take().unwrap().respond().await;
337                }
338            }
339
340            assert_eq!(svc.inner().current_limit, 4);
341
342            let req = svc.send(true).await;
343            advance(Duration::from_secs(1)).await;
344            req.defer().await;
345
346            assert_eq!(svc.inner().current_limit, 2);
347        })
348        .await;
349    }
350}