use std::{
future::Future,
sync::{Arc, Mutex, MutexGuard},
time::{Duration, Instant},
};
use tokio::sync::OwnedSemaphorePermit;
use tower::timeout::error::Elapsed;
use vector_lib::internal_event::{InternalEventHandle as _, Registered};
use super::{instant_now, semaphore::ShrinkableSemaphore, AdaptiveConcurrencySettings};
#[cfg(test)]
use crate::test_util::stats::{TimeHistogram, TimeWeightedSum};
use crate::{
http::HttpError,
internal_events::{
AdaptiveConcurrencyAveragedRtt, AdaptiveConcurrencyInFlight, AdaptiveConcurrencyLimit,
AdaptiveConcurrencyLimitData, AdaptiveConcurrencyObservedRtt,
},
sinks::util::retries::{RetryAction, RetryLogic},
stats::{EwmaVar, Mean, MeanVariance},
};
#[derive(Clone)]
pub(super) struct Controller<L> {
semaphore: Arc<ShrinkableSemaphore>,
concurrency: Option<usize>,
settings: AdaptiveConcurrencySettings,
logic: L,
pub(super) inner: Arc<Mutex<Inner>>,
#[cfg(test)]
pub(super) stats: Arc<Mutex<ControllerStatistics>>,
limit: Registered<AdaptiveConcurrencyLimit>,
in_flight: Registered<AdaptiveConcurrencyInFlight>,
observed_rtt: Registered<AdaptiveConcurrencyObservedRtt>,
averaged_rtt: Registered<AdaptiveConcurrencyAveragedRtt>,
}
#[derive(Debug)]
pub(super) struct Inner {
pub(super) current_limit: usize,
in_flight: usize,
past_rtt: EwmaVar,
next_update: Instant,
current_rtt: Mean,
had_back_pressure: bool,
reached_limit: bool,
}
#[cfg(test)]
#[derive(Debug, Default)]
pub(super) struct ControllerStatistics {
pub(super) in_flight: TimeHistogram,
pub(super) concurrency_limit: TimeHistogram,
pub(super) observed_rtt: TimeWeightedSum,
pub(super) averaged_rtt: TimeWeightedSum,
}
impl<L> Controller<L> {
pub(super) fn new(
concurrency: Option<usize>,
settings: AdaptiveConcurrencySettings,
logic: L,
) -> Self {
let current_limit = concurrency.unwrap_or(settings.initial_concurrency);
Self {
semaphore: Arc::new(ShrinkableSemaphore::new(current_limit)),
concurrency,
settings,
logic,
inner: Arc::new(Mutex::new(Inner {
current_limit,
in_flight: 0,
past_rtt: EwmaVar::new(settings.ewma_alpha),
next_update: instant_now(),
current_rtt: Default::default(),
had_back_pressure: false,
reached_limit: false,
})),
#[cfg(test)]
stats: Arc::new(Mutex::new(ControllerStatistics::default())),
limit: register!(AdaptiveConcurrencyLimit),
in_flight: register!(AdaptiveConcurrencyInFlight),
observed_rtt: register!(AdaptiveConcurrencyObservedRtt),
averaged_rtt: register!(AdaptiveConcurrencyAveragedRtt),
}
}
pub(super) fn load(&self) -> f64 {
let inner = self.inner.lock().expect("Controller mutex is poisoned");
if inner.current_limit > 0 {
inner.in_flight as f64 / inner.current_limit as f64
} else {
1.0
}
}
pub(super) fn acquire(&self) -> impl Future<Output = OwnedSemaphorePermit> + Send + 'static {
Arc::clone(&self.semaphore).acquire()
}
pub(super) fn start_request(&self) {
let mut inner = self.inner.lock().expect("Controller mutex is poisoned");
#[cfg(test)]
{
let mut stats = self.stats.lock().expect("Stats mutex is poisoned");
stats.in_flight.add(inner.in_flight, instant_now());
}
inner.in_flight += 1;
if inner.in_flight >= inner.current_limit {
inner.reached_limit = true;
}
self.in_flight.emit(inner.in_flight as u64);
}
fn adjust_to_response_inner(&self, start: Instant, is_back_pressure: bool, use_rtt: bool) {
let now = instant_now();
let mut inner = self.inner.lock().expect("Controller mutex is poisoned");
let rtt = now.saturating_duration_since(start);
if use_rtt {
self.observed_rtt.emit(rtt);
}
let rtt = rtt.as_secs_f64();
if is_back_pressure {
inner.had_back_pressure = true;
}
#[cfg(test)]
let mut stats = self.stats.lock().expect("Stats mutex is poisoned");
#[cfg(test)]
{
if use_rtt {
stats.observed_rtt.add(rtt, now);
}
stats.in_flight.add(inner.in_flight, now);
}
inner.in_flight -= 1;
self.in_flight.emit(inner.in_flight as u64);
if use_rtt {
inner.current_rtt.update(rtt);
}
let current_rtt = inner.current_rtt.average();
#[cfg(test)]
let current_rtt = current_rtt.map(|c| (c * 1000000.0).round() / 1000000.0);
match inner.past_rtt.state() {
None => {
if let Some(current_rtt) = current_rtt {
inner.past_rtt.update(current_rtt);
inner.next_update = now + Duration::from_secs_f64(current_rtt);
}
}
Some(mut past_rtt) => {
if now >= inner.next_update {
#[cfg(test)]
{
if let Some(current_rtt) = current_rtt {
stats.averaged_rtt.add(current_rtt, now);
}
stats.concurrency_limit.add(inner.current_limit, now);
drop(stats); }
if let Some(current_rtt) = current_rtt {
self.averaged_rtt.emit(Duration::from_secs_f64(current_rtt));
}
if self.concurrency.is_none() {
self.manage_limit(&mut inner, past_rtt, current_rtt);
}
if let Some(current_rtt) = current_rtt {
past_rtt = inner.past_rtt.update(current_rtt);
}
inner.next_update = now + Duration::from_secs_f64(past_rtt.mean);
inner.current_rtt = Default::default();
inner.had_back_pressure = false;
inner.reached_limit = false;
}
}
}
}
fn manage_limit(
&self,
inner: &mut MutexGuard<Inner>,
past_rtt: MeanVariance,
current_rtt: Option<f64>,
) {
let past_rtt_deviation = past_rtt.variance.sqrt();
let threshold = past_rtt_deviation * self.settings.rtt_deviation_scale;
if inner.current_limit < self.settings.max_concurrency_limit
&& inner.reached_limit
&& !inner.had_back_pressure
&& current_rtt.is_some()
&& current_rtt.unwrap() <= past_rtt.mean
{
self.semaphore.add_permits(1);
inner.current_limit += 1;
}
else if inner.current_limit > 1
&& (inner.had_back_pressure || current_rtt.unwrap_or(0.0) >= past_rtt.mean + threshold)
{
let new_limit =
((inner.current_limit as f64 * self.settings.decrease_ratio) as usize).max(1);
self.semaphore
.forget_permits(inner.current_limit - new_limit);
inner.current_limit = new_limit;
}
self.limit.emit(AdaptiveConcurrencyLimitData {
concurrency: inner.current_limit as u64,
reached_limit: inner.reached_limit,
had_back_pressure: inner.had_back_pressure,
current_rtt: current_rtt.map(Duration::from_secs_f64),
past_rtt: Duration::from_secs_f64(past_rtt.mean),
past_rtt_deviation: Duration::from_secs_f64(past_rtt_deviation),
});
}
}
impl<L> Controller<L>
where
L: RetryLogic,
{
pub(super) fn adjust_to_response(
&self,
start: Instant,
response: &Result<L::Response, crate::Error>,
) {
let response_action = response
.as_ref()
.map(|resp| self.logic.should_retry_response(resp));
let is_back_pressure = match &response_action {
Ok(action) => matches!(action, RetryAction::Retry(_)),
Err(error) => {
if let Some(error) = error.downcast_ref::<L::Error>() {
self.logic.is_retriable_error(error)
} else if error.downcast_ref::<Elapsed>().is_some() {
true
} else if error.downcast_ref::<HttpError>().is_some() {
false
} else {
warn!(
message = "Unhandled error response.",
%error,
internal_log_rate_limit = true
);
false
}
}
};
let use_rtt = matches!(response_action, Ok(RetryAction::Successful));
self.adjust_to_response_inner(start, is_back_pressure, use_rtt)
}
}