1use std::{
2 fmt,
3 sync::Arc,
4 task::{Context, Poll, ready},
5};
6
7use bytes::Bytes;
8use futures_util::future::BoxFuture;
9use http::Request;
10use serde::{Deserialize, Serialize};
11use snafu::ResultExt;
12use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot};
13use tokio_util::sync::PollSemaphore;
14use tower::Service;
15use uuid::Uuid;
16use vector_lib::{event::EventStatus, request_metadata::MetaDescriptive};
17
18use super::{
19 EndpointTarget,
20 acknowledgements::{HecClientAcknowledgementsConfig, run_acknowledgements},
21};
22use crate::{
23 http::HttpClient,
24 internal_events::{SplunkIndexerAcknowledgementUnavailableError, SplunkResponseParseError},
25 sinks::{
26 UriParseSnafu,
27 splunk_hec::common::{build_uri, request::HecRequest, response::HecResponse},
28 util::{Compression, sink::Response},
29 },
30};
31
32pub struct HecService<S> {
33 pub inner: S,
34 ack_finalizer_tx: Option<mpsc::Sender<(u64, oneshot::Sender<EventStatus>)>>,
35 ack_slots: PollSemaphore,
36 current_ack_slot: Option<OwnedSemaphorePermit>,
37}
38
39#[derive(Deserialize, Serialize, Debug)]
40struct HecAckResponseBody {
41 #[serde(alias = "ackId")]
42 ack_id: Option<u64>,
43}
44
45impl<S> HecService<S>
46where
47 S: Service<HecRequest> + Send + 'static,
48 S::Future: Send + 'static,
49 S::Response: Response + ResponseExt + Send + 'static,
50 S::Error: fmt::Debug + Into<crate::Error> + Send,
51{
52 pub fn new(
53 inner: S,
54 ack_client: Option<HttpClient>,
55 http_request_builder: Arc<HttpRequestBuilder>,
56 indexer_acknowledgements: HecClientAcknowledgementsConfig,
57 ) -> Self {
58 let max_pending_acks = indexer_acknowledgements.max_pending_acks.get();
59 let tx = if let Some(ack_client) = ack_client {
60 let (tx, rx) = mpsc::channel(128);
61 tokio::spawn(run_acknowledgements(
62 rx,
63 ack_client,
64 Arc::clone(&http_request_builder),
65 indexer_acknowledgements,
66 ));
67 Some(tx)
68 } else {
69 None
70 };
71
72 let ack_slots = PollSemaphore::new(Arc::new(Semaphore::new(max_pending_acks as usize)));
73 Self {
74 inner,
75 ack_finalizer_tx: tx,
76 ack_slots,
77 current_ack_slot: None,
78 }
79 }
80}
81
82impl<S> Service<HecRequest> for HecService<S>
83where
84 S: Service<HecRequest> + Send + 'static,
85 S::Future: Send + 'static,
86 S::Response: Response + ResponseExt + Send + 'static,
87 S::Error: fmt::Debug + Into<crate::Error> + Send,
88{
89 type Response = HecResponse;
90 type Error = crate::Error;
91 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
92
93 fn poll_ready(&mut self, cx: &mut Context) -> std::task::Poll<Result<(), Self::Error>> {
94 if self.ack_finalizer_tx.is_none() || self.current_ack_slot.is_some() {
97 self.inner.poll_ready(cx).map_err(Into::into)
98 } else {
99 match ready!(self.ack_slots.poll_acquire(cx)) {
100 Some(permit) => {
101 self.current_ack_slot.replace(permit);
102 self.inner.poll_ready(cx).map_err(Into::into)
103 }
104 None => Poll::Ready(Err(
105 "Indexer acknowledgements semaphore unexpectedly closed".into(),
106 )),
107 }
108 }
109 }
110
111 fn call(&mut self, mut req: HecRequest) -> Self::Future {
112 let ack_finalizer_tx = self.ack_finalizer_tx.clone();
113 let ack_slot = self.current_ack_slot.take();
114
115 let metadata = std::mem::take(req.metadata_mut());
116 let events_count = metadata.event_count();
117 let events_byte_size = metadata.into_events_estimated_json_encoded_byte_size();
118 let response = self.inner.call(req);
119
120 Box::pin(async move {
121 let response = response.await.map_err(Into::into)?;
122 let event_status = if response.is_successful() {
123 if let Some(ack_finalizer_tx) = ack_finalizer_tx {
124 let _ack_slot = ack_slot.expect("poll_ready not called before invoking call");
125 let body = serde_json::from_slice::<HecAckResponseBody>(response.body());
126 match body {
127 Ok(body) => {
128 if let Some(ack_id) = body.ack_id {
129 let (tx, rx) = oneshot::channel();
130 match ack_finalizer_tx.send((ack_id, tx)).await {
131 Ok(_) => rx.await.unwrap_or(EventStatus::Rejected),
132 Err(error) => {
134 emit!(SplunkIndexerAcknowledgementUnavailableError {
135 error
136 });
137 EventStatus::Delivered
138 }
139 }
140 } else {
141 EventStatus::Delivered
143 }
144 }
145 Err(error) => {
146 emit!(SplunkResponseParseError { error });
148 EventStatus::Delivered
149 }
150 }
151 } else {
152 EventStatus::Delivered
154 }
155 } else if response.is_transient() {
156 EventStatus::Errored
157 } else {
158 EventStatus::Rejected
159 };
160
161 Ok(HecResponse {
162 event_status,
163 events_count,
164 events_byte_size,
165 })
166 })
167 }
168}
169
170pub trait ResponseExt {
171 fn body(&self) -> &Bytes;
172}
173
174impl ResponseExt for http::Response<Bytes> {
175 fn body(&self) -> &Bytes {
176 self.body()
177 }
178}
179
180#[derive(Clone)]
181pub struct HttpRequestBuilder {
182 pub endpoint_target: EndpointTarget,
183 pub endpoint: String,
184 pub default_token: String,
185 pub compression: Compression,
186 pub channel: String,
189}
190
191#[derive(Default)]
192pub(super) struct MetadataFields {
193 pub(super) source: Option<String>,
194 pub(super) sourcetype: Option<String>,
195 pub(super) index: Option<String>,
196 pub(super) host: Option<String>,
197}
198
199impl HttpRequestBuilder {
200 pub fn new(
201 endpoint: String,
202 endpoint_target: EndpointTarget,
203 default_token: String,
204 compression: Compression,
205 ) -> Self {
206 let channel = Uuid::new_v4().hyphenated().to_string();
207 Self {
208 endpoint,
209 endpoint_target,
210 default_token,
211 compression,
212 channel,
213 }
214 }
215
216 pub(super) fn build_request(
217 &self,
218 body: Bytes,
219 path: &str,
220 passthrough_token: Option<Arc<str>>,
221 metadata_fields: MetadataFields,
222 auto_extract_timestamp: bool,
223 ) -> Result<Request<Bytes>, crate::Error> {
224 let uri = match self.endpoint_target {
225 EndpointTarget::Raw => {
226 let metadata = [
229 (super::SOURCE_FIELD, metadata_fields.source),
230 (super::SOURCETYPE_FIELD, metadata_fields.sourcetype),
231 (super::INDEX_FIELD, metadata_fields.index),
232 (super::HOST_FIELD, metadata_fields.host),
233 ]
234 .into_iter()
235 .filter_map(|(key, value)| value.map(|value| (key, value)));
236 build_uri(self.endpoint.as_str(), path, metadata).context(UriParseSnafu)?
237 }
238 EndpointTarget::Event => build_uri(
239 self.endpoint.as_str(),
240 path,
241 if auto_extract_timestamp {
242 Some((super::AUTO_EXTRACT_TIMESTAMP_FIELD, "true".to_string()))
243 } else {
244 None
245 },
246 )
247 .context(UriParseSnafu)?,
248 };
249
250 let mut builder = Request::post(uri)
251 .header("Content-Type", "application/json")
252 .header(
253 "Authorization",
254 format!(
255 "Splunk {}",
256 passthrough_token.unwrap_or_else(|| self.default_token.as_str().into())
257 ),
258 )
259 .header("X-Splunk-Request-Channel", self.channel.as_str());
260
261 if let Some(ce) = self.compression.content_encoding() {
262 builder = builder.header("Content-Encoding", ce);
263 }
264
265 builder.body(body).map_err(Into::into)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use std::{
272 collections::HashMap,
273 future::poll_fn,
274 num::{NonZeroU8, NonZeroU64, NonZeroUsize},
275 sync::{
276 Arc,
277 atomic::{AtomicU64, Ordering},
278 },
279 task::Poll,
280 };
281
282 use bytes::Bytes;
283 use futures_util::{StreamExt, future::BoxFuture, poll, stream::FuturesUnordered};
284 use tower::{Service, ServiceExt};
285 use vector_lib::{
286 config::proxy::ProxyConfig,
287 event::{EventFinalizers, EventStatus},
288 internal_event::CountByteSize,
289 };
290 use wiremock::{
291 Mock, MockServer, Request, Respond, ResponseTemplate,
292 matchers::{header, header_exists, method, path},
293 };
294
295 use crate::{
296 http::HttpClient,
297 sinks::{
298 splunk_hec::common::{
299 EndpointTarget,
300 acknowledgements::{
301 HecAckStatusRequest, HecAckStatusResponse, HecClientAcknowledgementsConfig,
302 },
303 build_http_batch_service,
304 request::HecRequest,
305 service::{HecAckResponseBody, HecService, HttpRequestBuilder},
306 },
307 util::{Compression, http::HttpBatchService, metadata::RequestMetadataBuilder},
308 },
309 };
310
311 const TOKEN: &str = "token";
312 static ACK_ID: AtomicU64 = AtomicU64::new(0);
313
314 fn get_hec_service(
315 endpoint: String,
316 acknowledgements_config: HecClientAcknowledgementsConfig,
317 ) -> HecService<
318 HttpBatchService<
319 BoxFuture<'static, Result<http::Request<Bytes>, crate::Error>>,
320 HecRequest,
321 >,
322 > {
323 let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
324 let http_request_builder = Arc::new(HttpRequestBuilder::new(
325 endpoint,
326 EndpointTarget::default(),
327 String::from(TOKEN),
328 Compression::default(),
329 ));
330 let http_service = build_http_batch_service(
331 client.clone(),
332 Arc::clone(&http_request_builder),
333 EndpointTarget::Event,
334 false,
335 );
336 HecService::new(
337 http_service,
338 Some(client),
339 http_request_builder,
340 acknowledgements_config,
341 )
342 }
343
344 fn get_hec_request() -> HecRequest {
345 let body = Bytes::from("test-message");
346 let events_byte_size = body.len();
347
348 let builder = RequestMetadataBuilder::new(
349 1,
350 events_byte_size,
351 CountByteSize(1, events_byte_size.into()).into(),
352 );
353 let bytes_len =
354 NonZeroUsize::new(events_byte_size).expect("payload should never be zero length");
355 let metadata = builder.with_request_size(bytes_len);
356
357 HecRequest {
358 body,
359 metadata,
360 finalizers: EventFinalizers::default(),
361 passthrough_token: None,
362 index: None,
363 source: None,
364 sourcetype: None,
365 host: None,
366 }
367 }
368
369 async fn get_hec_mock_server<R>(acknowledgements_enabled: bool, ack_response: R) -> MockServer
370 where
371 R: Respond + 'static,
372 {
373 let mock_server = MockServer::start().await;
375
376 Mock::given(method("POST"))
377 .and(path("/services/collector/event"))
378 .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
379 .and(header_exists("X-Splunk-Request-Channel"))
380 .respond_with(move |_: &Request| {
381 let ack_id =
382 acknowledgements_enabled.then(|| ACK_ID.fetch_add(1, Ordering::Relaxed));
383 ResponseTemplate::new(200).set_body_json(HecAckResponseBody { ack_id })
384 })
385 .mount(&mock_server)
386 .await;
387
388 Mock::given(method("POST"))
389 .and(path("/services/collector/ack"))
390 .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
391 .and(header_exists("X-Splunk-Request-Channel"))
392 .respond_with(ack_response)
393 .mount(&mock_server)
394 .await;
395
396 mock_server
397 }
398
399 fn ack_response_always_succeed(req: &Request) -> ResponseTemplate {
400 let req = serde_json::from_slice::<HecAckStatusRequest>(req.body.as_slice()).unwrap();
401 ResponseTemplate::new(200).set_body_json(HecAckStatusResponse {
402 acks: req
403 .acks
404 .into_iter()
405 .map(|ack_id| (ack_id, true))
406 .collect::<HashMap<_, _>>(),
407 })
408 }
409
410 fn ack_response_always_fail(req: &Request) -> ResponseTemplate {
411 let req = serde_json::from_slice::<HecAckStatusRequest>(req.body.as_slice()).unwrap();
412 ResponseTemplate::new(200).set_body_json(HecAckStatusResponse {
413 acks: req
414 .acks
415 .into_iter()
416 .map(|ack_id| (ack_id, false))
417 .collect::<HashMap<_, _>>(),
418 })
419 }
420
421 #[tokio::test]
422 async fn acknowledgements_disabled_in_config() {
423 let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
424
425 let acknowledgements_config = HecClientAcknowledgementsConfig {
426 indexer_acknowledgements_enabled: false,
427 ..Default::default()
428 };
429 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
430
431 let request = get_hec_request();
432 let response = service.ready().await.unwrap().call(request).await.unwrap();
433 assert_eq!(EventStatus::Delivered, response.event_status)
434 }
435
436 #[tokio::test]
437 async fn acknowledgements_enabled_on_server() {
438 let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
439
440 let acknowledgements_config = HecClientAcknowledgementsConfig {
441 query_interval: NonZeroU8::new(1).unwrap(),
442 ..Default::default()
443 };
444 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
445
446 let mut responses = FuturesUnordered::new();
447 responses.push(service.ready().await.unwrap().call(get_hec_request()));
448 responses.push(service.ready().await.unwrap().call(get_hec_request()));
449 responses.push(service.ready().await.unwrap().call(get_hec_request()));
450 while let Some(response) = responses.next().await {
451 assert_eq!(EventStatus::Delivered, response.unwrap().event_status)
452 }
453 }
454
455 #[tokio::test]
456 async fn acknowledgements_disabled_on_server() {
457 let ack_response = |_: &Request| ResponseTemplate::new(400);
458 let mock_server = get_hec_mock_server(false, ack_response).await;
459
460 let acknowledgements_config = HecClientAcknowledgementsConfig {
461 query_interval: NonZeroU8::new(1).unwrap(),
462 ..Default::default()
463 };
464 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
465
466 let request = get_hec_request();
467 let response = service.ready().await.unwrap().call(request).await.unwrap();
468 assert_eq!(EventStatus::Delivered, response.event_status)
469 }
470
471 #[tokio::test]
472 async fn acknowledgements_enabled_on_server_retry_limit_exceeded() {
473 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
474
475 let acknowledgements_config = HecClientAcknowledgementsConfig {
476 query_interval: NonZeroU8::new(1).unwrap(),
477 retry_limit: NonZeroU8::new(1).unwrap(),
478 ..Default::default()
479 };
480 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
481
482 let request = get_hec_request();
483 let response = service.ready().await.unwrap().call(request).await.unwrap();
484 assert_eq!(EventStatus::Rejected, response.event_status)
485 }
486
487 #[tokio::test]
488 async fn acknowledgements_server_changed_ack_response_format() {
489 let ack_response = |_: &Request| {
490 ResponseTemplate::new(200)
491 .set_body_json(serde_json::json!(r#"{ "new": "a new response body" }"#))
492 };
493 let mock_server = get_hec_mock_server(true, ack_response).await;
494
495 let acknowledgements_config = HecClientAcknowledgementsConfig {
496 query_interval: NonZeroU8::new(1).unwrap(),
497 retry_limit: NonZeroU8::new(3).unwrap(),
498 ..Default::default()
499 };
500 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
501
502 let request = get_hec_request();
503 let response = service.ready().await.unwrap().call(request).await.unwrap();
504 assert_eq!(EventStatus::Delivered, response.event_status)
505 }
506
507 #[tokio::test]
508 async fn acknowledgements_enabled_on_server_ack_endpoint_failing() {
509 let ack_response = |_: &Request| ResponseTemplate::new(503);
510 let mock_server = get_hec_mock_server(true, ack_response).await;
511
512 let acknowledgements_config = HecClientAcknowledgementsConfig {
513 query_interval: NonZeroU8::new(1).unwrap(),
514 retry_limit: NonZeroU8::new(3).unwrap(),
515 ..Default::default()
516 };
517 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
518
519 let request = get_hec_request();
520 let response = service.ready().await.unwrap().call(request).await.unwrap();
521 assert_eq!(EventStatus::Errored, response.event_status)
522 }
523
524 #[tokio::test]
525 async fn acknowledgements_server_changed_event_response_format() {
526 let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
527 Mock::given(method("POST"))
529 .and(path("/services/collector/event"))
530 .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
531 .and(header_exists("X-Splunk-Request-Channel"))
532 .respond_with(move |_: &Request| {
533 ResponseTemplate::new(200).set_body_json(r#"{ "new": "a new response body" }"#)
534 })
535 .mount(&mock_server)
536 .await;
537
538 let acknowledgements_config = HecClientAcknowledgementsConfig {
539 query_interval: NonZeroU8::new(1).unwrap(),
540 retry_limit: NonZeroU8::new(1).unwrap(),
541 ..Default::default()
542 };
543 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
544
545 let request = get_hec_request();
546 let response = service.ready().await.unwrap().call(request).await.unwrap();
547 assert_eq!(EventStatus::Delivered, response.event_status)
548 }
549
550 #[tokio::test]
551 async fn service_poll_ready_multiple_times() {
552 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
553 let mut service = get_hec_service(mock_server.uri(), Default::default());
554
555 assert!(service.ready().await.is_ok());
556 assert!(service.ready().await.is_ok());
559 }
560
561 #[tokio::test]
562 #[should_panic]
563 async fn service_call_without_poll_ready() {
564 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
565 let mut service = get_hec_service(mock_server.uri(), Default::default());
566
567 _ = service.call(get_hec_request()).await;
568 }
569
570 #[tokio::test]
571 async fn acknowledgements_max_pending_acks_reached() {
572 let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
573
574 let acknowledgements_config = HecClientAcknowledgementsConfig {
575 query_interval: NonZeroU8::new(1).unwrap(),
576 retry_limit: NonZeroU8::new(5).unwrap(),
577 max_pending_acks: NonZeroU64::new(1).unwrap(),
579 ..Default::default()
580 };
581 let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
582
583 let pending_call = service.ready().await.unwrap().call(get_hec_request());
585 assert!(matches!(
587 poll!(poll_fn(|cx| service.poll_ready(cx))),
588 Poll::Pending
589 ));
590 let response = pending_call.await.unwrap();
592 assert_eq!(EventStatus::Rejected, response.event_status);
593 assert!(matches!(
595 poll!(poll_fn(|cx| service.poll_ready(cx))),
596 Poll::Ready(Ok(_))
597 ));
598 }
599}