vector/
http.rs

1#![allow(missing_docs)]
2use std::{
3    collections::HashMap,
4    fmt,
5    net::SocketAddr,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use futures::future::BoxFuture;
11use headers::{Authorization, HeaderMapExt};
12use http::{
13    HeaderMap, Request, Response, Uri, Version, header::HeaderValue, request::Builder,
14    uri::InvalidUri,
15};
16use hyper::{
17    body::{Body, HttpBody},
18    client,
19    client::{Client, HttpConnector},
20};
21use hyper_openssl::HttpsConnector;
22use hyper_proxy::ProxyConnector;
23use rand::Rng;
24use serde_with::serde_as;
25use snafu::{ResultExt, Snafu};
26use tokio::time::Instant;
27use tower::{Layer, Service};
28use tower_http::{
29    classify::{ServerErrorsAsFailures, SharedClassifier},
30    trace::TraceLayer,
31};
32use tracing::{Instrument, Span};
33use vector_lib::{configurable::configurable_component, sensitive_string::SensitiveString};
34
35#[cfg(feature = "aws-core")]
36use crate::aws::AwsAuthentication;
37use crate::{
38    config::ProxyConfig,
39    internal_events::{HttpServerRequestReceived, HttpServerResponseSent, http_client},
40    tls::{MaybeTlsSettings, TlsError, tls_connector_builder},
41};
42
43pub mod status {
44    pub const FORBIDDEN: u16 = 403;
45    pub const NOT_FOUND: u16 = 404;
46    pub const TOO_MANY_REQUESTS: u16 = 429;
47}
48
49#[derive(Debug, Snafu)]
50#[snafu(visibility(pub(crate)))]
51pub enum HttpError {
52    #[snafu(display("Failed to build TLS connector: {}", source))]
53    BuildTlsConnector { source: TlsError },
54    #[snafu(display("Failed to build HTTPS connector: {}", source))]
55    MakeHttpsConnector { source: openssl::error::ErrorStack },
56    #[snafu(display("Failed to build Proxy connector: {}", source))]
57    MakeProxyConnector { source: InvalidUri },
58    #[snafu(display("Failed to make HTTP(S) request: {}", source))]
59    CallRequest { source: hyper::Error },
60    #[snafu(display("Failed to build HTTP request: {}", source))]
61    BuildRequest { source: http::Error },
62}
63
64impl HttpError {
65    pub const fn is_retriable(&self) -> bool {
66        match self {
67            HttpError::BuildRequest { .. } | HttpError::MakeProxyConnector { .. } => false,
68            HttpError::CallRequest { .. }
69            | HttpError::BuildTlsConnector { .. }
70            | HttpError::MakeHttpsConnector { .. } => true,
71        }
72    }
73}
74
75pub type HttpClientFuture = <HttpClient as Service<http::Request<Body>>>::Future;
76type HttpProxyConnector = ProxyConnector<HttpsConnector<HttpConnector>>;
77
78pub struct HttpClient<B = Body> {
79    client: Client<HttpProxyConnector, B>,
80    user_agent: HeaderValue,
81    proxy_connector: HttpProxyConnector,
82}
83
84impl<B> HttpClient<B>
85where
86    B: fmt::Debug + HttpBody + Send + 'static,
87    B::Data: Send,
88    B::Error: Into<crate::Error>,
89{
90    pub fn new(
91        tls_settings: impl Into<MaybeTlsSettings>,
92        proxy_config: &ProxyConfig,
93    ) -> Result<HttpClient<B>, HttpError> {
94        HttpClient::new_with_custom_client(tls_settings, proxy_config, &mut Client::builder())
95    }
96
97    pub fn new_with_custom_client(
98        tls_settings: impl Into<MaybeTlsSettings>,
99        proxy_config: &ProxyConfig,
100        client_builder: &mut client::Builder,
101    ) -> Result<HttpClient<B>, HttpError> {
102        let proxy_connector = build_proxy_connector(tls_settings.into(), proxy_config)?;
103        let client = client_builder.build(proxy_connector.clone());
104
105        let app_name = crate::get_app_name();
106        let version = crate::get_version();
107        let user_agent = HeaderValue::from_str(&format!("{app_name}/{version}"))
108            .expect("Invalid header value for user-agent!");
109
110        Ok(HttpClient {
111            client,
112            user_agent,
113            proxy_connector,
114        })
115    }
116
117    pub fn send(
118        &self,
119        mut request: Request<B>,
120    ) -> BoxFuture<'static, Result<http::Response<Body>, HttpError>> {
121        let span = tracing::info_span!("http");
122        let _enter = span.enter();
123
124        default_request_headers(&mut request, &self.user_agent);
125        self.maybe_add_proxy_headers(&mut request);
126
127        emit!(http_client::AboutToSendHttpRequest { request: &request });
128
129        let response = self.client.request(request);
130
131        let fut = async move {
132            // Capture the time right before we issue the request.
133            // Request doesn't start the processing until we start polling it.
134            let before = std::time::Instant::now();
135
136            // Send request and wait for the result.
137            let response_result = response.await;
138
139            // Compute the roundtrip time it took to send the request and get
140            // the response or error.
141            let roundtrip = before.elapsed();
142
143            // Handle the errors and extract the response.
144            let response = response_result
145                .inspect_err(|error| {
146                    // Emit the error into the internal events system.
147                    emit!(http_client::GotHttpWarning { error, roundtrip });
148                })
149                .context(CallRequestSnafu)?;
150
151            // Emit the response into the internal events system.
152            emit!(http_client::GotHttpResponse {
153                response: &response,
154                roundtrip
155            });
156            Ok(response)
157        }
158        .instrument(span.clone().or_current());
159
160        Box::pin(fut)
161    }
162
163    fn maybe_add_proxy_headers(&self, request: &mut Request<B>) {
164        if let Some(proxy_headers) = self.proxy_connector.http_headers(request.uri()) {
165            for (k, v) in proxy_headers {
166                let request_headers = request.headers_mut();
167                if !request_headers.contains_key(k) {
168                    request_headers.insert(k, v.into());
169                }
170            }
171        }
172    }
173}
174
175pub fn build_proxy_connector(
176    tls_settings: MaybeTlsSettings,
177    proxy_config: &ProxyConfig,
178) -> Result<ProxyConnector<HttpsConnector<HttpConnector>>, HttpError> {
179    // Create dedicated TLS connector for the proxied connection with user TLS settings.
180    let tls = tls_connector_builder(&tls_settings)
181        .context(BuildTlsConnectorSnafu)?
182        .build();
183    let https = build_tls_connector(tls_settings)?;
184    let mut proxy = ProxyConnector::new(https).unwrap();
185    // Make proxy connector aware of user TLS settings by setting the TLS connector:
186    // https://github.com/vectordotdev/vector/issues/13683
187    proxy.set_tls(Some(tls));
188    proxy_config
189        .configure(&mut proxy)
190        .context(MakeProxyConnectorSnafu)?;
191    Ok(proxy)
192}
193
194pub fn build_tls_connector(
195    tls_settings: MaybeTlsSettings,
196) -> Result<HttpsConnector<HttpConnector>, HttpError> {
197    let mut http = HttpConnector::new();
198    http.enforce_http(false);
199
200    let tls = tls_connector_builder(&tls_settings).context(BuildTlsConnectorSnafu)?;
201    let mut https = HttpsConnector::with_connector(http, tls).context(MakeHttpsConnectorSnafu)?;
202
203    let settings = tls_settings.tls().cloned();
204    https.set_callback(move |c, _uri| {
205        if let Some(settings) = &settings {
206            settings.apply_connect_configuration(c)
207        } else {
208            Ok(())
209        }
210    });
211    Ok(https)
212}
213
214fn default_request_headers<B>(request: &mut Request<B>, user_agent: &HeaderValue) {
215    if !request.headers().contains_key("User-Agent") {
216        request
217            .headers_mut()
218            .insert("User-Agent", user_agent.clone());
219    }
220
221    if !request.headers().contains_key("Accept-Encoding") {
222        // hardcoding until we support compressed responses:
223        // https://github.com/vectordotdev/vector/issues/5440
224        request
225            .headers_mut()
226            .insert("Accept-Encoding", HeaderValue::from_static("identity"));
227    }
228}
229
230impl<B> Service<Request<B>> for HttpClient<B>
231where
232    B: fmt::Debug + HttpBody + Send + 'static,
233    B::Data: Send,
234    B::Error: Into<crate::Error> + Send,
235{
236    type Response = http::Response<Body>;
237    type Error = HttpError;
238    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
239
240    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
241        Poll::Ready(Ok(()))
242    }
243
244    fn call(&mut self, request: Request<B>) -> Self::Future {
245        self.send(request)
246    }
247}
248
249impl<B> Clone for HttpClient<B> {
250    fn clone(&self) -> Self {
251        Self {
252            client: self.client.clone(),
253            user_agent: self.user_agent.clone(),
254            proxy_connector: self.proxy_connector.clone(),
255        }
256    }
257}
258
259impl<B> fmt::Debug for HttpClient<B> {
260    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261        f.debug_struct("HttpClient")
262            .field("client", &self.client)
263            .field("user_agent", &self.user_agent)
264            .finish()
265    }
266}
267
268/// Configuration of the authentication strategy for HTTP requests.
269///
270/// HTTP authentication should be used with HTTPS only, as the authentication credentials are passed as an
271/// HTTP header without any additional encryption beyond what is provided by the transport itself.
272#[configurable_component]
273#[derive(Clone, Debug, Eq, PartialEq)]
274#[serde(deny_unknown_fields, rename_all = "snake_case", tag = "strategy")]
275#[configurable(metadata(docs::enum_tag_description = "The authentication strategy to use."))]
276pub enum Auth {
277    /// Basic authentication.
278    ///
279    /// The username and password are concatenated and encoded using [base64][base64].
280    ///
281    /// [base64]: https://en.wikipedia.org/wiki/Base64
282    Basic {
283        /// The basic authentication username.
284        #[configurable(metadata(docs::examples = "${USERNAME}"))]
285        #[configurable(metadata(docs::examples = "username"))]
286        user: String,
287
288        /// The basic authentication password.
289        #[configurable(metadata(docs::examples = "${PASSWORD}"))]
290        #[configurable(metadata(docs::examples = "password"))]
291        password: SensitiveString,
292    },
293
294    /// Bearer authentication.
295    ///
296    /// The bearer token value (OAuth2, JWT, etc.) is passed as-is.
297    Bearer {
298        /// The bearer authentication token.
299        token: SensitiveString,
300    },
301
302    #[cfg(feature = "aws-core")]
303    /// AWS authentication.
304    Aws {
305        /// The AWS authentication configuration.
306        auth: AwsAuthentication,
307
308        /// The AWS service name to use for signing.
309        service: String,
310    },
311
312    /// Custom Authorization Header Value, will be inserted into the headers as `Authorization: < value >`
313    Custom {
314        /// Custom string value of the Authorization header
315        #[configurable(metadata(docs::examples = "${AUTH_HEADER_VALUE}"))]
316        #[configurable(metadata(docs::examples = "CUSTOM_PREFIX ${TOKEN}"))]
317        value: String,
318    },
319}
320
321pub trait MaybeAuth: Sized {
322    fn choose_one(&self, other: &Self) -> crate::Result<Self>;
323}
324
325impl MaybeAuth for Option<Auth> {
326    fn choose_one(&self, other: &Self) -> crate::Result<Self> {
327        if self.is_some() && other.is_some() {
328            Err("Two authorization credentials was provided.".into())
329        } else {
330            Ok(self.clone().or_else(|| other.clone()))
331        }
332    }
333}
334
335impl Auth {
336    pub fn apply<B>(&self, req: &mut Request<B>) {
337        self.apply_headers_map(req.headers_mut())
338    }
339
340    pub fn apply_builder(&self, mut builder: Builder) -> Builder {
341        if let Some(map) = builder.headers_mut() {
342            self.apply_headers_map(map)
343        }
344        builder
345    }
346
347    pub fn apply_headers_map(&self, map: &mut HeaderMap) {
348        match &self {
349            Auth::Basic { user, password } => {
350                let auth = Authorization::basic(user.as_str(), password.inner());
351                map.typed_insert(auth);
352            }
353            Auth::Bearer { token } => match Authorization::bearer(token.inner()) {
354                Ok(auth) => map.typed_insert(auth),
355                Err(error) => error!(message = "Invalid bearer token.", token = %token, %error),
356            },
357            Auth::Custom { value } => {
358                // The value contains just the value for the Authorization header
359                // Expected format: "SSWS token123" or "Bearer token123", etc.
360                match HeaderValue::from_str(value) {
361                    Ok(header_val) => {
362                        map.insert(http::header::AUTHORIZATION, header_val);
363                    }
364                    Err(error) => {
365                        error!(message = "Invalid custom auth header value.", value = %value, %error)
366                    }
367                }
368            }
369            #[cfg(feature = "aws-core")]
370            _ => {}
371        }
372    }
373}
374
375pub fn get_http_scheme_from_uri(uri: &Uri) -> &'static str {
376    // If there's no scheme, we just use "http" since it provides the most semantic relevance without inadvertently
377    // implying things it can't know i.e. returning "https" when we're not actually sure HTTPS was used.
378    uri.scheme_str().map_or("http", |scheme| match scheme {
379        "http" => "http",
380        "https" => "https",
381        // `http::Uri` ensures that we always get "http" or "https" if the URI is created with a well-formed scheme, but
382        // it also supports arbitrary schemes, which is where we bomb out down here, since we can't generate a static
383        // string for an arbitrary input string... and anything other than "http" and "https" makes no sense for an HTTP
384        // client anyways.
385        s => panic!("invalid URI scheme for HTTP client: {s}"),
386    })
387}
388
389/// Builds a [TraceLayer] configured for a HTTP server.
390///
391/// This layer emits HTTP specific telemetry for requests received, responses sent, and handler duration.
392pub fn build_http_trace_layer<T, U>(
393    span: Span,
394) -> TraceLayer<
395    SharedClassifier<ServerErrorsAsFailures>,
396    impl Fn(&Request<T>) -> Span + Clone,
397    impl Fn(&Request<T>, &Span) + Clone,
398    impl Fn(&Response<U>, Duration, &Span) + Clone,
399    (),
400    (),
401    (),
402> {
403    TraceLayer::new_for_http()
404        .make_span_with(move |request: &Request<T>| {
405            // This is an error span so that the labels are always present for metrics.
406            error_span!(
407               parent: &span,
408               "http-request",
409               method = %request.method(),
410               path = %request.uri().path(),
411            )
412        })
413        .on_request(Box::new(|_request: &Request<T>, _span: &Span| {
414            emit!(HttpServerRequestReceived);
415        }))
416        .on_response(|response: &Response<U>, latency: Duration, _span: &Span| {
417            emit!(HttpServerResponseSent { response, latency });
418        })
419        .on_failure(())
420        .on_body_chunk(())
421        .on_eos(())
422}
423
424/// Configuration of HTTP server keepalive parameters.
425#[serde_as]
426#[configurable_component]
427#[derive(Clone, Debug, PartialEq)]
428#[serde(deny_unknown_fields)]
429pub struct KeepaliveConfig {
430    /// The maximum amount of time a connection may exist before it is closed by sending
431    /// a `Connection: close` header on the HTTP response. Set this to a large value like
432    /// `100000000` to "disable" this feature
433    ///
434    ///
435    /// Only applies to HTTP/0.9, HTTP/1.0, and HTTP/1.1 requests.
436    ///
437    /// A random jitter configured by `max_connection_age_jitter_factor` is added
438    /// to the specified duration to spread out connection storms.
439    #[serde(default = "default_max_connection_age")]
440    #[configurable(metadata(docs::examples = 600))]
441    #[configurable(metadata(docs::type_unit = "seconds"))]
442    #[configurable(metadata(docs::human_name = "Maximum Connection Age"))]
443    pub max_connection_age_secs: Option<u64>,
444
445    /// The factor by which to jitter the `max_connection_age_secs` value.
446    ///
447    /// A value of 0.1 means that the actual duration will be between 90% and 110% of the
448    /// specified maximum duration.
449    #[serde(default = "default_max_connection_age_jitter_factor")]
450    #[configurable(validation(range(min = 0.0, max = 1.0)))]
451    pub max_connection_age_jitter_factor: f64,
452}
453
454const fn default_max_connection_age() -> Option<u64> {
455    Some(300) // 5 minutes
456}
457
458const fn default_max_connection_age_jitter_factor() -> f64 {
459    0.1
460}
461
462impl Default for KeepaliveConfig {
463    fn default() -> Self {
464        Self {
465            max_connection_age_secs: default_max_connection_age(),
466            max_connection_age_jitter_factor: default_max_connection_age_jitter_factor(),
467        }
468    }
469}
470
471/// A layer that limits the maximum duration of a client connection. It does so by adding a
472/// `Connection: close` header to the response if `max_connection_duration` time has elapsed
473/// since `start_reference`.
474///
475/// **Notes:**
476/// - This is intended to be used in a Hyper server (or similar) that will automatically close
477///   the connection after a response with a `Connection: close` header is sent.
478/// - This layer assumes that it is instantiated once per connection, which is true within the
479///   Hyper framework.
480pub struct MaxConnectionAgeLayer {
481    start_reference: Instant,
482    max_connection_age: Duration,
483    peer_addr: SocketAddr,
484}
485
486impl MaxConnectionAgeLayer {
487    pub fn new(max_connection_age: Duration, jitter_factor: f64, peer_addr: SocketAddr) -> Self {
488        Self {
489            start_reference: Instant::now(),
490            max_connection_age: Self::jittered_duration(max_connection_age, jitter_factor),
491            peer_addr,
492        }
493    }
494
495    fn jittered_duration(duration: Duration, jitter_factor: f64) -> Duration {
496        // Ensure the jitter_factor is between 0.0 and 1.0
497        let jitter_factor = jitter_factor.clamp(0.0, 1.0);
498        // Generate a random jitter factor between `1 - jitter_factor`` and `1 + jitter_factor`.
499        let mut rng = rand::rng();
500        let random_jitter_factor = rng.random_range(-jitter_factor..=jitter_factor) + 1.;
501        duration.mul_f64(random_jitter_factor)
502    }
503}
504
505impl<S> Layer<S> for MaxConnectionAgeLayer
506where
507    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
508    S::Future: Send + 'static,
509{
510    type Service = MaxConnectionAgeService<S>;
511
512    fn layer(&self, service: S) -> Self::Service {
513        MaxConnectionAgeService {
514            service,
515            start_reference: self.start_reference,
516            max_connection_age: self.max_connection_age,
517            peer_addr: self.peer_addr,
518        }
519    }
520}
521
522/// A service that limits the maximum age of a client connection. It does so by adding a
523/// `Connection: close` header to the response if `max_connection_age` time has elapsed
524/// since `start_reference`.
525///
526/// **Notes:**
527/// - This is intended to be used in a Hyper server (or similar) that will automatically close
528///   the connection after a response with a `Connection: close` header is sent.
529/// - This service assumes that it is instantiated once per connection, which is true within the
530///   Hyper framework.
531#[derive(Clone)]
532pub struct MaxConnectionAgeService<S> {
533    service: S,
534    start_reference: Instant,
535    max_connection_age: Duration,
536    peer_addr: SocketAddr,
537}
538
539impl<S, E> Service<Request<Body>> for MaxConnectionAgeService<S>
540where
541    S: Service<Request<Body>, Response = Response<Body>, Error = E> + Clone + Send + 'static,
542    S::Future: Send + 'static,
543{
544    type Response = S::Response;
545    type Error = E;
546    type Future = BoxFuture<'static, Result<Self::Response, E>>;
547
548    fn poll_ready(
549        &mut self,
550        cx: &mut std::task::Context<'_>,
551    ) -> std::task::Poll<Result<(), Self::Error>> {
552        self.service.poll_ready(cx)
553    }
554
555    fn call(&mut self, req: Request<Body>) -> Self::Future {
556        let start_reference = self.start_reference;
557        let max_connection_age = self.max_connection_age;
558        let peer_addr = self.peer_addr;
559        let version = req.version();
560        let future = self.service.call(req);
561        Box::pin(async move {
562            let mut response = future.await?;
563            match version {
564                Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => {
565                    if start_reference.elapsed() >= max_connection_age {
566                        debug!(
567                            message = "Closing connection due to max connection age.",
568                            ?max_connection_age,
569                            connection_age = ?start_reference.elapsed(),
570                            ?peer_addr,
571                        );
572                        // Tell the client to close this connection.
573                        // Hyper will automatically close the connection after the response is sent.
574                        response.headers_mut().insert(
575                            hyper::header::CONNECTION,
576                            hyper::header::HeaderValue::from_static("close"),
577                        );
578                    }
579                }
580                // TODO need to send GOAWAY frame
581                Version::HTTP_2 => (),
582                // TODO need to send GOAWAY frame
583                Version::HTTP_3 => (),
584                _ => (),
585            }
586            Ok(response)
587        })
588    }
589}
590
591/// The type of a query parameter's value, determines if it's treated as a plain string or a VRL expression.
592#[configurable_component]
593#[derive(Clone, Debug, Default, Eq, PartialEq)]
594#[serde(rename_all = "snake_case")]
595pub enum ParamType {
596    /// The parameter value is a plain string.
597    #[default]
598    String,
599    /// The parameter value is a VRL expression that is evaluated before each request.
600    Vrl,
601}
602
603impl ParamType {
604    fn is_default(&self) -> bool {
605        *self == Self::default()
606    }
607}
608
609/// Represents a query parameter value, which can be a simple string or a typed object
610/// indicating whether the value is a string or a VRL expression.
611#[configurable_component]
612#[derive(Clone, Debug, Eq, PartialEq)]
613#[serde(untagged)]
614pub enum ParameterValue {
615    /// A simple string value. For backwards compatibility.
616    String(String),
617    /// A value with an explicit type.
618    Typed {
619        /// The raw value of the parameter.
620        value: String,
621        /// The parameter type, indicating how the `value` should be treated.
622        #[serde(
623            default,
624            skip_serializing_if = "ParamType::is_default",
625            rename = "type"
626        )]
627        r#type: ParamType,
628    },
629}
630
631impl ParameterValue {
632    /// Returns true if the parameter is a VRL expression.
633    pub const fn is_vrl(&self) -> bool {
634        match self {
635            ParameterValue::String(_) => false,
636            ParameterValue::Typed { r#type, .. } => matches!(r#type, ParamType::Vrl),
637        }
638    }
639
640    /// Returns the raw string value of the parameter.
641    #[allow(clippy::missing_const_for_fn)]
642    pub fn value(&self) -> &str {
643        match self {
644            ParameterValue::String(s) => s,
645            ParameterValue::Typed { value, .. } => value,
646        }
647    }
648
649    /// Consumes the `ParameterValue` and returns the owned raw string value.
650    pub fn into_value(self) -> String {
651        match self {
652            ParameterValue::String(s) => s,
653            ParameterValue::Typed { value, .. } => value,
654        }
655    }
656}
657
658/// Configuration of the query parameter value for HTTP requests.
659#[configurable_component]
660#[derive(Clone, Debug, Eq, PartialEq)]
661#[serde(untagged)]
662#[configurable(metadata(docs::enum_tag_description = "Query parameter value"))]
663pub enum QueryParameterValue {
664    /// Query parameter with single value
665    SingleParam(ParameterValue),
666    /// Query parameter with multiple values
667    MultiParams(Vec<ParameterValue>),
668}
669
670impl QueryParameterValue {
671    /// Returns an iterator over the contained `ParameterValue`s.
672    pub fn iter(&self) -> impl Iterator<Item = &ParameterValue> {
673        match self {
674            QueryParameterValue::SingleParam(param) => std::slice::from_ref(param).iter(),
675            QueryParameterValue::MultiParams(params) => params.iter(),
676        }
677    }
678
679    /// Convert to `Vec<ParameterValue>` for owned iteration.
680    fn into_vec(self) -> Vec<ParameterValue> {
681        match self {
682            QueryParameterValue::SingleParam(param) => vec![param],
683            QueryParameterValue::MultiParams(params) => params,
684        }
685    }
686}
687
688// Implement IntoIterator for owned QueryParameterValue
689impl IntoIterator for QueryParameterValue {
690    type Item = ParameterValue;
691    type IntoIter = std::vec::IntoIter<ParameterValue>;
692
693    fn into_iter(self) -> Self::IntoIter {
694        self.into_vec().into_iter()
695    }
696}
697
698pub type QueryParameters = HashMap<String, QueryParameterValue>;
699
700#[cfg(test)]
701mod tests {
702    use std::convert::Infallible;
703
704    use hyper::{Server, server::conn::AddrStream, service::make_service_fn};
705    use proptest::prelude::*;
706    use tower::ServiceBuilder;
707
708    use super::*;
709    use crate::test_util::addr::next_addr;
710
711    #[test]
712    fn test_default_request_headers_defaults() {
713        let user_agent = HeaderValue::from_static("vector");
714        let mut request = Request::post("http://example.com").body(()).unwrap();
715        default_request_headers(&mut request, &user_agent);
716        assert_eq!(
717            request.headers().get("Accept-Encoding"),
718            Some(&HeaderValue::from_static("identity")),
719        );
720        assert_eq!(request.headers().get("User-Agent"), Some(&user_agent));
721    }
722
723    #[test]
724    fn test_default_request_headers_does_not_overwrite() {
725        let mut request = Request::post("http://example.com")
726            .header("Accept-Encoding", "gzip")
727            .header("User-Agent", "foo")
728            .body(())
729            .unwrap();
730        default_request_headers(&mut request, &HeaderValue::from_static("vector"));
731        assert_eq!(
732            request.headers().get("Accept-Encoding"),
733            Some(&HeaderValue::from_static("gzip")),
734        );
735        assert_eq!(
736            request.headers().get("User-Agent"),
737            Some(&HeaderValue::from_static("foo"))
738        );
739    }
740
741    proptest! {
742        #[test]
743        fn test_jittered_duration(duration_in_secs in 0u64..120, jitter_factor in 0.0..1.0) {
744            let duration = Duration::from_secs(duration_in_secs);
745            let jittered_duration = MaxConnectionAgeLayer::jittered_duration(duration, jitter_factor);
746
747            // Check properties based on the range of inputs
748            if jitter_factor == 0.0 {
749                // When jitter_factor is 0, jittered_duration should be equal to the original duration
750                prop_assert_eq!(
751                    jittered_duration,
752                    duration,
753                    "jittered_duration {:?} should be equal to duration {:?}",
754                    jittered_duration,
755                    duration,
756                );
757            } else if duration_in_secs > 0 {
758                // Check the bounds when duration is non-zero and jitter_factor is non-zero
759                let lower_bound = duration.mul_f64(1.0 - jitter_factor);
760                let upper_bound = duration.mul_f64(1.0 + jitter_factor);
761                prop_assert!(
762                    jittered_duration >= lower_bound && jittered_duration <= upper_bound,
763                    "jittered_duration {:?} should be between {:?} and {:?}",
764                    jittered_duration,
765                    lower_bound,
766                    upper_bound,
767                );
768            } else {
769                // When duration is zero, jittered_duration should also be zero
770                prop_assert_eq!(
771                    jittered_duration,
772                    Duration::from_secs(0),
773                    "jittered_duration {:?} should be equal to zero",
774                    jittered_duration,
775                );
776            }
777        }
778    }
779
780    #[tokio::test]
781    async fn test_max_connection_age_service() {
782        tokio::time::pause();
783
784        let start_reference = Instant::now();
785        let max_connection_age = Duration::from_secs(1);
786        let mut service = MaxConnectionAgeService {
787            service: tower::service_fn(|_req: Request<Body>| async {
788                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
789            }),
790            start_reference,
791            max_connection_age,
792            peer_addr: "1.2.3.4:1234".parse().unwrap(),
793        };
794
795        let req = Request::get("http://example.com")
796            .body(Body::empty())
797            .unwrap();
798        let response = service.call(req).await.unwrap();
799        assert_eq!(response.headers().get("Connection"), None);
800
801        tokio::time::advance(Duration::from_millis(500)).await;
802        let req = Request::get("http://example.com")
803            .body(Body::empty())
804            .unwrap();
805        let response = service.call(req).await.unwrap();
806        assert_eq!(response.headers().get("Connection"), None);
807
808        tokio::time::advance(Duration::from_millis(500)).await;
809        let req = Request::get("http://example.com")
810            .body(Body::empty())
811            .unwrap();
812        let response = service.call(req).await.unwrap();
813        assert_eq!(
814            response.headers().get("Connection"),
815            Some(&HeaderValue::from_static("close"))
816        );
817    }
818
819    #[tokio::test]
820    async fn test_max_connection_age_service_http2() {
821        tokio::time::pause();
822
823        let start_reference = Instant::now();
824        let max_connection_age = Duration::from_secs(0);
825        let mut service = MaxConnectionAgeService {
826            service: tower::service_fn(|_req: Request<Body>| async {
827                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
828            }),
829            start_reference,
830            max_connection_age,
831            peer_addr: "1.2.3.4:1234".parse().unwrap(),
832        };
833
834        let mut req = Request::get("http://example.com")
835            .body(Body::empty())
836            .unwrap();
837        *req.version_mut() = Version::HTTP_2;
838        let response = service.call(req).await.unwrap();
839        assert_eq!(response.headers().get("Connection"), None);
840    }
841
842    #[tokio::test]
843    async fn test_max_connection_age_service_http3() {
844        tokio::time::pause();
845
846        let start_reference = Instant::now();
847        let max_connection_age = Duration::from_secs(0);
848        let mut service = MaxConnectionAgeService {
849            service: tower::service_fn(|_req: Request<Body>| async {
850                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
851            }),
852            start_reference,
853            max_connection_age,
854            peer_addr: "1.2.3.4:1234".parse().unwrap(),
855        };
856
857        let mut req = Request::get("http://example.com")
858            .body(Body::empty())
859            .unwrap();
860        *req.version_mut() = Version::HTTP_3;
861        let response = service.call(req).await.unwrap();
862        assert_eq!(response.headers().get("Connection"), None);
863    }
864
865    #[tokio::test]
866    async fn test_max_connection_age_service_zero_duration() {
867        tokio::time::pause();
868
869        let start_reference = Instant::now();
870        let max_connection_age = Duration::from_millis(0);
871        let mut service = MaxConnectionAgeService {
872            service: tower::service_fn(|_req: Request<Body>| async {
873                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
874            }),
875            start_reference,
876            max_connection_age,
877            peer_addr: "1.2.3.4:1234".parse().unwrap(),
878        };
879
880        let req = Request::get("http://example.com")
881            .body(Body::empty())
882            .unwrap();
883        let response = service.call(req).await.unwrap();
884        assert_eq!(
885            response.headers().get("Connection"),
886            Some(&HeaderValue::from_static("close"))
887        );
888    }
889
890    // Note that we unfortunately cannot mock the time in this test because the client calls
891    // sleep internally, which advances the clock.  However, this test shouldn't be flakey given
892    // the time bounds provided.
893    #[tokio::test]
894    async fn test_max_connection_age_service_with_hyper_server() {
895        // Create a hyper server with the max connection age layer.
896        let max_connection_age = Duration::from_secs(1);
897        let (_guard, addr) = next_addr();
898        let make_svc = make_service_fn(move |conn: &AddrStream| {
899            let svc = ServiceBuilder::new()
900                .layer(MaxConnectionAgeLayer::new(
901                    max_connection_age,
902                    0.,
903                    conn.remote_addr(),
904                ))
905                .service(tower::service_fn(|_req: Request<Body>| async {
906                    Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
907                }));
908            futures_util::future::ok::<_, Infallible>(svc)
909        });
910
911        tokio::spawn(async move {
912            Server::bind(&addr).serve(make_svc).await.unwrap();
913        });
914
915        // Wait for the server to start.
916        tokio::time::sleep(Duration::from_millis(10)).await;
917
918        // Create a client, which has its own connection pool.
919        let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
920
921        // Responses generated before the client's max connection age has elapsed do not
922        // include a `Connection: close` header in the response.
923        let req = Request::get(format!("http://{addr}/"))
924            .body(Body::empty())
925            .unwrap();
926        let response = client.send(req).await.unwrap();
927        assert_eq!(response.headers().get("Connection"), None);
928
929        let req = Request::get(format!("http://{addr}/"))
930            .body(Body::empty())
931            .unwrap();
932        let response = client.send(req).await.unwrap();
933        assert_eq!(response.headers().get("Connection"), None);
934
935        // The first response generated after the client's max connection age has elapsed should
936        // include the `Connection: close` header.
937        tokio::time::sleep(Duration::from_secs(1)).await;
938        let req = Request::get(format!("http://{addr}/"))
939            .body(Body::empty())
940            .unwrap();
941        let response = client.send(req).await.unwrap();
942        assert_eq!(
943            response.headers().get("Connection"),
944            Some(&HeaderValue::from_static("close")),
945        );
946
947        // The next request should establish a new connection.
948        // Importantly, this also confirms that each connection has its own independent
949        // connection age timer.
950        let req = Request::get(format!("http://{addr}/"))
951            .body(Body::empty())
952            .unwrap();
953        let response = client.send(req).await.unwrap();
954        assert_eq!(response.headers().get("Connection"), None);
955    }
956}