1use std::{
2 fmt,
3 sync::Arc,
4 task::{ready, Context, Poll},
5};
6
7use bytes::Bytes;
8use futures_util::future::BoxFuture;
9use http::Request;
10use serde::{Deserialize, Serialize};
11use snafu::ResultExt;
12use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
13use tokio_util::sync::PollSemaphore;
14use tower::Service;
15use uuid::Uuid;
16use vector_lib::event::EventStatus;
17use vector_lib::request_metadata::MetaDescriptive;
18
19use super::{
20 acknowledgements::{run_acknowledgements, HecClientAcknowledgementsConfig},
21 EndpointTarget,
22};
23use crate::{
24 http::HttpClient,
25 internal_events::{SplunkIndexerAcknowledgementUnavailableError, SplunkResponseParseError},
26 sinks::{
27 splunk_hec::common::{build_uri, request::HecRequest, response::HecResponse},
28 util::{sink::Response, Compression},
29 UriParseSnafu,
30 },
31};
32
33pub struct HecService<S> {
34 pub inner: S,
35 ack_finalizer_tx: Option<mpsc::Sender<(u64, oneshot::Sender<EventStatus>)>>,
36 ack_slots: PollSemaphore,
37 current_ack_slot: Option<OwnedSemaphorePermit>,
38}
39
40#[derive(Deserialize, Serialize, Debug)]
41struct HecAckResponseBody {
42 #[serde(alias = "ackId")]
43 ack_id: Option<u64>,
44}
45
46impl<S> HecService<S>
47where
48 S: Service<HecRequest> + Send + 'static,
49 S::Future: Send + 'static,
50 S::Response: Response + ResponseExt + Send + 'static,
51 S::Error: fmt::Debug + Into<crate::Error> + Send,
52{
53 pub fn new(
54 inner: S,
55 ack_client: Option<HttpClient>,
56 http_request_builder: Arc<HttpRequestBuilder>,
57 indexer_acknowledgements: HecClientAcknowledgementsConfig,
58 ) -> Self {
59 let max_pending_acks = indexer_acknowledgements.max_pending_acks.get();
60 let tx = if let Some(ack_client) = ack_client {
61 let (tx, rx) = mpsc::channel(128);
62 tokio::spawn(run_acknowledgements(
63 rx,
64 ack_client,
65 Arc::clone(&http_request_builder),
66 indexer_acknowledgements,
67 ));
68 Some(tx)
69 } else {
70 None
71 };
72
73 let ack_slots = PollSemaphore::new(Arc::new(Semaphore::new(max_pending_acks as usize)));
74 Self {
75 inner,
76 ack_finalizer_tx: tx,
77 ack_slots,
78 current_ack_slot: None,
79 }
80 }
81}
82
83impl<S> Service<HecRequest> for HecService<S>
84where
85 S: Service<HecRequest> + Send + 'static,
86 S::Future: Send + 'static,
87 S::Response: Response + ResponseExt + Send + 'static,
88 S::Error: fmt::Debug + Into<crate::Error> + Send,
89{
90 type Response = HecResponse;
91 type Error = crate::Error;
92 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
93
94 fn poll_ready(&mut self, cx: &mut Context) -> std::task::Poll<Result<(), Self::Error>> {
95 if self.ack_finalizer_tx.is_none() || self.current_ack_slot.is_some() {
98 self.inner.poll_ready(cx).map_err(Into::into)
99 } else {
100 match ready!(self.ack_slots.poll_acquire(cx)) {
101 Some(permit) => {
102 self.current_ack_slot.replace(permit);
103 self.inner.poll_ready(cx).map_err(Into::into)
104 }
105 None => Poll::Ready(Err(
106 "Indexer acknowledgements semaphore unexpectedly closed".into(),
107 )),
108 }
109 }
110 }
111
112 fn call(&mut self, mut req: HecRequest) -> Self::Future {
113 let ack_finalizer_tx = self.ack_finalizer_tx.clone();
114 let ack_slot = self.current_ack_slot.take();
115
116 let metadata = std::mem::take(req.metadata_mut());
117 let events_count = metadata.event_count();
118 let events_byte_size = metadata.into_events_estimated_json_encoded_byte_size();
119 let response = self.inner.call(req);
120
121 Box::pin(async move {
122 let response = response.await.map_err(Into::into)?;
123 let event_status = if response.is_successful() {
124 if let Some(ack_finalizer_tx) = ack_finalizer_tx {
125 let _ack_slot = ack_slot.expect("poll_ready not called before invoking call");
126 let body = serde_json::from_slice::<HecAckResponseBody>(response.body());
127 match body {
128 Ok(body) => {
129 if let Some(ack_id) = body.ack_id {
130 let (tx, rx) = oneshot::channel();
131 match ack_finalizer_tx.send((ack_id, tx)).await {
132 Ok(_) => rx.await.unwrap_or(EventStatus::Rejected),
133 Err(error) => {
135 emit!(SplunkIndexerAcknowledgementUnavailableError {
136 error
137 });
138 EventStatus::Delivered
139 }
140 }
141 } else {
142 EventStatus::Delivered
144 }
145 }
146 Err(error) => {
147 emit!(SplunkResponseParseError { error });
149 EventStatus::Delivered
150 }
151 }
152 } else {
153 EventStatus::Delivered
155 }
156 } else if response.is_transient() {
157 EventStatus::Errored
158 } else {
159 EventStatus::Rejected
160 };
161
162 Ok(HecResponse {
163 event_status,
164 events_count,
165 events_byte_size,
166 })
167 })
168 }
169}
170
171pub trait ResponseExt {
172 fn body(&self) -> &Bytes;
173}
174
175impl ResponseExt for http::Response<Bytes> {
176 fn body(&self) -> &Bytes {
177 self.body()
178 }
179}
180
181pub struct HttpRequestBuilder {
182 pub endpoint_target: EndpointTarget,
183 pub endpoint: String,
184 pub default_token: String,
185 pub compression: Compression,
186 pub channel: String,
189}
190
191#[derive(Default)]
192pub(super) struct MetadataFields {
193 pub(super) source: Option<String>,
194 pub(super) sourcetype: Option<String>,
195 pub(super) index: Option<String>,
196 pub(super) host: Option<String>,
197}
198
199impl HttpRequestBuilder {
200 pub fn new(
201 endpoint: String,
202 endpoint_target: EndpointTarget,
203 default_token: String,
204 compression: Compression,
205 ) -> Self {
206 let channel = Uuid::new_v4().hyphenated().to_string();
207 Self {
208 endpoint,
209 endpoint_target,
210 default_token,
211 compression,
212 channel,
213 }
214 }
215
216 pub(super) fn build_request(
217 &self,
218 body: Bytes,
219 path: &str,
220 passthrough_token: Option<Arc<str>>,
221 metadata_fields: MetadataFields,
222 auto_extract_timestamp: bool,
223 ) -> Result<Request<Bytes>, crate::Error> {
224 let uri = match self.endpoint_target {
225 EndpointTarget::Raw => {
226 let metadata = [
229 (super::SOURCE_FIELD, metadata_fields.source),
230 (super::SOURCETYPE_FIELD, metadata_fields.sourcetype),
231 (super::INDEX_FIELD, metadata_fields.index),
232 (super::HOST_FIELD, metadata_fields.host),
233 ]
234 .into_iter()
235 .filter_map(|(key, value)| value.map(|value| (key, value)));
236 build_uri(self.endpoint.as_str(), path, metadata).context(UriParseSnafu)?
237 }
238 EndpointTarget::Event => build_uri(
239 self.endpoint.as_str(),
240 path,
241 if auto_extract_timestamp {
242 Some((super::AUTO_EXTRACT_TIMESTAMP_FIELD, "true".to_string()))
243 } else {
244 None
245 },
246 )
247 .context(UriParseSnafu)?,
248 };
249
250 let mut builder = Request::post(uri)
251 .header("Content-Type", "application/json")
252 .header(
253 "Authorization",
254 format!(
255 "Splunk {}",
256 passthrough_token.unwrap_or_else(|| self.default_token.as_str().into())
257 ),
258 )
259 .header("X-Splunk-Request-Channel", self.channel.as_str());
260
261 if let Some(ce) = self.compression.content_encoding() {
262 builder = builder.header("Content-Encoding", ce);
263 }
264
265 builder.body(body).map_err(Into::into)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use std::{
272 collections::HashMap,
273 future::poll_fn,
274 num::{NonZeroU64, NonZeroU8, NonZeroUsize},
275 sync::{
276 atomic::{AtomicU64, Ordering},
277 Arc,
278 },
279 task::Poll,
280 };
281
282 use bytes::Bytes;
283 use futures_util::{poll, stream::FuturesUnordered, StreamExt};
284 use tower::{util::BoxService, Service, ServiceExt};
285 use vector_lib::internal_event::CountByteSize;
286 use vector_lib::{
287 config::proxy::ProxyConfig,
288 event::{EventFinalizers, EventStatus},
289 };
290 use wiremock::{
291 matchers::{header, header_exists, method, path},
292 Mock, MockServer, Request, Respond, ResponseTemplate,
293 };
294
295 use crate::{
296 http::HttpClient,
297 sinks::{
298 splunk_hec::common::{
299 acknowledgements::{
300 HecAckStatusRequest, HecAckStatusResponse, HecClientAcknowledgementsConfig,
301 },
302 build_http_batch_service,
303 request::HecRequest,
304 service::{HecAckResponseBody, HecService, HttpRequestBuilder},
305 EndpointTarget,
306 },
307 util::{metadata::RequestMetadataBuilder, Compression},
308 },
309 };
310
311 const TOKEN: &str = "token";
312 static ACK_ID: AtomicU64 = AtomicU64::new(0);
313
314 fn get_hec_service(
315 endpoint: String,
316 acknowledgements_config: HecClientAcknowledgementsConfig,
317 ) -> HecService<BoxService<HecRequest, http::Response<Bytes>, crate::Error>> {
318 let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
319 let http_request_builder = Arc::new(HttpRequestBuilder::new(
320 endpoint,
321 EndpointTarget::default(),
322 String::from(TOKEN),
323 Compression::default(),
324 ));
325 let http_service = build_http_batch_service(
326 client.clone(),
327 Arc::clone(&http_request_builder),
328 EndpointTarget::Event,
329 false,
330 );
331 HecService::new(
332 BoxService::new(http_service),
333 Some(client),
334 http_request_builder,
335 acknowledgements_config,
336 )
337 }
338
339 fn get_hec_request() -> HecRequest {
340 let body = Bytes::from("test-message");
341 let events_byte_size = body.len();
342
343 let builder = RequestMetadataBuilder::new(
344 1,
345 events_byte_size,
346 CountByteSize(1, events_byte_size.into()).into(),
347 );
348 let bytes_len =
349 NonZeroUsize::new(events_byte_size).expect("payload should never be zero length");
350 let metadata = builder.with_request_size(bytes_len);
351
352 HecRequest {
353 body,
354 metadata,
355 finalizers: EventFinalizers::default(),
356 passthrough_token: None,
357 index: None,
358 source: None,
359 sourcetype: None,
360 host: None,
361 }
362 }
363
364 async fn get_hec_mock_server<R>(acknowledgements_enabled: bool, ack_response: R) -> MockServer
365 where
366 R: Respond + 'static,
367 {
368 let mock_server = MockServer::start().await;
370
371 Mock::given(method("POST"))
372 .and(path("/services/collector/event"))
373 .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
374 .and(header_exists("X-Splunk-Request-Channel"))
375 .respond_with(move |_: &Request| {
376 let ack_id =
377 acknowledgements_enabled.then(|| ACK_ID.fetch_add(1, Ordering::Relaxed));
378 ResponseTemplate::new(200).set_body_json(HecAckResponseBody { ack_id })
379 })
380 .mount(&mock_server)
381 .await;
382
383 Mock::given(method("POST"))
384 .and(path("/services/collector/ack"))
385 .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
386 .and(header_exists("X-Splunk-Request-Channel"))
387 .respond_with(ack_response)
388 .mount(&mock_server)
389 .await;
390
391 mock_server
392 }
393
394 fn ack_response_always_succeed(req: &Request) -> ResponseTemplate {
395 let req = serde_json::from_slice::<HecAckStatusRequest>(req.body.as_slice()).unwrap();
396 ResponseTemplate::new(200).set_body_json(HecAckStatusResponse {
397 acks: req
398 .acks
399 .into_iter()
400 .map(|ack_id| (ack_id, true))
401 .collect::<HashMap<_, _>>(),
402 })
403 }
404
405 fn ack_response_always_fail(req: &Request) -> ResponseTemplate {
406 let req = serde_json::from_slice::<HecAckStatusRequest>(req.body.as_slice()).unwrap();
407 ResponseTemplate::new(200).set_body_json(HecAckStatusResponse {
408 acks: req
409 .acks
410 .into_iter()
411 .map(|ack_id| (ack_id, false))
412 .collect::<HashMap<_, _>>(),
413 })
414 }
415
416 #[tokio::test]
417 async fn acknowledgements_disabled_in_config() {
418 let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
419
420 let acknowledgements_config = HecClientAcknowledgementsConfig {
421 indexer_acknowledgements_enabled: false,
422 ..Default::default()
423 };
424 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
425
426 let request = get_hec_request();
427 let response = service.ready().await.unwrap().call(request).await.unwrap();
428 assert_eq!(EventStatus::Delivered, response.event_status)
429 }
430
431 #[tokio::test]
432 async fn acknowledgements_enabled_on_server() {
433 let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
434
435 let acknowledgements_config = HecClientAcknowledgementsConfig {
436 query_interval: NonZeroU8::new(1).unwrap(),
437 ..Default::default()
438 };
439 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
440
441 let mut responses = FuturesUnordered::new();
442 responses.push(service.ready().await.unwrap().call(get_hec_request()));
443 responses.push(service.ready().await.unwrap().call(get_hec_request()));
444 responses.push(service.ready().await.unwrap().call(get_hec_request()));
445 while let Some(response) = responses.next().await {
446 assert_eq!(EventStatus::Delivered, response.unwrap().event_status)
447 }
448 }
449
450 #[tokio::test]
451 async fn acknowledgements_disabled_on_server() {
452 let ack_response = |_: &Request| ResponseTemplate::new(400);
453 let mock_server = get_hec_mock_server(false, ack_response).await;
454
455 let acknowledgements_config = HecClientAcknowledgementsConfig {
456 query_interval: NonZeroU8::new(1).unwrap(),
457 ..Default::default()
458 };
459 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
460
461 let request = get_hec_request();
462 let response = service.ready().await.unwrap().call(request).await.unwrap();
463 assert_eq!(EventStatus::Delivered, response.event_status)
464 }
465
466 #[tokio::test]
467 async fn acknowledgements_enabled_on_server_retry_limit_exceeded() {
468 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
469
470 let acknowledgements_config = HecClientAcknowledgementsConfig {
471 query_interval: NonZeroU8::new(1).unwrap(),
472 retry_limit: NonZeroU8::new(1).unwrap(),
473 ..Default::default()
474 };
475 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
476
477 let request = get_hec_request();
478 let response = service.ready().await.unwrap().call(request).await.unwrap();
479 assert_eq!(EventStatus::Rejected, response.event_status)
480 }
481
482 #[tokio::test]
483 async fn acknowledgements_server_changed_ack_response_format() {
484 let ack_response = |_: &Request| {
485 ResponseTemplate::new(200)
486 .set_body_json(serde_json::json!(r#"{ "new": "a new response body" }"#))
487 };
488 let mock_server = get_hec_mock_server(true, ack_response).await;
489
490 let acknowledgements_config = HecClientAcknowledgementsConfig {
491 query_interval: NonZeroU8::new(1).unwrap(),
492 retry_limit: NonZeroU8::new(3).unwrap(),
493 ..Default::default()
494 };
495 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
496
497 let request = get_hec_request();
498 let response = service.ready().await.unwrap().call(request).await.unwrap();
499 assert_eq!(EventStatus::Delivered, response.event_status)
500 }
501
502 #[tokio::test]
503 async fn acknowledgements_enabled_on_server_ack_endpoint_failing() {
504 let ack_response = |_: &Request| ResponseTemplate::new(503);
505 let mock_server = get_hec_mock_server(true, ack_response).await;
506
507 let acknowledgements_config = HecClientAcknowledgementsConfig {
508 query_interval: NonZeroU8::new(1).unwrap(),
509 retry_limit: NonZeroU8::new(3).unwrap(),
510 ..Default::default()
511 };
512 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
513
514 let request = get_hec_request();
515 let response = service.ready().await.unwrap().call(request).await.unwrap();
516 assert_eq!(EventStatus::Errored, response.event_status)
517 }
518
519 #[tokio::test]
520 async fn acknowledgements_server_changed_event_response_format() {
521 let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
522 Mock::given(method("POST"))
524 .and(path("/services/collector/event"))
525 .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
526 .and(header_exists("X-Splunk-Request-Channel"))
527 .respond_with(move |_: &Request| {
528 ResponseTemplate::new(200).set_body_json(r#"{ "new": "a new response body" }"#)
529 })
530 .mount(&mock_server)
531 .await;
532
533 let acknowledgements_config = HecClientAcknowledgementsConfig {
534 query_interval: NonZeroU8::new(1).unwrap(),
535 retry_limit: NonZeroU8::new(1).unwrap(),
536 ..Default::default()
537 };
538 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
539
540 let request = get_hec_request();
541 let response = service.ready().await.unwrap().call(request).await.unwrap();
542 assert_eq!(EventStatus::Delivered, response.event_status)
543 }
544
545 #[tokio::test]
546 async fn service_poll_ready_multiple_times() {
547 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
548 let mut service = get_hec_service(mock_server.uri(), Default::default());
549
550 assert!(service.ready().await.is_ok());
551 assert!(service.ready().await.is_ok());
554 }
555
556 #[tokio::test]
557 #[should_panic]
558 async fn service_call_without_poll_ready() {
559 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
560 let mut service = get_hec_service(mock_server.uri(), Default::default());
561
562 _ = service.call(get_hec_request()).await;
563 }
564
565 #[tokio::test]
566 async fn acknowledgements_max_pending_acks_reached() {
567 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
568
569 let acknowledgements_config = HecClientAcknowledgementsConfig {
570 query_interval: NonZeroU8::new(1).unwrap(),
571 retry_limit: NonZeroU8::new(5).unwrap(),
572 max_pending_acks: NonZeroU64::new(1).unwrap(),
574 ..Default::default()
575 };
576 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
577
578 let pending_call = service.ready().await.unwrap().call(get_hec_request());
580 assert!(matches!(
582 poll!(poll_fn(|cx| service.poll_ready(cx))),
583 Poll::Pending
584 ));
585 let response = pending_call.await.unwrap();
587 assert_eq!(EventStatus::Rejected, response.event_status);
588 assert!(matches!(
590 poll!(poll_fn(|cx| service.poll_ready(cx))),
591 Poll::Ready(Ok(_))
592 ));
593 }
594}