vector/sinks/util/adaptive_concurrency/
semaphore.rs

1// The `to_forget` mutex needs to be both a lock and a counter, so
2// clippy's warning that an AtomicUsize would work better is incorrect.
3#![allow(clippy::mutex_atomic)]
4
5use std::{
6    future::Future,
7    mem::{drop, replace},
8    pin::Pin,
9    sync::{Arc, Mutex},
10    task::{ready, Context, Poll},
11};
12
13use futures::future::{BoxFuture, FutureExt};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16/// Wrapper for `tokio::sync::Semaphore` that allows for shrinking the
17/// semaphore safely.
18#[derive(Debug)]
19pub(super) struct ShrinkableSemaphore {
20    semaphore: Arc<Semaphore>,
21    to_forget: Mutex<usize>,
22}
23
24impl ShrinkableSemaphore {
25    pub(super) fn new(size: usize) -> Self {
26        Self {
27            semaphore: Arc::new(Semaphore::new(size)),
28            to_forget: Mutex::new(0),
29        }
30    }
31
32    pub(super) fn acquire(
33        self: Arc<Self>,
34    ) -> impl Future<Output = OwnedSemaphorePermit> + Send + 'static {
35        MaybeForgetFuture {
36            master: Arc::clone(&self),
37            future: Box::pin(
38                Arc::clone(&self.semaphore)
39                    .acquire_owned()
40                    .map(|r| r.expect("Semaphore has been closed")),
41            ),
42        }
43    }
44
45    pub(super) fn forget_permits(&self, count: usize) {
46        // When forgetting permits, there may not be enough immediately
47        // available. If so, just increase the count we need to forget
48        // later and finish.
49        let mut to_forget = self
50            .to_forget
51            .lock()
52            .expect("Shrinkable semaphore mutex is poisoned");
53        for _ in 0..count {
54            match self.semaphore.try_acquire() {
55                Ok(permit) => permit.forget(),
56                Err(_) => *to_forget += 1,
57            }
58        }
59    }
60
61    pub(super) fn add_permits(&self, count: usize) {
62        let mut to_forget = self
63            .to_forget
64            .lock()
65            .expect("Shrinkable semaphore mutex is poisoned");
66        if *to_forget >= count {
67            *to_forget -= count;
68        } else {
69            self.semaphore.add_permits(count);
70            *to_forget = to_forget.saturating_sub(count);
71        }
72    }
73}
74
75/// A future that accounts for the possibility of needing to forget some
76/// number of permits before yielding a valid one.
77struct MaybeForgetFuture {
78    master: Arc<ShrinkableSemaphore>,
79    future: BoxFuture<'static, OwnedSemaphorePermit>,
80}
81
82impl Future for MaybeForgetFuture {
83    type Output = OwnedSemaphorePermit;
84    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85        let master = Arc::clone(&self.master);
86        let mut to_forget = master
87            .to_forget
88            .lock()
89            .expect("Shrinkable semaphore mutex is poisoned");
90        while *to_forget > 0 {
91            let permit = ready!(self.future.as_mut().poll(cx));
92            permit.forget();
93            *to_forget -= 1;
94            let future = Arc::clone(&self.master.semaphore)
95                .acquire_owned()
96                .map(|r| r.expect("Semaphore is closed"));
97            drop(replace(&mut self.future, Box::pin(future)));
98        }
99        drop(to_forget);
100        self.future.as_mut().poll(cx)
101    }
102}