vector/sinks/splunk_hec/common/
service.rs

1use std::{
2    fmt,
3    sync::Arc,
4    task::{Context, Poll, ready},
5};
6
7use bytes::Bytes;
8use futures_util::future::BoxFuture;
9use http::Request;
10use serde::{Deserialize, Serialize};
11use snafu::ResultExt;
12use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot};
13use tokio_util::sync::PollSemaphore;
14use tower::Service;
15use uuid::Uuid;
16use vector_lib::{event::EventStatus, request_metadata::MetaDescriptive};
17
18use super::{
19    EndpointTarget,
20    acknowledgements::{HecClientAcknowledgementsConfig, run_acknowledgements},
21};
22use crate::{
23    http::HttpClient,
24    internal_events::{SplunkIndexerAcknowledgementUnavailableError, SplunkResponseParseError},
25    sinks::{
26        UriParseSnafu,
27        splunk_hec::common::{build_uri, request::HecRequest, response::HecResponse},
28        util::{Compression, sink::Response},
29    },
30};
31
32pub struct HecService<S> {
33    pub inner: S,
34    ack_finalizer_tx: Option<mpsc::Sender<(u64, oneshot::Sender<EventStatus>)>>,
35    ack_slots: PollSemaphore,
36    current_ack_slot: Option<OwnedSemaphorePermit>,
37}
38
39#[derive(Deserialize, Serialize, Debug)]
40struct HecAckResponseBody {
41    #[serde(alias = "ackId")]
42    ack_id: Option<u64>,
43}
44
45impl<S> HecService<S>
46where
47    S: Service<HecRequest> + Send + 'static,
48    S::Future: Send + 'static,
49    S::Response: Response + ResponseExt + Send + 'static,
50    S::Error: fmt::Debug + Into<crate::Error> + Send,
51{
52    pub fn new(
53        inner: S,
54        ack_client: Option<HttpClient>,
55        http_request_builder: Arc<HttpRequestBuilder>,
56        indexer_acknowledgements: HecClientAcknowledgementsConfig,
57    ) -> Self {
58        let max_pending_acks = indexer_acknowledgements.max_pending_acks.get();
59        let tx = if let Some(ack_client) = ack_client {
60            let (tx, rx) = mpsc::channel(128);
61            tokio::spawn(run_acknowledgements(
62                rx,
63                ack_client,
64                Arc::clone(&http_request_builder),
65                indexer_acknowledgements,
66            ));
67            Some(tx)
68        } else {
69            None
70        };
71
72        let ack_slots = PollSemaphore::new(Arc::new(Semaphore::new(max_pending_acks as usize)));
73        Self {
74            inner,
75            ack_finalizer_tx: tx,
76            ack_slots,
77            current_ack_slot: None,
78        }
79    }
80}
81
82impl<S> Service<HecRequest> for HecService<S>
83where
84    S: Service<HecRequest> + Send + 'static,
85    S::Future: Send + 'static,
86    S::Response: Response + ResponseExt + Send + 'static,
87    S::Error: fmt::Debug + Into<crate::Error> + Send,
88{
89    type Response = HecResponse;
90    type Error = crate::Error;
91    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
92
93    fn poll_ready(&mut self, cx: &mut Context) -> std::task::Poll<Result<(), Self::Error>> {
94        // Ready if indexer acknowledgements is disabled or there is room for
95        // additional pending acks. Otherwise, wait until there is room.
96        if self.ack_finalizer_tx.is_none() || self.current_ack_slot.is_some() {
97            self.inner.poll_ready(cx).map_err(Into::into)
98        } else {
99            match ready!(self.ack_slots.poll_acquire(cx)) {
100                Some(permit) => {
101                    self.current_ack_slot.replace(permit);
102                    self.inner.poll_ready(cx).map_err(Into::into)
103                }
104                None => Poll::Ready(Err(
105                    "Indexer acknowledgements semaphore unexpectedly closed".into(),
106                )),
107            }
108        }
109    }
110
111    fn call(&mut self, mut req: HecRequest) -> Self::Future {
112        let ack_finalizer_tx = self.ack_finalizer_tx.clone();
113        let ack_slot = self.current_ack_slot.take();
114
115        let metadata = std::mem::take(req.metadata_mut());
116        let events_count = metadata.event_count();
117        let events_byte_size = metadata.into_events_estimated_json_encoded_byte_size();
118        let response = self.inner.call(req);
119
120        Box::pin(async move {
121            let response = response.await.map_err(Into::into)?;
122            let event_status = if response.is_successful() {
123                if let Some(ack_finalizer_tx) = ack_finalizer_tx {
124                    let _ack_slot = ack_slot.expect("poll_ready not called before invoking call");
125                    let body = serde_json::from_slice::<HecAckResponseBody>(response.body());
126                    match body {
127                        Ok(body) => {
128                            if let Some(ack_id) = body.ack_id {
129                                let (tx, rx) = oneshot::channel();
130                                match ack_finalizer_tx.send((ack_id, tx)).await {
131                                    Ok(_) => rx.await.unwrap_or(EventStatus::Rejected),
132                                    // If we cannot send ack ids to the ack client, fall back to default behavior
133                                    Err(error) => {
134                                        emit!(SplunkIndexerAcknowledgementUnavailableError {
135                                            error
136                                        });
137                                        EventStatus::Delivered
138                                    }
139                                }
140                            } else {
141                                // Default behavior if indexer acknowledgements is disabled on the Splunk server
142                                EventStatus::Delivered
143                            }
144                        }
145                        Err(error) => {
146                            // This may occur if Splunk changes the response format in future versions.
147                            emit!(SplunkResponseParseError { error });
148                            EventStatus::Delivered
149                        }
150                    }
151                } else {
152                    // Default behavior if indexer acknowledgements is disabled by configuration
153                    EventStatus::Delivered
154                }
155            } else if response.is_transient() {
156                EventStatus::Errored
157            } else {
158                EventStatus::Rejected
159            };
160
161            Ok(HecResponse {
162                event_status,
163                events_count,
164                events_byte_size,
165            })
166        })
167    }
168}
169
170pub trait ResponseExt {
171    fn body(&self) -> &Bytes;
172}
173
174impl ResponseExt for http::Response<Bytes> {
175    fn body(&self) -> &Bytes {
176        self.body()
177    }
178}
179
180#[derive(Clone)]
181pub struct HttpRequestBuilder {
182    pub endpoint_target: EndpointTarget,
183    pub endpoint: String,
184    pub default_token: String,
185    pub compression: Compression,
186    // A Splunk channel must be a GUID/UUID formatted value
187    // https://docs.splunk.com/Documentation/Splunk/8.2.3/Data/AboutHECIDXAck#About_channels_and_sending_data
188    pub channel: String,
189}
190
191#[derive(Default)]
192pub(super) struct MetadataFields {
193    pub(super) source: Option<String>,
194    pub(super) sourcetype: Option<String>,
195    pub(super) index: Option<String>,
196    pub(super) host: Option<String>,
197}
198
199impl HttpRequestBuilder {
200    pub fn new(
201        endpoint: String,
202        endpoint_target: EndpointTarget,
203        default_token: String,
204        compression: Compression,
205    ) -> Self {
206        let channel = Uuid::new_v4().hyphenated().to_string();
207        Self {
208            endpoint,
209            endpoint_target,
210            default_token,
211            compression,
212            channel,
213        }
214    }
215
216    pub(super) fn build_request(
217        &self,
218        body: Bytes,
219        path: &str,
220        passthrough_token: Option<Arc<str>>,
221        metadata_fields: MetadataFields,
222        auto_extract_timestamp: bool,
223    ) -> Result<Request<Bytes>, crate::Error> {
224        let uri = match self.endpoint_target {
225            EndpointTarget::Raw => {
226                // `auto_extract_timestamp` doesn't apply to the raw endpoint since the raw endpoint
227                // always does this anyway.
228                let metadata = [
229                    (super::SOURCE_FIELD, metadata_fields.source),
230                    (super::SOURCETYPE_FIELD, metadata_fields.sourcetype),
231                    (super::INDEX_FIELD, metadata_fields.index),
232                    (super::HOST_FIELD, metadata_fields.host),
233                ]
234                .into_iter()
235                .filter_map(|(key, value)| value.map(|value| (key, value)));
236                build_uri(self.endpoint.as_str(), path, metadata).context(UriParseSnafu)?
237            }
238            EndpointTarget::Event => build_uri(
239                self.endpoint.as_str(),
240                path,
241                if auto_extract_timestamp {
242                    Some((super::AUTO_EXTRACT_TIMESTAMP_FIELD, "true".to_string()))
243                } else {
244                    None
245                },
246            )
247            .context(UriParseSnafu)?,
248        };
249
250        let mut builder = Request::post(uri)
251            .header("Content-Type", "application/json")
252            .header(
253                "Authorization",
254                format!(
255                    "Splunk {}",
256                    passthrough_token.unwrap_or_else(|| self.default_token.as_str().into())
257                ),
258            )
259            .header("X-Splunk-Request-Channel", self.channel.as_str());
260
261        if let Some(ce) = self.compression.content_encoding() {
262            builder = builder.header("Content-Encoding", ce);
263        }
264
265        builder.body(body).map_err(Into::into)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use std::{
272        collections::HashMap,
273        future::poll_fn,
274        num::{NonZeroU8, NonZeroU64, NonZeroUsize},
275        sync::{
276            Arc,
277            atomic::{AtomicU64, Ordering},
278        },
279        task::Poll,
280    };
281
282    use bytes::Bytes;
283    use futures_util::{StreamExt, future::BoxFuture, poll, stream::FuturesUnordered};
284    use tower::{Service, ServiceExt};
285    use vector_lib::{
286        config::proxy::ProxyConfig,
287        event::{EventFinalizers, EventStatus},
288        internal_event::CountByteSize,
289    };
290    use wiremock::{
291        Mock, MockServer, Request, Respond, ResponseTemplate,
292        matchers::{header, header_exists, method, path},
293    };
294
295    use crate::{
296        http::HttpClient,
297        sinks::{
298            splunk_hec::common::{
299                EndpointTarget,
300                acknowledgements::{
301                    HecAckStatusRequest, HecAckStatusResponse, HecClientAcknowledgementsConfig,
302                },
303                build_http_batch_service,
304                request::HecRequest,
305                service::{HecAckResponseBody, HecService, HttpRequestBuilder},
306            },
307            util::{Compression, http::HttpBatchService, metadata::RequestMetadataBuilder},
308        },
309    };
310
311    const TOKEN: &str = "token";
312    static ACK_ID: AtomicU64 = AtomicU64::new(0);
313
314    fn get_hec_service(
315        endpoint: String,
316        acknowledgements_config: HecClientAcknowledgementsConfig,
317    ) -> HecService<
318        HttpBatchService<
319            BoxFuture<'static, Result<http::Request<Bytes>, crate::Error>>,
320            HecRequest,
321        >,
322    > {
323        let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
324        let http_request_builder = Arc::new(HttpRequestBuilder::new(
325            endpoint,
326            EndpointTarget::default(),
327            String::from(TOKEN),
328            Compression::default(),
329        ));
330        let http_service = build_http_batch_service(
331            client.clone(),
332            Arc::clone(&http_request_builder),
333            EndpointTarget::Event,
334            false,
335        );
336        HecService::new(
337            http_service,
338            Some(client),
339            http_request_builder,
340            acknowledgements_config,
341        )
342    }
343
344    fn get_hec_request() -> HecRequest {
345        let body = Bytes::from("test-message");
346        let events_byte_size = body.len();
347
348        let builder = RequestMetadataBuilder::new(
349            1,
350            events_byte_size,
351            CountByteSize(1, events_byte_size.into()).into(),
352        );
353        let bytes_len =
354            NonZeroUsize::new(events_byte_size).expect("payload should never be zero length");
355        let metadata = builder.with_request_size(bytes_len);
356
357        HecRequest {
358            body,
359            metadata,
360            finalizers: EventFinalizers::default(),
361            passthrough_token: None,
362            index: None,
363            source: None,
364            sourcetype: None,
365            host: None,
366        }
367    }
368
369    async fn get_hec_mock_server<R>(acknowledgements_enabled: bool, ack_response: R) -> MockServer
370    where
371        R: Respond + 'static,
372    {
373        // Authorization tokens and channels are required
374        let mock_server = MockServer::start().await;
375
376        Mock::given(method("POST"))
377            .and(path("/services/collector/event"))
378            .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
379            .and(header_exists("X-Splunk-Request-Channel"))
380            .respond_with(move |_: &Request| {
381                let ack_id =
382                    acknowledgements_enabled.then(|| ACK_ID.fetch_add(1, Ordering::Relaxed));
383                ResponseTemplate::new(200).set_body_json(HecAckResponseBody { ack_id })
384            })
385            .mount(&mock_server)
386            .await;
387
388        Mock::given(method("POST"))
389            .and(path("/services/collector/ack"))
390            .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
391            .and(header_exists("X-Splunk-Request-Channel"))
392            .respond_with(ack_response)
393            .mount(&mock_server)
394            .await;
395
396        mock_server
397    }
398
399    fn ack_response_always_succeed(req: &Request) -> ResponseTemplate {
400        let req = serde_json::from_slice::<HecAckStatusRequest>(req.body.as_slice()).unwrap();
401        ResponseTemplate::new(200).set_body_json(HecAckStatusResponse {
402            acks: req
403                .acks
404                .into_iter()
405                .map(|ack_id| (ack_id, true))
406                .collect::<HashMap<_, _>>(),
407        })
408    }
409
410    fn ack_response_always_fail(req: &Request) -> ResponseTemplate {
411        let req = serde_json::from_slice::<HecAckStatusRequest>(req.body.as_slice()).unwrap();
412        ResponseTemplate::new(200).set_body_json(HecAckStatusResponse {
413            acks: req
414                .acks
415                .into_iter()
416                .map(|ack_id| (ack_id, false))
417                .collect::<HashMap<_, _>>(),
418        })
419    }
420
421    #[tokio::test]
422    async fn acknowledgements_disabled_in_config() {
423        let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
424
425        let acknowledgements_config = HecClientAcknowledgementsConfig {
426            indexer_acknowledgements_enabled: false,
427            ..Default::default()
428        };
429        let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
430
431        let request = get_hec_request();
432        let response = service.ready().await.unwrap().call(request).await.unwrap();
433        assert_eq!(EventStatus::Delivered, response.event_status)
434    }
435
436    #[tokio::test]
437    async fn acknowledgements_enabled_on_server() {
438        let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
439
440        let acknowledgements_config = HecClientAcknowledgementsConfig {
441            query_interval: NonZeroU8::new(1).unwrap(),
442            ..Default::default()
443        };
444        let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
445
446        let mut responses = FuturesUnordered::new();
447        responses.push(service.ready().await.unwrap().call(get_hec_request()));
448        responses.push(service.ready().await.unwrap().call(get_hec_request()));
449        responses.push(service.ready().await.unwrap().call(get_hec_request()));
450        while let Some(response) = responses.next().await {
451            assert_eq!(EventStatus::Delivered, response.unwrap().event_status)
452        }
453    }
454
455    #[tokio::test]
456    async fn acknowledgements_disabled_on_server() {
457        let ack_response = |_: &Request| ResponseTemplate::new(400);
458        let mock_server = get_hec_mock_server(false, ack_response).await;
459
460        let acknowledgements_config = HecClientAcknowledgementsConfig {
461            query_interval: NonZeroU8::new(1).unwrap(),
462            ..Default::default()
463        };
464        let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
465
466        let request = get_hec_request();
467        let response = service.ready().await.unwrap().call(request).await.unwrap();
468        assert_eq!(EventStatus::Delivered, response.event_status)
469    }
470
471    #[tokio::test]
472    async fn acknowledgements_enabled_on_server_retry_limit_exceeded() {
473        let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
474
475        let acknowledgements_config = HecClientAcknowledgementsConfig {
476            query_interval: NonZeroU8::new(1).unwrap(),
477            retry_limit: NonZeroU8::new(1).unwrap(),
478            ..Default::default()
479        };
480        let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
481
482        let request = get_hec_request();
483        let response = service.ready().await.unwrap().call(request).await.unwrap();
484        assert_eq!(EventStatus::Rejected, response.event_status)
485    }
486
487    #[tokio::test]
488    async fn acknowledgements_server_changed_ack_response_format() {
489        let ack_response = |_: &Request| {
490            ResponseTemplate::new(200)
491                .set_body_json(serde_json::json!(r#"{ "new": "a new response body" }"#))
492        };
493        let mock_server = get_hec_mock_server(true, ack_response).await;
494
495        let acknowledgements_config = HecClientAcknowledgementsConfig {
496            query_interval: NonZeroU8::new(1).unwrap(),
497            retry_limit: NonZeroU8::new(3).unwrap(),
498            ..Default::default()
499        };
500        let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
501
502        let request = get_hec_request();
503        let response = service.ready().await.unwrap().call(request).await.unwrap();
504        assert_eq!(EventStatus::Delivered, response.event_status)
505    }
506
507    #[tokio::test]
508    async fn acknowledgements_enabled_on_server_ack_endpoint_failing() {
509        let ack_response = |_: &Request| ResponseTemplate::new(503);
510        let mock_server = get_hec_mock_server(true, ack_response).await;
511
512        let acknowledgements_config = HecClientAcknowledgementsConfig {
513            query_interval: NonZeroU8::new(1).unwrap(),
514            retry_limit: NonZeroU8::new(3).unwrap(),
515            ..Default::default()
516        };
517        let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
518
519        let request = get_hec_request();
520        let response = service.ready().await.unwrap().call(request).await.unwrap();
521        assert_eq!(EventStatus::Errored, response.event_status)
522    }
523
524    #[tokio::test]
525    async fn acknowledgements_server_changed_event_response_format() {
526        let mock_server = get_hec_mock_server(true, ack_response_always_succeed).await;
527        // Override the usual event endpoint
528        Mock::given(method("POST"))
529            .and(path("/services/collector/event"))
530            .and(header("Authorization", format!("Splunk {TOKEN}").as_str()))
531            .and(header_exists("X-Splunk-Request-Channel"))
532            .respond_with(move |_: &Request| {
533                ResponseTemplate::new(200).set_body_json(r#"{ "new": "a new response body" }"#)
534            })
535            .mount(&mock_server)
536            .await;
537
538        let acknowledgements_config = HecClientAcknowledgementsConfig {
539            query_interval: NonZeroU8::new(1).unwrap(),
540            retry_limit: NonZeroU8::new(1).unwrap(),
541            ..Default::default()
542        };
543        let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
544
545        let request = get_hec_request();
546        let response = service.ready().await.unwrap().call(request).await.unwrap();
547        assert_eq!(EventStatus::Delivered, response.event_status)
548    }
549
550    #[tokio::test]
551    async fn service_poll_ready_multiple_times() {
552        let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
553        let mut service = get_hec_service(mock_server.uri(), Default::default());
554
555        assert!(service.ready().await.is_ok());
556        // Consecutive poll_ready returns OK since an ack slot has been granted
557        // but has not been used (call has not been invoked)
558        assert!(service.ready().await.is_ok());
559    }
560
561    #[tokio::test]
562    #[should_panic]
563    async fn service_call_without_poll_ready() {
564        let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
565        let mut service = get_hec_service(mock_server.uri(), Default::default());
566
567        _ = service.call(get_hec_request()).await;
568    }
569
570    #[tokio::test]
571    async fn acknowledgements_max_pending_acks_reached() {
572        let mock_server = get_hec_mock_server(true, ack_response_always_fail).await;
573
574        let acknowledgements_config = HecClientAcknowledgementsConfig {
575            query_interval: NonZeroU8::new(1).unwrap(),
576            retry_limit: NonZeroU8::new(5).unwrap(),
577            // Allow a single pending ack
578            max_pending_acks: NonZeroU64::new(1).unwrap(),
579            ..Default::default()
580        };
581        let mut service = get_hec_service(mock_server.uri(), acknowledgements_config);
582
583        // Grab the one available ack slot
584        let pending_call = service.ready().await.unwrap().call(get_hec_request());
585        // The service should return pending for additional requests
586        assert!(matches!(
587            poll!(poll_fn(|cx| service.poll_ready(cx))),
588            Poll::Pending
589        ));
590        // Complete the call to free up the slot
591        let response = pending_call.await.unwrap();
592        assert_eq!(EventStatus::Rejected, response.event_status);
593        // The service should now be ready for additional requests
594        assert!(matches!(
595            poll!(poll_fn(|cx| service.poll_ready(cx))),
596            Poll::Ready(Ok(_))
597        ));
598    }
599}