vector/
http.rs

1#![allow(missing_docs)]
2use futures::future::BoxFuture;
3use headers::{Authorization, HeaderMapExt};
4use http::{
5    header::HeaderValue, request::Builder, uri::InvalidUri, HeaderMap, Request, Response, Uri,
6    Version,
7};
8use hyper::{
9    body::{Body, HttpBody},
10    client,
11    client::{Client, HttpConnector},
12};
13use hyper_openssl::HttpsConnector;
14use hyper_proxy::ProxyConnector;
15use rand::Rng;
16use serde_with::serde_as;
17use snafu::{ResultExt, Snafu};
18use std::{
19    collections::HashMap,
20    fmt,
21    net::SocketAddr,
22    task::{Context, Poll},
23    time::Duration,
24};
25use tokio::time::Instant;
26use tower::{Layer, Service};
27use tower_http::{
28    classify::{ServerErrorsAsFailures, SharedClassifier},
29    trace::TraceLayer,
30};
31use tracing::{Instrument, Span};
32use vector_lib::configurable::configurable_component;
33use vector_lib::sensitive_string::SensitiveString;
34
35#[cfg(feature = "aws-core")]
36use crate::aws::AwsAuthentication;
37
38use crate::{
39    config::ProxyConfig,
40    internal_events::{http_client, HttpServerRequestReceived, HttpServerResponseSent},
41    tls::{tls_connector_builder, MaybeTlsSettings, TlsError},
42};
43
44pub mod status {
45    pub const FORBIDDEN: u16 = 403;
46    pub const NOT_FOUND: u16 = 404;
47    pub const TOO_MANY_REQUESTS: u16 = 429;
48}
49
50#[derive(Debug, Snafu)]
51#[snafu(visibility(pub(crate)))]
52pub enum HttpError {
53    #[snafu(display("Failed to build TLS connector: {}", source))]
54    BuildTlsConnector { source: TlsError },
55    #[snafu(display("Failed to build HTTPS connector: {}", source))]
56    MakeHttpsConnector { source: openssl::error::ErrorStack },
57    #[snafu(display("Failed to build Proxy connector: {}", source))]
58    MakeProxyConnector { source: InvalidUri },
59    #[snafu(display("Failed to make HTTP(S) request: {}", source))]
60    CallRequest { source: hyper::Error },
61    #[snafu(display("Failed to build HTTP request: {}", source))]
62    BuildRequest { source: http::Error },
63}
64
65impl HttpError {
66    pub const fn is_retriable(&self) -> bool {
67        match self {
68            HttpError::BuildRequest { .. } | HttpError::MakeProxyConnector { .. } => false,
69            HttpError::CallRequest { .. }
70            | HttpError::BuildTlsConnector { .. }
71            | HttpError::MakeHttpsConnector { .. } => true,
72        }
73    }
74}
75
76pub type HttpClientFuture = <HttpClient as Service<http::Request<Body>>>::Future;
77type HttpProxyConnector = ProxyConnector<HttpsConnector<HttpConnector>>;
78
79pub struct HttpClient<B = Body> {
80    client: Client<HttpProxyConnector, B>,
81    user_agent: HeaderValue,
82    proxy_connector: HttpProxyConnector,
83}
84
85impl<B> HttpClient<B>
86where
87    B: fmt::Debug + HttpBody + Send + 'static,
88    B::Data: Send,
89    B::Error: Into<crate::Error>,
90{
91    pub fn new(
92        tls_settings: impl Into<MaybeTlsSettings>,
93        proxy_config: &ProxyConfig,
94    ) -> Result<HttpClient<B>, HttpError> {
95        HttpClient::new_with_custom_client(tls_settings, proxy_config, &mut Client::builder())
96    }
97
98    pub fn new_with_custom_client(
99        tls_settings: impl Into<MaybeTlsSettings>,
100        proxy_config: &ProxyConfig,
101        client_builder: &mut client::Builder,
102    ) -> Result<HttpClient<B>, HttpError> {
103        let proxy_connector = build_proxy_connector(tls_settings.into(), proxy_config)?;
104        let client = client_builder.build(proxy_connector.clone());
105
106        let app_name = crate::get_app_name();
107        let version = crate::get_version();
108        let user_agent = HeaderValue::from_str(&format!("{app_name}/{version}"))
109            .expect("Invalid header value for user-agent!");
110
111        Ok(HttpClient {
112            client,
113            user_agent,
114            proxy_connector,
115        })
116    }
117
118    pub fn send(
119        &self,
120        mut request: Request<B>,
121    ) -> BoxFuture<'static, Result<http::Response<Body>, HttpError>> {
122        let span = tracing::info_span!("http");
123        let _enter = span.enter();
124
125        default_request_headers(&mut request, &self.user_agent);
126        self.maybe_add_proxy_headers(&mut request);
127
128        emit!(http_client::AboutToSendHttpRequest { request: &request });
129
130        let response = self.client.request(request);
131
132        let fut = async move {
133            // Capture the time right before we issue the request.
134            // Request doesn't start the processing until we start polling it.
135            let before = std::time::Instant::now();
136
137            // Send request and wait for the result.
138            let response_result = response.await;
139
140            // Compute the roundtrip time it took to send the request and get
141            // the response or error.
142            let roundtrip = before.elapsed();
143
144            // Handle the errors and extract the response.
145            let response = response_result
146                .inspect_err(|error| {
147                    // Emit the error into the internal events system.
148                    emit!(http_client::GotHttpWarning { error, roundtrip });
149                })
150                .context(CallRequestSnafu)?;
151
152            // Emit the response into the internal events system.
153            emit!(http_client::GotHttpResponse {
154                response: &response,
155                roundtrip
156            });
157            Ok(response)
158        }
159        .instrument(span.clone().or_current());
160
161        Box::pin(fut)
162    }
163
164    fn maybe_add_proxy_headers(&self, request: &mut Request<B>) {
165        if let Some(proxy_headers) = self.proxy_connector.http_headers(request.uri()) {
166            for (k, v) in proxy_headers {
167                let request_headers = request.headers_mut();
168                if !request_headers.contains_key(k) {
169                    request_headers.insert(k, v.into());
170                }
171            }
172        }
173    }
174}
175
176pub fn build_proxy_connector(
177    tls_settings: MaybeTlsSettings,
178    proxy_config: &ProxyConfig,
179) -> Result<ProxyConnector<HttpsConnector<HttpConnector>>, HttpError> {
180    // Create dedicated TLS connector for the proxied connection with user TLS settings.
181    let tls = tls_connector_builder(&tls_settings)
182        .context(BuildTlsConnectorSnafu)?
183        .build();
184    let https = build_tls_connector(tls_settings)?;
185    let mut proxy = ProxyConnector::new(https).unwrap();
186    // Make proxy connector aware of user TLS settings by setting the TLS connector:
187    // https://github.com/vectordotdev/vector/issues/13683
188    proxy.set_tls(Some(tls));
189    proxy_config
190        .configure(&mut proxy)
191        .context(MakeProxyConnectorSnafu)?;
192    Ok(proxy)
193}
194
195pub fn build_tls_connector(
196    tls_settings: MaybeTlsSettings,
197) -> Result<HttpsConnector<HttpConnector>, HttpError> {
198    let mut http = HttpConnector::new();
199    http.enforce_http(false);
200
201    let tls = tls_connector_builder(&tls_settings).context(BuildTlsConnectorSnafu)?;
202    let mut https = HttpsConnector::with_connector(http, tls).context(MakeHttpsConnectorSnafu)?;
203
204    let settings = tls_settings.tls().cloned();
205    https.set_callback(move |c, _uri| {
206        if let Some(settings) = &settings {
207            settings.apply_connect_configuration(c)
208        } else {
209            Ok(())
210        }
211    });
212    Ok(https)
213}
214
215fn default_request_headers<B>(request: &mut Request<B>, user_agent: &HeaderValue) {
216    if !request.headers().contains_key("User-Agent") {
217        request
218            .headers_mut()
219            .insert("User-Agent", user_agent.clone());
220    }
221
222    if !request.headers().contains_key("Accept-Encoding") {
223        // hardcoding until we support compressed responses:
224        // https://github.com/vectordotdev/vector/issues/5440
225        request
226            .headers_mut()
227            .insert("Accept-Encoding", HeaderValue::from_static("identity"));
228    }
229}
230
231impl<B> Service<Request<B>> for HttpClient<B>
232where
233    B: fmt::Debug + HttpBody + Send + 'static,
234    B::Data: Send,
235    B::Error: Into<crate::Error> + Send,
236{
237    type Response = http::Response<Body>;
238    type Error = HttpError;
239    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
240
241    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
242        Poll::Ready(Ok(()))
243    }
244
245    fn call(&mut self, request: Request<B>) -> Self::Future {
246        self.send(request)
247    }
248}
249
250impl<B> Clone for HttpClient<B> {
251    fn clone(&self) -> Self {
252        Self {
253            client: self.client.clone(),
254            user_agent: self.user_agent.clone(),
255            proxy_connector: self.proxy_connector.clone(),
256        }
257    }
258}
259
260impl<B> fmt::Debug for HttpClient<B> {
261    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262        f.debug_struct("HttpClient")
263            .field("client", &self.client)
264            .field("user_agent", &self.user_agent)
265            .finish()
266    }
267}
268
269/// Configuration of the authentication strategy for HTTP requests.
270///
271/// HTTP authentication should be used with HTTPS only, as the authentication credentials are passed as an
272/// HTTP header without any additional encryption beyond what is provided by the transport itself.
273#[configurable_component]
274#[derive(Clone, Debug, Eq, PartialEq)]
275#[serde(deny_unknown_fields, rename_all = "snake_case", tag = "strategy")]
276#[configurable(metadata(docs::enum_tag_description = "The authentication strategy to use."))]
277pub enum Auth {
278    /// Basic authentication.
279    ///
280    /// The username and password are concatenated and encoded using [base64][base64].
281    ///
282    /// [base64]: https://en.wikipedia.org/wiki/Base64
283    Basic {
284        /// The basic authentication username.
285        #[configurable(metadata(docs::examples = "${USERNAME}"))]
286        #[configurable(metadata(docs::examples = "username"))]
287        user: String,
288
289        /// The basic authentication password.
290        #[configurable(metadata(docs::examples = "${PASSWORD}"))]
291        #[configurable(metadata(docs::examples = "password"))]
292        password: SensitiveString,
293    },
294
295    /// Bearer authentication.
296    ///
297    /// The bearer token value (OAuth2, JWT, etc.) is passed as-is.
298    Bearer {
299        /// The bearer authentication token.
300        token: SensitiveString,
301    },
302
303    #[cfg(feature = "aws-core")]
304    /// AWS authentication.
305    Aws {
306        /// The AWS authentication configuration.
307        auth: AwsAuthentication,
308
309        /// The AWS service name to use for signing.
310        service: String,
311    },
312}
313
314pub trait MaybeAuth: Sized {
315    fn choose_one(&self, other: &Self) -> crate::Result<Self>;
316}
317
318impl MaybeAuth for Option<Auth> {
319    fn choose_one(&self, other: &Self) -> crate::Result<Self> {
320        if self.is_some() && other.is_some() {
321            Err("Two authorization credentials was provided.".into())
322        } else {
323            Ok(self.clone().or_else(|| other.clone()))
324        }
325    }
326}
327
328impl Auth {
329    pub fn apply<B>(&self, req: &mut Request<B>) {
330        self.apply_headers_map(req.headers_mut())
331    }
332
333    pub fn apply_builder(&self, mut builder: Builder) -> Builder {
334        if let Some(map) = builder.headers_mut() {
335            self.apply_headers_map(map)
336        }
337        builder
338    }
339
340    pub fn apply_headers_map(&self, map: &mut HeaderMap) {
341        match &self {
342            Auth::Basic { user, password } => {
343                let auth = Authorization::basic(user.as_str(), password.inner());
344                map.typed_insert(auth);
345            }
346            Auth::Bearer { token } => match Authorization::bearer(token.inner()) {
347                Ok(auth) => map.typed_insert(auth),
348                Err(error) => error!(message = "Invalid bearer token.", token = %token, %error),
349            },
350            #[cfg(feature = "aws-core")]
351            _ => {}
352        }
353    }
354}
355
356pub fn get_http_scheme_from_uri(uri: &Uri) -> &'static str {
357    // If there's no scheme, we just use "http" since it provides the most semantic relevance without inadvertently
358    // implying things it can't know i.e. returning "https" when we're not actually sure HTTPS was used.
359    uri.scheme_str().map_or("http", |scheme| match scheme {
360        "http" => "http",
361        "https" => "https",
362        // `http::Uri` ensures that we always get "http" or "https" if the URI is created with a well-formed scheme, but
363        // it also supports arbitrary schemes, which is where we bomb out down here, since we can't generate a static
364        // string for an arbitrary input string... and anything other than "http" and "https" makes no sense for an HTTP
365        // client anyways.
366        s => panic!("invalid URI scheme for HTTP client: {s}"),
367    })
368}
369
370/// Builds a [TraceLayer] configured for a HTTP server.
371///
372/// This layer emits HTTP specific telemetry for requests received, responses sent, and handler duration.
373pub fn build_http_trace_layer<T, U>(
374    span: Span,
375) -> TraceLayer<
376    SharedClassifier<ServerErrorsAsFailures>,
377    impl Fn(&Request<T>) -> Span + Clone,
378    impl Fn(&Request<T>, &Span) + Clone,
379    impl Fn(&Response<U>, Duration, &Span) + Clone,
380    (),
381    (),
382    (),
383> {
384    TraceLayer::new_for_http()
385        .make_span_with(move |request: &Request<T>| {
386            // This is an error span so that the labels are always present for metrics.
387            error_span!(
388               parent: &span,
389               "http-request",
390               method = %request.method(),
391               path = %request.uri().path(),
392            )
393        })
394        .on_request(Box::new(|_request: &Request<T>, _span: &Span| {
395            emit!(HttpServerRequestReceived);
396        }))
397        .on_response(|response: &Response<U>, latency: Duration, _span: &Span| {
398            emit!(HttpServerResponseSent { response, latency });
399        })
400        .on_failure(())
401        .on_body_chunk(())
402        .on_eos(())
403}
404
405/// Configuration of HTTP server keepalive parameters.
406#[serde_as]
407#[configurable_component]
408#[derive(Clone, Debug, PartialEq)]
409#[serde(deny_unknown_fields)]
410pub struct KeepaliveConfig {
411    /// The maximum amount of time a connection may exist before it is closed by sending
412    /// a `Connection: close` header on the HTTP response. Set this to a large value like
413    /// `100000000` to "disable" this feature
414    ///
415    ///
416    /// Only applies to HTTP/0.9, HTTP/1.0, and HTTP/1.1 requests.
417    ///
418    /// A random jitter configured by `max_connection_age_jitter_factor` is added
419    /// to the specified duration to spread out connection storms.
420    #[serde(default = "default_max_connection_age")]
421    #[configurable(metadata(docs::examples = 600))]
422    #[configurable(metadata(docs::type_unit = "seconds"))]
423    #[configurable(metadata(docs::human_name = "Maximum Connection Age"))]
424    pub max_connection_age_secs: Option<u64>,
425
426    /// The factor by which to jitter the `max_connection_age_secs` value.
427    ///
428    /// A value of 0.1 means that the actual duration will be between 90% and 110% of the
429    /// specified maximum duration.
430    #[serde(default = "default_max_connection_age_jitter_factor")]
431    #[configurable(validation(range(min = 0.0, max = 1.0)))]
432    pub max_connection_age_jitter_factor: f64,
433}
434
435const fn default_max_connection_age() -> Option<u64> {
436    Some(300) // 5 minutes
437}
438
439const fn default_max_connection_age_jitter_factor() -> f64 {
440    0.1
441}
442
443impl Default for KeepaliveConfig {
444    fn default() -> Self {
445        Self {
446            max_connection_age_secs: default_max_connection_age(),
447            max_connection_age_jitter_factor: default_max_connection_age_jitter_factor(),
448        }
449    }
450}
451
452/// A layer that limits the maximum duration of a client connection. It does so by adding a
453/// `Connection: close` header to the response if `max_connection_duration` time has elapsed
454/// since `start_reference`.
455///
456/// **Notes:**
457/// - This is intended to be used in a Hyper server (or similar) that will automatically close
458///   the connection after a response with a `Connection: close` header is sent.
459/// - This layer assumes that it is instantiated once per connection, which is true within the
460///   Hyper framework.
461pub struct MaxConnectionAgeLayer {
462    start_reference: Instant,
463    max_connection_age: Duration,
464    peer_addr: SocketAddr,
465}
466
467impl MaxConnectionAgeLayer {
468    pub fn new(max_connection_age: Duration, jitter_factor: f64, peer_addr: SocketAddr) -> Self {
469        Self {
470            start_reference: Instant::now(),
471            max_connection_age: Self::jittered_duration(max_connection_age, jitter_factor),
472            peer_addr,
473        }
474    }
475
476    fn jittered_duration(duration: Duration, jitter_factor: f64) -> Duration {
477        // Ensure the jitter_factor is between 0.0 and 1.0
478        let jitter_factor = jitter_factor.clamp(0.0, 1.0);
479        // Generate a random jitter factor between `1 - jitter_factor`` and `1 + jitter_factor`.
480        let mut rng = rand::rng();
481        let random_jitter_factor = rng.random_range(-jitter_factor..=jitter_factor) + 1.;
482        duration.mul_f64(random_jitter_factor)
483    }
484}
485
486impl<S> Layer<S> for MaxConnectionAgeLayer
487where
488    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
489    S::Future: Send + 'static,
490{
491    type Service = MaxConnectionAgeService<S>;
492
493    fn layer(&self, service: S) -> Self::Service {
494        MaxConnectionAgeService {
495            service,
496            start_reference: self.start_reference,
497            max_connection_age: self.max_connection_age,
498            peer_addr: self.peer_addr,
499        }
500    }
501}
502
503/// A service that limits the maximum age of a client connection. It does so by adding a
504/// `Connection: close` header to the response if `max_connection_age` time has elapsed
505/// since `start_reference`.
506///
507/// **Notes:**
508/// - This is intended to be used in a Hyper server (or similar) that will automatically close
509///   the connection after a response with a `Connection: close` header is sent.
510/// - This service assumes that it is instantiated once per connection, which is true within the
511///   Hyper framework.
512#[derive(Clone)]
513pub struct MaxConnectionAgeService<S> {
514    service: S,
515    start_reference: Instant,
516    max_connection_age: Duration,
517    peer_addr: SocketAddr,
518}
519
520impl<S, E> Service<Request<Body>> for MaxConnectionAgeService<S>
521where
522    S: Service<Request<Body>, Response = Response<Body>, Error = E> + Clone + Send + 'static,
523    S::Future: Send + 'static,
524{
525    type Response = S::Response;
526    type Error = E;
527    type Future = BoxFuture<'static, Result<Self::Response, E>>;
528
529    fn poll_ready(
530        &mut self,
531        cx: &mut std::task::Context<'_>,
532    ) -> std::task::Poll<Result<(), Self::Error>> {
533        self.service.poll_ready(cx)
534    }
535
536    fn call(&mut self, req: Request<Body>) -> Self::Future {
537        let start_reference = self.start_reference;
538        let max_connection_age = self.max_connection_age;
539        let peer_addr = self.peer_addr;
540        let version = req.version();
541        let future = self.service.call(req);
542        Box::pin(async move {
543            let mut response = future.await?;
544            match version {
545                Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => {
546                    if start_reference.elapsed() >= max_connection_age {
547                        debug!(
548                            message = "Closing connection due to max connection age.",
549                            ?max_connection_age,
550                            connection_age = ?start_reference.elapsed(),
551                            ?peer_addr,
552                        );
553                        // Tell the client to close this connection.
554                        // Hyper will automatically close the connection after the response is sent.
555                        response.headers_mut().insert(
556                            hyper::header::CONNECTION,
557                            hyper::header::HeaderValue::from_static("close"),
558                        );
559                    }
560                }
561                // TODO need to send GOAWAY frame
562                Version::HTTP_2 => (),
563                // TODO need to send GOAWAY frame
564                Version::HTTP_3 => (),
565                _ => (),
566            }
567            Ok(response)
568        })
569    }
570}
571
572/// The type of a query parameter's value, determines if it's treated as a plain string or a VRL expression.
573#[configurable_component]
574#[derive(Clone, Debug, Default, Eq, PartialEq)]
575#[serde(rename_all = "snake_case")]
576pub enum ParamType {
577    /// The parameter value is a plain string.
578    #[default]
579    String,
580    /// The parameter value is a VRL expression that will be evaluated before each request.
581    Vrl,
582}
583
584impl ParamType {
585    fn is_default(&self) -> bool {
586        *self == Self::default()
587    }
588}
589
590/// Represents a query parameter value, which can be a simple string or a typed object
591/// indicating whether the value is a string or a VRL expression.
592#[configurable_component]
593#[derive(Clone, Debug, Eq, PartialEq)]
594#[serde(untagged)]
595pub enum ParameterValue {
596    /// A simple string value. For backwards compatibility.
597    String(String),
598    /// A value with an explicit type.
599    Typed {
600        /// The raw value of the parameter.
601        value: String,
602        /// The type of the parameter, indicating how the `value` should be treated.
603        #[serde(
604            default,
605            skip_serializing_if = "ParamType::is_default",
606            rename = "type"
607        )]
608        r#type: ParamType,
609    },
610}
611
612impl ParameterValue {
613    /// Returns true if the parameter is a VRL expression.
614    pub const fn is_vrl(&self) -> bool {
615        match self {
616            ParameterValue::String(_) => false,
617            ParameterValue::Typed { r#type, .. } => matches!(r#type, ParamType::Vrl),
618        }
619    }
620
621    /// Returns the raw string value of the parameter.
622    #[allow(clippy::missing_const_for_fn)]
623    pub fn value(&self) -> &str {
624        match self {
625            ParameterValue::String(s) => s,
626            ParameterValue::Typed { value, .. } => value,
627        }
628    }
629
630    /// Consumes the `ParameterValue` and returns the owned raw string value.
631    pub fn into_value(self) -> String {
632        match self {
633            ParameterValue::String(s) => s,
634            ParameterValue::Typed { value, .. } => value,
635        }
636    }
637}
638
639/// Configuration of the query parameter value for HTTP requests.
640#[configurable_component]
641#[derive(Clone, Debug, Eq, PartialEq)]
642#[serde(untagged)]
643#[configurable(metadata(docs::enum_tag_description = "Query parameter value"))]
644pub enum QueryParameterValue {
645    /// Query parameter with single value
646    SingleParam(ParameterValue),
647    /// Query parameter with multiple values
648    MultiParams(Vec<ParameterValue>),
649}
650
651impl QueryParameterValue {
652    /// Returns an iterator over the contained `ParameterValue`s.
653    pub fn iter(&self) -> impl Iterator<Item = &ParameterValue> {
654        match self {
655            QueryParameterValue::SingleParam(param) => std::slice::from_ref(param).iter(),
656            QueryParameterValue::MultiParams(params) => params.iter(),
657        }
658    }
659
660    /// Convert to `Vec<ParameterValue>` for owned iteration.
661    fn into_vec(self) -> Vec<ParameterValue> {
662        match self {
663            QueryParameterValue::SingleParam(param) => vec![param],
664            QueryParameterValue::MultiParams(params) => params,
665        }
666    }
667}
668
669// Implement IntoIterator for owned QueryParameterValue
670impl IntoIterator for QueryParameterValue {
671    type Item = ParameterValue;
672    type IntoIter = std::vec::IntoIter<ParameterValue>;
673
674    fn into_iter(self) -> Self::IntoIter {
675        self.into_vec().into_iter()
676    }
677}
678
679pub type QueryParameters = HashMap<String, QueryParameterValue>;
680
681#[cfg(test)]
682mod tests {
683    use std::convert::Infallible;
684
685    use hyper::{server::conn::AddrStream, service::make_service_fn, Server};
686    use proptest::prelude::*;
687    use tower::ServiceBuilder;
688
689    use crate::test_util::next_addr;
690
691    use super::*;
692
693    #[test]
694    fn test_default_request_headers_defaults() {
695        let user_agent = HeaderValue::from_static("vector");
696        let mut request = Request::post("http://example.com").body(()).unwrap();
697        default_request_headers(&mut request, &user_agent);
698        assert_eq!(
699            request.headers().get("Accept-Encoding"),
700            Some(&HeaderValue::from_static("identity")),
701        );
702        assert_eq!(request.headers().get("User-Agent"), Some(&user_agent));
703    }
704
705    #[test]
706    fn test_default_request_headers_does_not_overwrite() {
707        let mut request = Request::post("http://example.com")
708            .header("Accept-Encoding", "gzip")
709            .header("User-Agent", "foo")
710            .body(())
711            .unwrap();
712        default_request_headers(&mut request, &HeaderValue::from_static("vector"));
713        assert_eq!(
714            request.headers().get("Accept-Encoding"),
715            Some(&HeaderValue::from_static("gzip")),
716        );
717        assert_eq!(
718            request.headers().get("User-Agent"),
719            Some(&HeaderValue::from_static("foo"))
720        );
721    }
722
723    proptest! {
724        #[test]
725        fn test_jittered_duration(duration_in_secs in 0u64..120, jitter_factor in 0.0..1.0) {
726            let duration = Duration::from_secs(duration_in_secs);
727            let jittered_duration = MaxConnectionAgeLayer::jittered_duration(duration, jitter_factor);
728
729            // Check properties based on the range of inputs
730            if jitter_factor == 0.0 {
731                // When jitter_factor is 0, jittered_duration should be equal to the original duration
732                prop_assert_eq!(
733                    jittered_duration,
734                    duration,
735                    "jittered_duration {:?} should be equal to duration {:?}",
736                    jittered_duration,
737                    duration,
738                );
739            } else if duration_in_secs > 0 {
740                // Check the bounds when duration is non-zero and jitter_factor is non-zero
741                let lower_bound = duration.mul_f64(1.0 - jitter_factor);
742                let upper_bound = duration.mul_f64(1.0 + jitter_factor);
743                prop_assert!(
744                    jittered_duration >= lower_bound && jittered_duration <= upper_bound,
745                    "jittered_duration {:?} should be between {:?} and {:?}",
746                    jittered_duration,
747                    lower_bound,
748                    upper_bound,
749                );
750            } else {
751                // When duration is zero, jittered_duration should also be zero
752                prop_assert_eq!(
753                    jittered_duration,
754                    Duration::from_secs(0),
755                    "jittered_duration {:?} should be equal to zero",
756                    jittered_duration,
757                );
758            }
759        }
760    }
761
762    #[tokio::test]
763    async fn test_max_connection_age_service() {
764        tokio::time::pause();
765
766        let start_reference = Instant::now();
767        let max_connection_age = Duration::from_secs(1);
768        let mut service = MaxConnectionAgeService {
769            service: tower::service_fn(|_req: Request<Body>| async {
770                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
771            }),
772            start_reference,
773            max_connection_age,
774            peer_addr: "1.2.3.4:1234".parse().unwrap(),
775        };
776
777        let req = Request::get("http://example.com")
778            .body(Body::empty())
779            .unwrap();
780        let response = service.call(req).await.unwrap();
781        assert_eq!(response.headers().get("Connection"), None);
782
783        tokio::time::advance(Duration::from_millis(500)).await;
784        let req = Request::get("http://example.com")
785            .body(Body::empty())
786            .unwrap();
787        let response = service.call(req).await.unwrap();
788        assert_eq!(response.headers().get("Connection"), None);
789
790        tokio::time::advance(Duration::from_millis(500)).await;
791        let req = Request::get("http://example.com")
792            .body(Body::empty())
793            .unwrap();
794        let response = service.call(req).await.unwrap();
795        assert_eq!(
796            response.headers().get("Connection"),
797            Some(&HeaderValue::from_static("close"))
798        );
799    }
800
801    #[tokio::test]
802    async fn test_max_connection_age_service_http2() {
803        tokio::time::pause();
804
805        let start_reference = Instant::now();
806        let max_connection_age = Duration::from_secs(0);
807        let mut service = MaxConnectionAgeService {
808            service: tower::service_fn(|_req: Request<Body>| async {
809                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
810            }),
811            start_reference,
812            max_connection_age,
813            peer_addr: "1.2.3.4:1234".parse().unwrap(),
814        };
815
816        let mut req = Request::get("http://example.com")
817            .body(Body::empty())
818            .unwrap();
819        *req.version_mut() = Version::HTTP_2;
820        let response = service.call(req).await.unwrap();
821        assert_eq!(response.headers().get("Connection"), None);
822    }
823
824    #[tokio::test]
825    async fn test_max_connection_age_service_http3() {
826        tokio::time::pause();
827
828        let start_reference = Instant::now();
829        let max_connection_age = Duration::from_secs(0);
830        let mut service = MaxConnectionAgeService {
831            service: tower::service_fn(|_req: Request<Body>| async {
832                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
833            }),
834            start_reference,
835            max_connection_age,
836            peer_addr: "1.2.3.4:1234".parse().unwrap(),
837        };
838
839        let mut req = Request::get("http://example.com")
840            .body(Body::empty())
841            .unwrap();
842        *req.version_mut() = Version::HTTP_3;
843        let response = service.call(req).await.unwrap();
844        assert_eq!(response.headers().get("Connection"), None);
845    }
846
847    #[tokio::test]
848    async fn test_max_connection_age_service_zero_duration() {
849        tokio::time::pause();
850
851        let start_reference = Instant::now();
852        let max_connection_age = Duration::from_millis(0);
853        let mut service = MaxConnectionAgeService {
854            service: tower::service_fn(|_req: Request<Body>| async {
855                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
856            }),
857            start_reference,
858            max_connection_age,
859            peer_addr: "1.2.3.4:1234".parse().unwrap(),
860        };
861
862        let req = Request::get("http://example.com")
863            .body(Body::empty())
864            .unwrap();
865        let response = service.call(req).await.unwrap();
866        assert_eq!(
867            response.headers().get("Connection"),
868            Some(&HeaderValue::from_static("close"))
869        );
870    }
871
872    // Note that we unfortunately cannot mock the time in this test because the client calls
873    // sleep internally, which advances the clock.  However, this test shouldn't be flakey given
874    // the time bounds provided.
875    #[tokio::test]
876    async fn test_max_connection_age_service_with_hyper_server() {
877        // Create a hyper server with the max connection age layer.
878        let max_connection_age = Duration::from_secs(1);
879        let addr = next_addr();
880        let make_svc = make_service_fn(move |conn: &AddrStream| {
881            let svc = ServiceBuilder::new()
882                .layer(MaxConnectionAgeLayer::new(
883                    max_connection_age,
884                    0.,
885                    conn.remote_addr(),
886                ))
887                .service(tower::service_fn(|_req: Request<Body>| async {
888                    Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
889                }));
890            futures_util::future::ok::<_, Infallible>(svc)
891        });
892
893        tokio::spawn(async move {
894            Server::bind(&addr).serve(make_svc).await.unwrap();
895        });
896
897        // Wait for the server to start.
898        tokio::time::sleep(Duration::from_millis(10)).await;
899
900        // Create a client, which has its own connection pool.
901        let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
902
903        // Responses generated before the client's max connection age has elapsed do not
904        // include a `Connection: close` header in the response.
905        let req = Request::get(format!("http://{addr}/"))
906            .body(Body::empty())
907            .unwrap();
908        let response = client.send(req).await.unwrap();
909        assert_eq!(response.headers().get("Connection"), None);
910
911        let req = Request::get(format!("http://{addr}/"))
912            .body(Body::empty())
913            .unwrap();
914        let response = client.send(req).await.unwrap();
915        assert_eq!(response.headers().get("Connection"), None);
916
917        // The first response generated after the client's max connection age has elapsed should
918        // include the `Connection: close` header.
919        tokio::time::sleep(Duration::from_secs(1)).await;
920        let req = Request::get(format!("http://{addr}/"))
921            .body(Body::empty())
922            .unwrap();
923        let response = client.send(req).await.unwrap();
924        assert_eq!(
925            response.headers().get("Connection"),
926            Some(&HeaderValue::from_static("close")),
927        );
928
929        // The next request should establish a new connection.
930        // Importantly, this also confirms that each connection has its own independent
931        // connection age timer.
932        let req = Request::get(format!("http://{addr}/"))
933            .body(Body::empty())
934            .unwrap();
935        let response = client.send(req).await.unwrap();
936        assert_eq!(response.headers().get("Connection"), None);
937    }
938}