use std::{
collections::HashMap,
num::{NonZeroU64, NonZeroU8},
sync::Arc,
time::Duration,
};
use hyper::Body;
use serde::{Deserialize, Serialize};
use tokio::sync::{mpsc::Receiver, oneshot::Sender};
use vector_lib::configurable::configurable_component;
use vector_lib::event::EventStatus;
use super::service::{HttpRequestBuilder, MetadataFields};
use crate::{
config::AcknowledgementsConfig,
http::HttpClient,
internal_events::{
SplunkIndexerAcknowledgementAPIError, SplunkIndexerAcknowledgementAckAdded,
SplunkIndexerAcknowledgementAcksRemoved,
},
};
#[configurable_component]
#[derive(Clone, Debug)]
#[serde(default)]
#[configurable(metadata(docs::advanced))]
pub struct HecClientAcknowledgementsConfig {
pub indexer_acknowledgements_enabled: bool,
#[configurable(metadata(docs::type_unit = "seconds"))]
pub query_interval: NonZeroU8,
pub retry_limit: NonZeroU8,
pub max_pending_acks: NonZeroU64,
#[serde(
default,
deserialize_with = "crate::serde::bool_or_struct",
flatten,
skip_serializing_if = "crate::serde::is_default"
)]
pub inner: AcknowledgementsConfig,
}
impl Default for HecClientAcknowledgementsConfig {
fn default() -> Self {
Self {
indexer_acknowledgements_enabled: true,
query_interval: NonZeroU8::new(10).unwrap(),
retry_limit: NonZeroU8::new(30).unwrap(),
max_pending_acks: NonZeroU64::new(1_000_000).unwrap(),
inner: Default::default(),
}
}
}
#[derive(Deserialize, Serialize, Eq, PartialEq, Debug)]
pub struct HecAckStatusRequest {
pub acks: Vec<u64>,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct HecAckStatusResponse {
pub acks: HashMap<u64, bool>,
}
#[derive(Debug)]
pub enum HecAckApiError {
ClientBuildRequest,
ClientParseResponse,
ClientSendQuery,
ServerSendQuery,
}
struct HecAckClient {
acks: HashMap<u64, (u8, Sender<EventStatus>)>,
retry_limit: u8,
client: HttpClient,
http_request_builder: Arc<HttpRequestBuilder>,
}
impl HecAckClient {
fn new(
retry_limit: u8,
client: HttpClient,
http_request_builder: Arc<HttpRequestBuilder>,
) -> Self {
Self {
acks: HashMap::new(),
retry_limit,
client,
http_request_builder,
}
}
fn add(&mut self, ack_id: u64, ack_event_status_sender: Sender<EventStatus>) {
self.acks
.insert(ack_id, (self.retry_limit, ack_event_status_sender));
emit!(SplunkIndexerAcknowledgementAckAdded);
}
async fn run(&mut self) {
let ack_query_body = self.get_ack_query_body();
if !ack_query_body.acks.is_empty() {
let ack_query_response = self.send_ack_query_request(&ack_query_body).await;
match ack_query_response {
Ok(ack_query_response) => {
debug!(message = "Received ack statuses.", ?ack_query_response);
let acked_ack_ids = ack_query_response
.acks
.iter()
.filter(|&(_ack_id, ack_status)| *ack_status)
.map(|(ack_id, _ack_status)| *ack_id)
.collect::<Vec<u64>>();
self.finalize_delivered_ack_ids(acked_ack_ids.as_slice());
self.expire_ack_ids_with_status(EventStatus::Rejected);
}
Err(error) => {
match error {
HecAckApiError::ClientParseResponse | HecAckApiError::ClientSendQuery => {
emit!(SplunkIndexerAcknowledgementAPIError {
message: "Unable to use indexer acknowledgements. Acknowledging based on initial 200 OK.",
error,
});
self.finalize_delivered_ack_ids(
self.acks.keys().copied().collect::<Vec<_>>().as_slice(),
);
}
_ => {
emit!(SplunkIndexerAcknowledgementAPIError {
message:
"Unable to send acknowledgement query request. Will retry.",
error,
});
self.expire_ack_ids_with_status(EventStatus::Errored);
}
}
}
};
}
}
fn finalize_delivered_ack_ids(&mut self, ack_ids: &[u64]) {
let mut removed_count = 0.0;
for ack_id in ack_ids {
if let Some((_, ack_event_status_sender)) = self.acks.remove(ack_id) {
_ = ack_event_status_sender.send(EventStatus::Delivered);
removed_count += 1.0;
debug!(message = "Finalized ack id.", ?ack_id);
}
}
emit!(SplunkIndexerAcknowledgementAcksRemoved {
count: removed_count
});
}
fn get_ack_query_body(&mut self) -> HecAckStatusRequest {
HecAckStatusRequest {
acks: self.acks.keys().copied().collect::<Vec<u64>>(),
}
}
fn decrement_retries(&mut self) {
for (retries, _) in self.acks.values_mut() {
*retries = retries.checked_sub(1).unwrap_or(0);
}
}
fn expire_ack_ids_with_status(&mut self, status: EventStatus) {
let expired_ack_ids = self
.acks
.iter()
.filter_map(|(ack_id, (retries, _))| (*retries == 0).then_some(*ack_id))
.collect::<Vec<_>>();
let mut removed_count = 0.0;
for ack_id in expired_ack_ids {
if let Some((_, ack_event_status_sender)) = self.acks.remove(&ack_id) {
_ = ack_event_status_sender.send(status);
removed_count += 1.0;
}
}
emit!(SplunkIndexerAcknowledgementAcksRemoved {
count: removed_count
});
}
async fn send_ack_query_request(
&mut self,
request_body: &HecAckStatusRequest,
) -> Result<HecAckStatusResponse, HecAckApiError> {
self.decrement_retries();
let request_body_bytes = crate::serde::json::to_bytes(request_body)
.map_err(|_| HecAckApiError::ClientBuildRequest)?
.freeze();
let request = self
.http_request_builder
.build_request(
request_body_bytes,
"/services/collector/ack",
None,
MetadataFields::default(),
false,
)
.map_err(|_| HecAckApiError::ClientBuildRequest)?;
let response = self
.client
.send(request.map(Body::from))
.await
.map_err(|_| HecAckApiError::ServerSendQuery)?;
let status = response.status();
if status.is_success() {
let response_body = hyper::body::to_bytes(response.into_body())
.await
.map_err(|_| HecAckApiError::ClientParseResponse)?;
serde_json::from_slice::<HecAckStatusResponse>(&response_body)
.map_err(|_| HecAckApiError::ClientParseResponse)
} else if status.is_client_error() {
Err(HecAckApiError::ClientSendQuery)
} else {
Err(HecAckApiError::ServerSendQuery)
}
}
}
pub async fn run_acknowledgements(
mut receiver: Receiver<(u64, Sender<EventStatus>)>,
client: HttpClient,
http_request_builder: Arc<HttpRequestBuilder>,
indexer_acknowledgements: HecClientAcknowledgementsConfig,
) {
let mut interval = tokio::time::interval(Duration::from_secs(
indexer_acknowledgements.query_interval.get() as u64,
));
let mut ack_client = HecAckClient::new(
indexer_acknowledgements.retry_limit.get(),
client,
http_request_builder,
);
loop {
tokio::select! {
_ = interval.tick() => {
ack_client.run().await;
},
ack_info = receiver.recv() => {
match ack_info {
Some((ack_id, tx)) => {
ack_client.add(ack_id, tx);
debug!(message = "Stored ack id.", ?ack_id);
},
None => break,
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use futures_util::{stream::FuturesUnordered, StreamExt};
use tokio::sync::oneshot::{self, Receiver};
use vector_lib::{config::proxy::ProxyConfig, event::EventStatus};
use super::HecAckClient;
use crate::{
http::HttpClient,
sinks::{
splunk_hec::common::{
acknowledgements::HecAckStatusRequest, service::HttpRequestBuilder, EndpointTarget,
},
util::Compression,
},
};
fn get_ack_client(retry_limit: u8) -> HecAckClient {
let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
let http_request_builder = HttpRequestBuilder::new(
String::from(""),
EndpointTarget::default(),
String::from(""),
Compression::default(),
);
HecAckClient::new(retry_limit, client, Arc::new(http_request_builder))
}
fn populate_ack_client(
ack_client: &mut HecAckClient,
ack_ids: &[u64],
) -> Vec<Receiver<EventStatus>> {
let mut ack_status_rxs = Vec::new();
for ack_id in ack_ids {
let (tx, rx) = oneshot::channel();
ack_client.add(*ack_id, tx);
ack_status_rxs.push(rx);
}
ack_status_rxs
}
#[test]
fn test_get_ack_query_body() {
let mut ack_client = get_ack_client(1);
let ack_ids = (0..100).collect::<Vec<u64>>();
_ = populate_ack_client(&mut ack_client, &ack_ids);
let expected_ack_body = HecAckStatusRequest { acks: ack_ids };
let mut ack_request_body = ack_client.get_ack_query_body();
ack_request_body.acks.sort_unstable();
assert_eq!(expected_ack_body, ack_request_body);
}
#[test]
fn test_decrement_retries() {
let mut ack_client = get_ack_client(1);
let ack_ids = (0..100).collect::<Vec<u64>>();
_ = populate_ack_client(&mut ack_client, &ack_ids);
let mut ack_request_body = ack_client.get_ack_query_body();
ack_request_body.acks.sort_unstable();
assert_eq!(ack_ids, ack_request_body.acks);
ack_client.decrement_retries();
ack_client.expire_ack_ids_with_status(EventStatus::Rejected);
let ack_request_body = ack_client.get_ack_query_body();
assert!(ack_request_body.acks.is_empty())
}
#[tokio::test]
async fn test_finalize_delivered_ack_ids() {
let mut ack_client = get_ack_client(1);
let ack_ids = (0..100).collect::<Vec<u64>>();
let ack_status_rxs = populate_ack_client(&mut ack_client, &ack_ids);
ack_client.finalize_delivered_ack_ids(ack_ids.as_slice());
let mut statuses = ack_status_rxs.into_iter().collect::<FuturesUnordered<_>>();
while let Some(status) = statuses.next().await {
assert_eq!(EventStatus::Delivered, status.unwrap());
}
}
#[tokio::test]
async fn test_expire_ack_ids_with_status() {
let mut ack_client = get_ack_client(1);
let ack_ids = (0..100).collect::<Vec<u64>>();
let ack_status_rxs = populate_ack_client(&mut ack_client, &ack_ids);
ack_client.decrement_retries();
ack_client.expire_ack_ids_with_status(EventStatus::Rejected);
let mut statuses = ack_status_rxs.into_iter().collect::<FuturesUnordered<_>>();
while let Some(status) = statuses.next().await {
assert_eq!(EventStatus::Rejected, status.unwrap());
}
}
}