1use std::{
2 fmt,
3 sync::Arc,
4 task::{Context, Poll, ready},
5};
6
7use bytes::Bytes;
8use futures_util::future::BoxFuture;
9use http::Request;
10use serde::{Deserialize, Serialize};
11use snafu::ResultExt;
12use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot};
13use tokio_util::sync::PollSemaphore;
14use tower::Service;
15use uuid::Uuid;
16use vector_lib::{event::EventStatus, request_metadata::MetaDescriptive};
17
18use super::{
19 EndpointTarget,
20 acknowledgements::{HecClientAcknowledgementsConfig, run_acknowledgements},
21};
22use crate::{
23 http::HttpClient,
24 internal_events::{SplunkIndexerAcknowledgementUnavailableError, SplunkResponseParseError},
25 sinks::{
26 UriParseSnafu,
27 splunk_hec::common::{build_uri, request::HecRequest, response::HecResponse},
28 util::{Compression, sink::Response},
29 },
30};
31
32pub struct HecService<S> {
33 pub inner: S,
34 ack_finalizer_tx: Option<mpsc::Sender<(u64, oneshot::Sender<EventStatus>)>>,
35 ack_slots: PollSemaphore,
36 current_ack_slot: Option<OwnedSemaphorePermit>,
37}
38
39#[derive(Deserialize, Serialize, Debug)]
40struct HecAckResponseBody {
41 #[serde(alias = "ackId")]
42 ack_id: Option<u64>,
43}
44
45impl<S> HecService<S>
46where
47 S: Service<HecRequest> + Send + 'static,
48 S::Future: Send + 'static,
49 S::Response: Response + ResponseExt + Send + 'static,
50 S::Error: fmt::Debug + Into<crate::Error> + Send,
51{
52 pub fn new(
53 inner: S,
54 ack_client: Option<HttpClient>,
55 http_request_builder: Arc<HttpRequestBuilder>,
56 indexer_acknowledgements: HecClientAcknowledgementsConfig,
57 ) -> Self {
58 let max_pending_acks = indexer_acknowledgements.max_pending_acks.get();
59 let tx = if let Some(ack_client) = ack_client {
60 let (tx, rx) = mpsc::channel(128);
61 tokio::spawn(run_acknowledgements(
62 rx,
63 ack_client,
64 Arc::clone(&http_request_builder),
65 indexer_acknowledgements,
66 ));
67 Some(tx)
68 } else {
69 None
70 };
71
72 let ack_slots = PollSemaphore::new(Arc::new(Semaphore::new(max_pending_acks as usize)));
73 Self {
74 inner,
75 ack_finalizer_tx: tx,
76 ack_slots,
77 current_ack_slot: None,
78 }
79 }
80}
81
82impl<S> Service<HecRequest> for HecService<S>
83where
84 S: Service<HecRequest> + Send + 'static,
85 S::Future: Send + 'static,
86 S::Response: Response + ResponseExt + Send + 'static,
87 S::Error: fmt::Debug + Into<crate::Error> + Send,
88{
89 type Response = HecResponse;
90 type Error = crate::Error;
91 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
92
93 fn poll_ready(&mut self, cx: &mut Context) -> std::task::Poll<Result<(), Self::Error>> {
94 if self.ack_finalizer_tx.is_none() || self.current_ack_slot.is_some() {
97 self.inner.poll_ready(cx).map_err(Into::into)
98 } else {
99 match ready!(self.ack_slots.poll_acquire(cx)) {
100 Some(permit) => {
101 self.current_ack_slot.replace(permit);
102 self.inner.poll_ready(cx).map_err(Into::into)
103 }
104 None => Poll::Ready(Err(
105 "Indexer acknowledgements semaphore unexpectedly closed".into(),
106 )),
107 }
108 }
109 }
110
111 fn call(&mut self, mut req: HecRequest) -> Self::Future {
112 let ack_finalizer_tx = self.ack_finalizer_tx.clone();
113 let ack_slot = self.current_ack_slot.take();
114
115 let metadata = std::mem::take(req.metadata_mut());
116 let events_count = metadata.event_count();
117 let events_byte_size = metadata.into_events_estimated_json_encoded_byte_size();
118 let response = self.inner.call(req);
119
120 Box::pin(async move {
121 let response = response.await.map_err(Into::into)?;
122 let event_status = if response.is_successful() {
123 if let Some(ack_finalizer_tx) = ack_finalizer_tx {
124 let _ack_slot = ack_slot.expect("poll_ready not called before invoking call");
125 let body = serde_json::from_slice::<HecAckResponseBody>(response.body());
126 match body {
127 Ok(body) => {
128 if let Some(ack_id) = body.ack_id {
129 let (tx, rx) = oneshot::channel();
130 match ack_finalizer_tx.send((ack_id, tx)).await {
131 Ok(_) => rx.await.unwrap_or(EventStatus::Rejected),
132 Err(error) => {
134 emit!(SplunkIndexerAcknowledgementUnavailableError {
135 error
136 });
137 EventStatus::Delivered
138 }
139 }
140 } else {
141 EventStatus::Delivered
143 }
144 }
145 Err(error) => {
146 emit!(SplunkResponseParseError { error });
148 EventStatus::Delivered
149 }
150 }
151 } else {
152 EventStatus::Delivered
154 }
155 } else if response.is_transient() {
156 EventStatus::Errored
157 } else {
158 EventStatus::Rejected
159 };
160
161 Ok(HecResponse {
162 event_status,
163 events_count,
164 events_byte_size,
165 })
166 })
167 }
168}
169
170pub trait ResponseExt {
171 fn body(&self) -> &Bytes;
172}
173
174impl ResponseExt for http::Response<Bytes> {
175 fn body(&self) -> &Bytes {
176 self.body()
177 }
178}
179
180#[derive(Clone)]
181pub struct HttpRequestBuilder {
182 pub endpoint_target: EndpointTarget,
183 pub endpoint: String,
184 pub default_token: String,
185 pub compression: Compression,
186 pub channel: String,
189}
190
191#[derive(Default)]
192pub(super) struct MetadataFields {
193 pub(super) source: Option<String>,
194 pub(super) sourcetype: Option<String>,
195 pub(super) index: Option<String>,
196 pub(super) host: Option<String>,
197}
198
199impl HttpRequestBuilder {
200 pub fn new(
201 endpoint: String,
202 endpoint_target: EndpointTarget,
203 default_token: String,
204 compression: Compression,
205 ) -> Self {
206 let channel = Uuid::new_v4().hyphenated().to_string();
207 Self {
208 endpoint,
209 endpoint_target,
210 default_token,
211 compression,
212 channel,
213 }
214 }
215
216 pub(super) fn build_request(
217 &self,
218 body: Bytes,
219 path: &str,
220 passthrough_token: Option<Arc<str>>,
221 metadata_fields: MetadataFields,
222 auto_extract_timestamp: bool,
223 ) -> Result<Request<Bytes>, crate::Error> {
224 let uri = match self.endpoint_target {
225 EndpointTarget::Raw => {
226 let metadata = [
229 (super::SOURCE_FIELD, metadata_fields.source),
230 (super::SOURCETYPE_FIELD, metadata_fields.sourcetype),
231 (super::INDEX_FIELD, metadata_fields.index),
232 (super::HOST_FIELD, metadata_fields.host),
233 ]
234 .into_iter()
235 .filter_map(|(key, value)| value.map(|value| (key, value)));
236 build_uri(self.endpoint.as_str(), path, metadata).context(UriParseSnafu)?
237 }
238 EndpointTarget::Event => build_uri(
239 self.endpoint.as_str(),
240 path,
241 if auto_extract_timestamp {
242 Some((super::AUTO_EXTRACT_TIMESTAMP_FIELD, "true".to_string()))
243 } else {
244 None
245 },
246 )
247 .context(UriParseSnafu)?,
248 };
249
250 let mut builder = Request::post(uri)
251 .header("Content-Type", "application/json")
252 .header(
253 "Authorization",
254 format!(
255 "Splunk {}",
256 passthrough_token.unwrap_or_else(|| self.default_token.as_str().into())
257 ),
258 )
259 .header("X-Splunk-Request-Channel", self.channel.as_str());
260
261 if let Some(ce) = self.compression.content_encoding() {
262 builder = builder.header("Content-Encoding", ce);
263 }
264
265 builder.body(body).map_err(Into::into)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use std::{
272 collections::HashMap,
273 future::poll_fn,
274 num::{NonZeroU8, NonZeroU64, NonZeroUsize},
275 sync::{
276 Arc,
277 atomic::{AtomicU64, Ordering},
278 },
279 task::Poll,
280 };
281
282 use bytes::Bytes;
283 use futures_util::{StreamExt, poll, stream::FuturesUnordered};
284 use tower::{Service, ServiceExt, util::BoxService};
285 use vector_lib::{
286 config::proxy::ProxyConfig,
287 event::{EventFinalizers, EventStatus},
288 internal_event::CountByteSize,
289 };
290 use wiremock::{
291 Mock, MockServer, Request, Respond, ResponseTemplate,
292 matchers::{header, header_exists, method, path},
293 };
294
295 use crate::{
296 http::HttpClient,
297 sinks::{
298 splunk_hec::common::{
299 EndpointTarget,
300 acknowledgements::{
301 HecAckStatusRequest, HecAckStatusResponse, HecClientAcknowledgementsConfig,
302 },
303 build_http_batch_service,
304 request::HecRequest,
305 service::{HecAckResponseBody, HecService, HttpRequestBuilder},
306 },
307 util::{Compression, metadata::RequestMetadataBuilder},
308 },
309 };
310
311 const TOKEN: &str = "token";
312 static ACK_ID: AtomicU64 = AtomicU64::new(0);
313
314 fn get_hec_service(
315 endpoint: String,
316 acknowledgements_config: HecClientAcknowledgementsConfig,
317 ) -> HecService<BoxService<HecRequest, http::Response<Bytes>, crate::Error>> {
318 let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
319 let http_request_builder = Arc::new(HttpRequestBuilder::new(
320 endpoint,
321 EndpointTarget::default(),
322 String::from(TOKEN),
323 Compression::default(),
324 ));
325 let http_service = build_http_batch_service(
326 client.clone(),
327 Arc::clone(&http_request_builder),
328 EndpointTarget::Event,
329 false,
330 );
331 HecService::new(
332 BoxService::new(http_service),
333 Some(client),
334 http_request_builder,
335 acknowledgements_config,
336 )
337 }
338
339 fn get_hec_request() -> HecRequest {
340 let body = Bytes::from("test-message");
341 let events_byte_size = body.len();
342
343 let builder = RequestMetadataBuilder::new(
344 1,
345 events_byte_size,
346 CountByteSize(1, events_byte_size.into()).into(),
347 );
348 let bytes_len =
349 NonZeroUsize::new(events_byte_size).expect("payload should never be zero length");
350 let metadata = builder.with_request_size(bytes_len);
351
352 HecRequest {
353 body,
354 metadata,
355 finalizers: EventFinalizers::default(),
356 passthrough_token: None,
357 index: None,
358 source: None,
359 sourcetype: None,
360 host: None,
361 }
362 }
363
364 async fn get_hec_mock_server<R>(acknowledgements_enabled: bool, ack_response: R) -> MockServer
365 where
366 R: Respond + 'static,
367 {
368 let mock_server = MockServer::start().await;
370
371 Mock::given(method("POST"))
372 .and(path("/services/collector/event"))
373 .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
374 .and(header_exists("X-Splunk-Request-Channel"))
375 .respond_with(move |_: &Request| {
376 let ack_id =
377 acknowledgements_enabled.then(|| ACK_ID.fetch_add(1, Ordering::Relaxed));
378 ResponseTemplate::new(200).set_body_json(HecAckResponseBody { ack_id })
379 })
380 .mount(&mock_server)
381 .await;
382
383 Mock::given(method("POST"))
384 .and(path("/services/collector/ack"))
385 .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
386 .and(header_exists("X-Splunk-Request-Channel"))
387 .respond_with(ack_response)
388 .mount(&mock_server)
389 .await;
390
391 mock_server
392 }
393
394 fn ack_response_always_succeed(req: &Request) -> ResponseTemplate {
395 let req = serde_json::from_slice::<HecAckStatusRequest>(req.body.as_slice()).unwrap();
396 ResponseTemplate::new(200).set_body_json(HecAckStatusResponse {
397 acks: req
398 .acks
399 .into_iter()
400 .map(|ack_id| (ack_id, true))
401 .collect::<HashMap<_, _>>(),
402 })
403 }
404
405 fn ack_response_always_fail(req: &Request) -> ResponseTemplate {
406 let req = serde_json::from_slice::<HecAckStatusRequest>(req.body.as_slice()).unwrap();
407 ResponseTemplate::new(200).set_body_json(HecAckStatusResponse {
408 acks: req
409 .acks
410 .into_iter()
411 .map(|ack_id| (ack_id, false))
412 .collect::<HashMap<_, _>>(),
413 })
414 }
415
416 #[tokio::test]
417 async fn acknowledgements_disabled_in_config() {
418 let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
419
420 let acknowledgements_config = HecClientAcknowledgementsConfig {
421 indexer_acknowledgements_enabled: false,
422 ..Default::default()
423 };
424 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
425
426 let request = get_hec_request();
427 let response = service.ready().await.unwrap().call(request).await.unwrap();
428 assert_eq!(EventStatus::Delivered, response.event_status)
429 }
430
431 #[tokio::test]
432 async fn acknowledgements_enabled_on_server() {
433 let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
434
435 let acknowledgements_config = HecClientAcknowledgementsConfig {
436 query_interval: NonZeroU8::new(1).unwrap(),
437 ..Default::default()
438 };
439 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
440
441 let mut responses = FuturesUnordered::new();
442 responses.push(service.ready().await.unwrap().call(get_hec_request()));
443 responses.push(service.ready().await.unwrap().call(get_hec_request()));
444 responses.push(service.ready().await.unwrap().call(get_hec_request()));
445 while let Some(response) = responses.next().await {
446 assert_eq!(EventStatus::Delivered, response.unwrap().event_status)
447 }
448 }
449
450 #[tokio::test]
451 async fn acknowledgements_disabled_on_server() {
452 let ack_response = |_: &Request| ResponseTemplate::new(400);
453 let mock_server = get_hec_mock_server(false, ack_response).await;
454
455 let acknowledgements_config = HecClientAcknowledgementsConfig {
456 query_interval: NonZeroU8::new(1).unwrap(),
457 ..Default::default()
458 };
459 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
460
461 let request = get_hec_request();
462 let response = service.ready().await.unwrap().call(request).await.unwrap();
463 assert_eq!(EventStatus::Delivered, response.event_status)
464 }
465
466 #[tokio::test]
467 async fn acknowledgements_enabled_on_server_retry_limit_exceeded() {
468 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
469
470 let acknowledgements_config = HecClientAcknowledgementsConfig {
471 query_interval: NonZeroU8::new(1).unwrap(),
472 retry_limit: NonZeroU8::new(1).unwrap(),
473 ..Default::default()
474 };
475 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
476
477 let request = get_hec_request();
478 let response = service.ready().await.unwrap().call(request).await.unwrap();
479 assert_eq!(EventStatus::Rejected, response.event_status)
480 }
481
482 #[tokio::test]
483 async fn acknowledgements_server_changed_ack_response_format() {
484 let ack_response = |_: &Request| {
485 ResponseTemplate::new(200)
486 .set_body_json(serde_json::json!(r#"{ "new": "a new response body" }"#))
487 };
488 let mock_server = get_hec_mock_server(true, ack_response).await;
489
490 let acknowledgements_config = HecClientAcknowledgementsConfig {
491 query_interval: NonZeroU8::new(1).unwrap(),
492 retry_limit: NonZeroU8::new(3).unwrap(),
493 ..Default::default()
494 };
495 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
496
497 let request = get_hec_request();
498 let response = service.ready().await.unwrap().call(request).await.unwrap();
499 assert_eq!(EventStatus::Delivered, response.event_status)
500 }
501
502 #[tokio::test]
503 async fn acknowledgements_enabled_on_server_ack_endpoint_failing() {
504 let ack_response = |_: &Request| ResponseTemplate::new(503);
505 let mock_server = get_hec_mock_server(true, ack_response).await;
506
507 let acknowledgements_config = HecClientAcknowledgementsConfig {
508 query_interval: NonZeroU8::new(1).unwrap(),
509 retry_limit: NonZeroU8::new(3).unwrap(),
510 ..Default::default()
511 };
512 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
513
514 let request = get_hec_request();
515 let response = service.ready().await.unwrap().call(request).await.unwrap();
516 assert_eq!(EventStatus::Errored, response.event_status)
517 }
518
519 #[tokio::test]
520 async fn acknowledgements_server_changed_event_response_format() {
521 let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
522 Mock::given(method("POST"))
524 .and(path("/services/collector/event"))
525 .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
526 .and(header_exists("X-Splunk-Request-Channel"))
527 .respond_with(move |_: &Request| {
528 ResponseTemplate::new(200).set_body_json(r#"{ "new": "a new response body" }"#)
529 })
530 .mount(&mock_server)
531 .await;
532
533 let acknowledgements_config = HecClientAcknowledgementsConfig {
534 query_interval: NonZeroU8::new(1).unwrap(),
535 retry_limit: NonZeroU8::new(1).unwrap(),
536 ..Default::default()
537 };
538 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
539
540 let request = get_hec_request();
541 let response = service.ready().await.unwrap().call(request).await.unwrap();
542 assert_eq!(EventStatus::Delivered, response.event_status)
543 }
544
545 #[tokio::test]
546 async fn service_poll_ready_multiple_times() {
547 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
548 let mut service = get_hec_service(mock_server.uri(), Default::default());
549
550 assert!(service.ready().await.is_ok());
551 assert!(service.ready().await.is_ok());
554 }
555
556 #[tokio::test]
557 #[should_panic]
558 async fn service_call_without_poll_ready() {
559 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
560 let mut service = get_hec_service(mock_server.uri(), Default::default());
561
562 _ = service.call(get_hec_request()).await;
563 }
564
565 #[tokio::test]
566 async fn acknowledgements_max_pending_acks_reached() {
567 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
568
569 let acknowledgements_config = HecClientAcknowledgementsConfig {
570 query_interval: NonZeroU8::new(1).unwrap(),
571 retry_limit: NonZeroU8::new(5).unwrap(),
572 max_pending_acks: NonZeroU64::new(1).unwrap(),
574 ..Default::default()
575 };
576 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
577
578 let pending_call = service.ready().await.unwrap().call(get_hec_request());
580 assert!(matches!(
582 poll!(poll_fn(|cx| service.poll_ready(cx))),
583 Poll::Pending
584 ));
585 let response = pending_call.await.unwrap();
587 assert_eq!(EventStatus::Rejected, response.event_status);
588 assert!(matches!(
590 poll!(poll_fn(|cx| service.poll_ready(cx))),
591 Poll::Ready(Ok(_))
592 ));
593 }
594}