vector/
http.rs

1#![allow(missing_docs)]
2use std::{
3    collections::HashMap,
4    fmt,
5    net::SocketAddr,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use futures::future::BoxFuture;
11use headers::{Authorization, HeaderMapExt};
12use http::{
13    HeaderMap, Request, Response, Uri, Version, header::HeaderValue, request::Builder,
14    uri::InvalidUri,
15};
16use hyper::{
17    body::{Body, HttpBody},
18    client,
19    client::{Client, HttpConnector},
20};
21use hyper_openssl::HttpsConnector;
22use hyper_proxy::ProxyConnector;
23use rand::Rng;
24use serde_with::serde_as;
25use snafu::{ResultExt, Snafu};
26use tokio::time::Instant;
27use tower::{Layer, Service};
28use tower_http::{
29    classify::{ServerErrorsAsFailures, SharedClassifier},
30    trace::TraceLayer,
31};
32use tracing::{Instrument, Span};
33use vector_lib::{configurable::configurable_component, sensitive_string::SensitiveString};
34
35#[cfg(feature = "aws-core")]
36use crate::aws::AwsAuthentication;
37use crate::{
38    config::ProxyConfig,
39    internal_events::{HttpServerRequestReceived, HttpServerResponseSent, http_client},
40    tls::{MaybeTlsSettings, TlsError, tls_connector_builder},
41};
42
43pub mod status {
44    pub const FORBIDDEN: u16 = 403;
45    pub const NOT_FOUND: u16 = 404;
46    pub const TOO_MANY_REQUESTS: u16 = 429;
47}
48
49#[derive(Debug, Snafu)]
50#[snafu(visibility(pub(crate)))]
51pub enum HttpError {
52    #[snafu(display("Failed to build TLS connector: {}", source))]
53    BuildTlsConnector { source: TlsError },
54    #[snafu(display("Failed to build HTTPS connector: {}", source))]
55    MakeHttpsConnector { source: openssl::error::ErrorStack },
56    #[snafu(display("Failed to build Proxy connector: {}", source))]
57    MakeProxyConnector { source: InvalidUri },
58    #[snafu(display("Failed to make HTTP(S) request: {}", source))]
59    CallRequest { source: hyper::Error },
60    #[snafu(display("Failed to build HTTP request: {}", source))]
61    BuildRequest { source: http::Error },
62}
63
64impl HttpError {
65    pub const fn is_retriable(&self) -> bool {
66        match self {
67            HttpError::BuildRequest { .. } | HttpError::MakeProxyConnector { .. } => false,
68            HttpError::CallRequest { .. }
69            | HttpError::BuildTlsConnector { .. }
70            | HttpError::MakeHttpsConnector { .. } => true,
71        }
72    }
73}
74
75pub type HttpClientFuture = <HttpClient as Service<http::Request<Body>>>::Future;
76type HttpProxyConnector = ProxyConnector<HttpsConnector<HttpConnector>>;
77
78pub struct HttpClient<B = Body> {
79    client: Client<HttpProxyConnector, B>,
80    user_agent: HeaderValue,
81    proxy_connector: HttpProxyConnector,
82}
83
84impl<B> HttpClient<B>
85where
86    B: fmt::Debug + HttpBody + Send + 'static,
87    B::Data: Send,
88    B::Error: Into<crate::Error>,
89{
90    pub fn new(
91        tls_settings: impl Into<MaybeTlsSettings>,
92        proxy_config: &ProxyConfig,
93    ) -> Result<HttpClient<B>, HttpError> {
94        HttpClient::new_with_custom_client(tls_settings, proxy_config, &mut Client::builder())
95    }
96
97    pub fn new_with_custom_client(
98        tls_settings: impl Into<MaybeTlsSettings>,
99        proxy_config: &ProxyConfig,
100        client_builder: &mut client::Builder,
101    ) -> Result<HttpClient<B>, HttpError> {
102        let proxy_connector = build_proxy_connector(tls_settings.into(), proxy_config)?;
103        let client = client_builder.build(proxy_connector.clone());
104
105        let app_name = crate::get_app_name();
106        let version = crate::get_version();
107        let user_agent = HeaderValue::from_str(&format!("{app_name}/{version}"))
108            .expect("Invalid header value for user-agent!");
109
110        Ok(HttpClient {
111            client,
112            user_agent,
113            proxy_connector,
114        })
115    }
116
117    pub fn send(
118        &self,
119        mut request: Request<B>,
120    ) -> BoxFuture<'static, Result<http::Response<Body>, HttpError>> {
121        let span = tracing::info_span!("http");
122        let _enter = span.enter();
123
124        default_request_headers(&mut request, &self.user_agent);
125        self.maybe_add_proxy_headers(&mut request);
126
127        emit!(http_client::AboutToSendHttpRequest { request: &request });
128
129        let response = self.client.request(request);
130
131        let fut = async move {
132            // Capture the time right before we issue the request.
133            // Request doesn't start the processing until we start polling it.
134            let before = std::time::Instant::now();
135
136            // Send request and wait for the result.
137            let response_result = response.await;
138
139            // Compute the roundtrip time it took to send the request and get
140            // the response or error.
141            let roundtrip = before.elapsed();
142
143            // Handle the errors and extract the response.
144            let response = response_result
145                .inspect_err(|error| {
146                    // Emit the error into the internal events system.
147                    emit!(http_client::GotHttpWarning { error, roundtrip });
148                })
149                .context(CallRequestSnafu)?;
150
151            // Emit the response into the internal events system.
152            emit!(http_client::GotHttpResponse {
153                response: &response,
154                roundtrip
155            });
156            Ok(response)
157        }
158        .instrument(span.clone().or_current());
159
160        Box::pin(fut)
161    }
162
163    fn maybe_add_proxy_headers(&self, request: &mut Request<B>) {
164        if let Some(proxy_headers) = self.proxy_connector.http_headers(request.uri()) {
165            for (k, v) in proxy_headers {
166                let request_headers = request.headers_mut();
167                if !request_headers.contains_key(k) {
168                    request_headers.insert(k, v.into());
169                }
170            }
171        }
172    }
173}
174
175pub fn build_proxy_connector(
176    tls_settings: MaybeTlsSettings,
177    proxy_config: &ProxyConfig,
178) -> Result<ProxyConnector<HttpsConnector<HttpConnector>>, HttpError> {
179    // Create dedicated TLS connector for the proxied connection with user TLS settings.
180    let tls = tls_connector_builder(&tls_settings)
181        .context(BuildTlsConnectorSnafu)?
182        .build();
183    let https = build_tls_connector(tls_settings)?;
184    let mut proxy = ProxyConnector::new(https).unwrap();
185    // Make proxy connector aware of user TLS settings by setting the TLS connector:
186    // https://github.com/vectordotdev/vector/issues/13683
187    proxy.set_tls(Some(tls));
188    proxy_config
189        .configure(&mut proxy)
190        .context(MakeProxyConnectorSnafu)?;
191    Ok(proxy)
192}
193
194pub fn build_tls_connector(
195    tls_settings: MaybeTlsSettings,
196) -> Result<HttpsConnector<HttpConnector>, HttpError> {
197    let mut http = HttpConnector::new();
198    http.enforce_http(false);
199
200    let tls = tls_connector_builder(&tls_settings).context(BuildTlsConnectorSnafu)?;
201    let mut https = HttpsConnector::with_connector(http, tls).context(MakeHttpsConnectorSnafu)?;
202
203    let settings = tls_settings.tls().cloned();
204    https.set_callback(move |c, _uri| {
205        if let Some(settings) = &settings {
206            settings.apply_connect_configuration(c)
207        } else {
208            Ok(())
209        }
210    });
211    Ok(https)
212}
213
214fn default_request_headers<B>(request: &mut Request<B>, user_agent: &HeaderValue) {
215    if !request.headers().contains_key("User-Agent") {
216        request
217            .headers_mut()
218            .insert("User-Agent", user_agent.clone());
219    }
220
221    if !request.headers().contains_key("Accept-Encoding") {
222        // hardcoding until we support compressed responses:
223        // https://github.com/vectordotdev/vector/issues/5440
224        request
225            .headers_mut()
226            .insert("Accept-Encoding", HeaderValue::from_static("identity"));
227    }
228}
229
230impl<B> Service<Request<B>> for HttpClient<B>
231where
232    B: fmt::Debug + HttpBody + Send + 'static,
233    B::Data: Send,
234    B::Error: Into<crate::Error> + Send,
235{
236    type Response = http::Response<Body>;
237    type Error = HttpError;
238    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
239
240    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
241        Poll::Ready(Ok(()))
242    }
243
244    fn call(&mut self, request: Request<B>) -> Self::Future {
245        self.send(request)
246    }
247}
248
249impl<B> Clone for HttpClient<B> {
250    fn clone(&self) -> Self {
251        Self {
252            client: self.client.clone(),
253            user_agent: self.user_agent.clone(),
254            proxy_connector: self.proxy_connector.clone(),
255        }
256    }
257}
258
259impl<B> fmt::Debug for HttpClient<B> {
260    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261        f.debug_struct("HttpClient")
262            .field("client", &self.client)
263            .field("user_agent", &self.user_agent)
264            .finish()
265    }
266}
267
268/// Configuration of the authentication strategy for HTTP requests.
269///
270/// HTTP authentication should be used with HTTPS only, as the authentication credentials are passed as an
271/// HTTP header without any additional encryption beyond what is provided by the transport itself.
272#[configurable_component]
273#[derive(Clone, Debug, Eq, PartialEq)]
274#[serde(deny_unknown_fields, rename_all = "snake_case", tag = "strategy")]
275#[configurable(metadata(docs::enum_tag_description = "The authentication strategy to use."))]
276pub enum Auth {
277    /// Basic authentication.
278    ///
279    /// The username and password are concatenated and encoded using [base64][base64].
280    ///
281    /// [base64]: https://en.wikipedia.org/wiki/Base64
282    Basic {
283        /// The basic authentication username.
284        #[configurable(metadata(docs::examples = "${USERNAME}"))]
285        #[configurable(metadata(docs::examples = "username"))]
286        user: String,
287
288        /// The basic authentication password.
289        #[configurable(metadata(docs::examples = "${PASSWORD}"))]
290        #[configurable(metadata(docs::examples = "password"))]
291        password: SensitiveString,
292    },
293
294    /// Bearer authentication.
295    ///
296    /// The bearer token value (OAuth2, JWT, etc.) is passed as-is.
297    Bearer {
298        /// The bearer authentication token.
299        token: SensitiveString,
300    },
301
302    #[cfg(feature = "aws-core")]
303    /// AWS authentication.
304    Aws {
305        /// The AWS authentication configuration.
306        auth: AwsAuthentication,
307
308        /// The AWS service name to use for signing.
309        service: String,
310    },
311}
312
313pub trait MaybeAuth: Sized {
314    fn choose_one(&self, other: &Self) -> crate::Result<Self>;
315}
316
317impl MaybeAuth for Option<Auth> {
318    fn choose_one(&self, other: &Self) -> crate::Result<Self> {
319        if self.is_some() && other.is_some() {
320            Err("Two authorization credentials was provided.".into())
321        } else {
322            Ok(self.clone().or_else(|| other.clone()))
323        }
324    }
325}
326
327impl Auth {
328    pub fn apply<B>(&self, req: &mut Request<B>) {
329        self.apply_headers_map(req.headers_mut())
330    }
331
332    pub fn apply_builder(&self, mut builder: Builder) -> Builder {
333        if let Some(map) = builder.headers_mut() {
334            self.apply_headers_map(map)
335        }
336        builder
337    }
338
339    pub fn apply_headers_map(&self, map: &mut HeaderMap) {
340        match &self {
341            Auth::Basic { user, password } => {
342                let auth = Authorization::basic(user.as_str(), password.inner());
343                map.typed_insert(auth);
344            }
345            Auth::Bearer { token } => match Authorization::bearer(token.inner()) {
346                Ok(auth) => map.typed_insert(auth),
347                Err(error) => error!(message = "Invalid bearer token.", token = %token, %error),
348            },
349            #[cfg(feature = "aws-core")]
350            _ => {}
351        }
352    }
353}
354
355pub fn get_http_scheme_from_uri(uri: &Uri) -> &'static str {
356    // If there's no scheme, we just use "http" since it provides the most semantic relevance without inadvertently
357    // implying things it can't know i.e. returning "https" when we're not actually sure HTTPS was used.
358    uri.scheme_str().map_or("http", |scheme| match scheme {
359        "http" => "http",
360        "https" => "https",
361        // `http::Uri` ensures that we always get "http" or "https" if the URI is created with a well-formed scheme, but
362        // it also supports arbitrary schemes, which is where we bomb out down here, since we can't generate a static
363        // string for an arbitrary input string... and anything other than "http" and "https" makes no sense for an HTTP
364        // client anyways.
365        s => panic!("invalid URI scheme for HTTP client: {s}"),
366    })
367}
368
369/// Builds a [TraceLayer] configured for a HTTP server.
370///
371/// This layer emits HTTP specific telemetry for requests received, responses sent, and handler duration.
372pub fn build_http_trace_layer<T, U>(
373    span: Span,
374) -> TraceLayer<
375    SharedClassifier<ServerErrorsAsFailures>,
376    impl Fn(&Request<T>) -> Span + Clone,
377    impl Fn(&Request<T>, &Span) + Clone,
378    impl Fn(&Response<U>, Duration, &Span) + Clone,
379    (),
380    (),
381    (),
382> {
383    TraceLayer::new_for_http()
384        .make_span_with(move |request: &Request<T>| {
385            // This is an error span so that the labels are always present for metrics.
386            error_span!(
387               parent: &span,
388               "http-request",
389               method = %request.method(),
390               path = %request.uri().path(),
391            )
392        })
393        .on_request(Box::new(|_request: &Request<T>, _span: &Span| {
394            emit!(HttpServerRequestReceived);
395        }))
396        .on_response(|response: &Response<U>, latency: Duration, _span: &Span| {
397            emit!(HttpServerResponseSent { response, latency });
398        })
399        .on_failure(())
400        .on_body_chunk(())
401        .on_eos(())
402}
403
404/// Configuration of HTTP server keepalive parameters.
405#[serde_as]
406#[configurable_component]
407#[derive(Clone, Debug, PartialEq)]
408#[serde(deny_unknown_fields)]
409pub struct KeepaliveConfig {
410    /// The maximum amount of time a connection may exist before it is closed by sending
411    /// a `Connection: close` header on the HTTP response. Set this to a large value like
412    /// `100000000` to "disable" this feature
413    ///
414    ///
415    /// Only applies to HTTP/0.9, HTTP/1.0, and HTTP/1.1 requests.
416    ///
417    /// A random jitter configured by `max_connection_age_jitter_factor` is added
418    /// to the specified duration to spread out connection storms.
419    #[serde(default = "default_max_connection_age")]
420    #[configurable(metadata(docs::examples = 600))]
421    #[configurable(metadata(docs::type_unit = "seconds"))]
422    #[configurable(metadata(docs::human_name = "Maximum Connection Age"))]
423    pub max_connection_age_secs: Option<u64>,
424
425    /// The factor by which to jitter the `max_connection_age_secs` value.
426    ///
427    /// A value of 0.1 means that the actual duration will be between 90% and 110% of the
428    /// specified maximum duration.
429    #[serde(default = "default_max_connection_age_jitter_factor")]
430    #[configurable(validation(range(min = 0.0, max = 1.0)))]
431    pub max_connection_age_jitter_factor: f64,
432}
433
434const fn default_max_connection_age() -> Option<u64> {
435    Some(300) // 5 minutes
436}
437
438const fn default_max_connection_age_jitter_factor() -> f64 {
439    0.1
440}
441
442impl Default for KeepaliveConfig {
443    fn default() -> Self {
444        Self {
445            max_connection_age_secs: default_max_connection_age(),
446            max_connection_age_jitter_factor: default_max_connection_age_jitter_factor(),
447        }
448    }
449}
450
451/// A layer that limits the maximum duration of a client connection. It does so by adding a
452/// `Connection: close` header to the response if `max_connection_duration` time has elapsed
453/// since `start_reference`.
454///
455/// **Notes:**
456/// - This is intended to be used in a Hyper server (or similar) that will automatically close
457///   the connection after a response with a `Connection: close` header is sent.
458/// - This layer assumes that it is instantiated once per connection, which is true within the
459///   Hyper framework.
460pub struct MaxConnectionAgeLayer {
461    start_reference: Instant,
462    max_connection_age: Duration,
463    peer_addr: SocketAddr,
464}
465
466impl MaxConnectionAgeLayer {
467    pub fn new(max_connection_age: Duration, jitter_factor: f64, peer_addr: SocketAddr) -> Self {
468        Self {
469            start_reference: Instant::now(),
470            max_connection_age: Self::jittered_duration(max_connection_age, jitter_factor),
471            peer_addr,
472        }
473    }
474
475    fn jittered_duration(duration: Duration, jitter_factor: f64) -> Duration {
476        // Ensure the jitter_factor is between 0.0 and 1.0
477        let jitter_factor = jitter_factor.clamp(0.0, 1.0);
478        // Generate a random jitter factor between `1 - jitter_factor`` and `1 + jitter_factor`.
479        let mut rng = rand::rng();
480        let random_jitter_factor = rng.random_range(-jitter_factor..=jitter_factor) + 1.;
481        duration.mul_f64(random_jitter_factor)
482    }
483}
484
485impl<S> Layer<S> for MaxConnectionAgeLayer
486where
487    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
488    S::Future: Send + 'static,
489{
490    type Service = MaxConnectionAgeService<S>;
491
492    fn layer(&self, service: S) -> Self::Service {
493        MaxConnectionAgeService {
494            service,
495            start_reference: self.start_reference,
496            max_connection_age: self.max_connection_age,
497            peer_addr: self.peer_addr,
498        }
499    }
500}
501
502/// A service that limits the maximum age of a client connection. It does so by adding a
503/// `Connection: close` header to the response if `max_connection_age` time has elapsed
504/// since `start_reference`.
505///
506/// **Notes:**
507/// - This is intended to be used in a Hyper server (or similar) that will automatically close
508///   the connection after a response with a `Connection: close` header is sent.
509/// - This service assumes that it is instantiated once per connection, which is true within the
510///   Hyper framework.
511#[derive(Clone)]
512pub struct MaxConnectionAgeService<S> {
513    service: S,
514    start_reference: Instant,
515    max_connection_age: Duration,
516    peer_addr: SocketAddr,
517}
518
519impl<S, E> Service<Request<Body>> for MaxConnectionAgeService<S>
520where
521    S: Service<Request<Body>, Response = Response<Body>, Error = E> + Clone + Send + 'static,
522    S::Future: Send + 'static,
523{
524    type Response = S::Response;
525    type Error = E;
526    type Future = BoxFuture<'static, Result<Self::Response, E>>;
527
528    fn poll_ready(
529        &mut self,
530        cx: &mut std::task::Context<'_>,
531    ) -> std::task::Poll<Result<(), Self::Error>> {
532        self.service.poll_ready(cx)
533    }
534
535    fn call(&mut self, req: Request<Body>) -> Self::Future {
536        let start_reference = self.start_reference;
537        let max_connection_age = self.max_connection_age;
538        let peer_addr = self.peer_addr;
539        let version = req.version();
540        let future = self.service.call(req);
541        Box::pin(async move {
542            let mut response = future.await?;
543            match version {
544                Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => {
545                    if start_reference.elapsed() >= max_connection_age {
546                        debug!(
547                            message = "Closing connection due to max connection age.",
548                            ?max_connection_age,
549                            connection_age = ?start_reference.elapsed(),
550                            ?peer_addr,
551                        );
552                        // Tell the client to close this connection.
553                        // Hyper will automatically close the connection after the response is sent.
554                        response.headers_mut().insert(
555                            hyper::header::CONNECTION,
556                            hyper::header::HeaderValue::from_static("close"),
557                        );
558                    }
559                }
560                // TODO need to send GOAWAY frame
561                Version::HTTP_2 => (),
562                // TODO need to send GOAWAY frame
563                Version::HTTP_3 => (),
564                _ => (),
565            }
566            Ok(response)
567        })
568    }
569}
570
571/// The type of a query parameter's value, determines if it's treated as a plain string or a VRL expression.
572#[configurable_component]
573#[derive(Clone, Debug, Default, Eq, PartialEq)]
574#[serde(rename_all = "snake_case")]
575pub enum ParamType {
576    /// The parameter value is a plain string.
577    #[default]
578    String,
579    /// The parameter value is a VRL expression that will be evaluated before each request.
580    Vrl,
581}
582
583impl ParamType {
584    fn is_default(&self) -> bool {
585        *self == Self::default()
586    }
587}
588
589/// Represents a query parameter value, which can be a simple string or a typed object
590/// indicating whether the value is a string or a VRL expression.
591#[configurable_component]
592#[derive(Clone, Debug, Eq, PartialEq)]
593#[serde(untagged)]
594pub enum ParameterValue {
595    /// A simple string value. For backwards compatibility.
596    String(String),
597    /// A value with an explicit type.
598    Typed {
599        /// The raw value of the parameter.
600        value: String,
601        /// The type of the parameter, indicating how the `value` should be treated.
602        #[serde(
603            default,
604            skip_serializing_if = "ParamType::is_default",
605            rename = "type"
606        )]
607        r#type: ParamType,
608    },
609}
610
611impl ParameterValue {
612    /// Returns true if the parameter is a VRL expression.
613    pub const fn is_vrl(&self) -> bool {
614        match self {
615            ParameterValue::String(_) => false,
616            ParameterValue::Typed { r#type, .. } => matches!(r#type, ParamType::Vrl),
617        }
618    }
619
620    /// Returns the raw string value of the parameter.
621    #[allow(clippy::missing_const_for_fn)]
622    pub fn value(&self) -> &str {
623        match self {
624            ParameterValue::String(s) => s,
625            ParameterValue::Typed { value, .. } => value,
626        }
627    }
628
629    /// Consumes the `ParameterValue` and returns the owned raw string value.
630    pub fn into_value(self) -> String {
631        match self {
632            ParameterValue::String(s) => s,
633            ParameterValue::Typed { value, .. } => value,
634        }
635    }
636}
637
638/// Configuration of the query parameter value for HTTP requests.
639#[configurable_component]
640#[derive(Clone, Debug, Eq, PartialEq)]
641#[serde(untagged)]
642#[configurable(metadata(docs::enum_tag_description = "Query parameter value"))]
643pub enum QueryParameterValue {
644    /// Query parameter with single value
645    SingleParam(ParameterValue),
646    /// Query parameter with multiple values
647    MultiParams(Vec<ParameterValue>),
648}
649
650impl QueryParameterValue {
651    /// Returns an iterator over the contained `ParameterValue`s.
652    pub fn iter(&self) -> impl Iterator<Item = &ParameterValue> {
653        match self {
654            QueryParameterValue::SingleParam(param) => std::slice::from_ref(param).iter(),
655            QueryParameterValue::MultiParams(params) => params.iter(),
656        }
657    }
658
659    /// Convert to `Vec<ParameterValue>` for owned iteration.
660    fn into_vec(self) -> Vec<ParameterValue> {
661        match self {
662            QueryParameterValue::SingleParam(param) => vec![param],
663            QueryParameterValue::MultiParams(params) => params,
664        }
665    }
666}
667
668// Implement IntoIterator for owned QueryParameterValue
669impl IntoIterator for QueryParameterValue {
670    type Item = ParameterValue;
671    type IntoIter = std::vec::IntoIter<ParameterValue>;
672
673    fn into_iter(self) -> Self::IntoIter {
674        self.into_vec().into_iter()
675    }
676}
677
678pub type QueryParameters = HashMap<String, QueryParameterValue>;
679
680#[cfg(test)]
681mod tests {
682    use std::convert::Infallible;
683
684    use hyper::{Server, server::conn::AddrStream, service::make_service_fn};
685    use proptest::prelude::*;
686    use tower::ServiceBuilder;
687
688    use super::*;
689    use crate::test_util::next_addr;
690
691    #[test]
692    fn test_default_request_headers_defaults() {
693        let user_agent = HeaderValue::from_static("vector");
694        let mut request = Request::post("http://example.com").body(()).unwrap();
695        default_request_headers(&mut request, &user_agent);
696        assert_eq!(
697            request.headers().get("Accept-Encoding"),
698            Some(&HeaderValue::from_static("identity")),
699        );
700        assert_eq!(request.headers().get("User-Agent"), Some(&user_agent));
701    }
702
703    #[test]
704    fn test_default_request_headers_does_not_overwrite() {
705        let mut request = Request::post("http://example.com")
706            .header("Accept-Encoding", "gzip")
707            .header("User-Agent", "foo")
708            .body(())
709            .unwrap();
710        default_request_headers(&mut request, &HeaderValue::from_static("vector"));
711        assert_eq!(
712            request.headers().get("Accept-Encoding"),
713            Some(&HeaderValue::from_static("gzip")),
714        );
715        assert_eq!(
716            request.headers().get("User-Agent"),
717            Some(&HeaderValue::from_static("foo"))
718        );
719    }
720
721    proptest! {
722        #[test]
723        fn test_jittered_duration(duration_in_secs in 0u64..120, jitter_factor in 0.0..1.0) {
724            let duration = Duration::from_secs(duration_in_secs);
725            let jittered_duration = MaxConnectionAgeLayer::jittered_duration(duration, jitter_factor);
726
727            // Check properties based on the range of inputs
728            if jitter_factor == 0.0 {
729                // When jitter_factor is 0, jittered_duration should be equal to the original duration
730                prop_assert_eq!(
731                    jittered_duration,
732                    duration,
733                    "jittered_duration {:?} should be equal to duration {:?}",
734                    jittered_duration,
735                    duration,
736                );
737            } else if duration_in_secs > 0 {
738                // Check the bounds when duration is non-zero and jitter_factor is non-zero
739                let lower_bound = duration.mul_f64(1.0 - jitter_factor);
740                let upper_bound = duration.mul_f64(1.0 + jitter_factor);
741                prop_assert!(
742                    jittered_duration >= lower_bound && jittered_duration <= upper_bound,
743                    "jittered_duration {:?} should be between {:?} and {:?}",
744                    jittered_duration,
745                    lower_bound,
746                    upper_bound,
747                );
748            } else {
749                // When duration is zero, jittered_duration should also be zero
750                prop_assert_eq!(
751                    jittered_duration,
752                    Duration::from_secs(0),
753                    "jittered_duration {:?} should be equal to zero",
754                    jittered_duration,
755                );
756            }
757        }
758    }
759
760    #[tokio::test]
761    async fn test_max_connection_age_service() {
762        tokio::time::pause();
763
764        let start_reference = Instant::now();
765        let max_connection_age = Duration::from_secs(1);
766        let mut service = MaxConnectionAgeService {
767            service: tower::service_fn(|_req: Request<Body>| async {
768                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
769            }),
770            start_reference,
771            max_connection_age,
772            peer_addr: "1.2.3.4:1234".parse().unwrap(),
773        };
774
775        let req = Request::get("http://example.com")
776            .body(Body::empty())
777            .unwrap();
778        let response = service.call(req).await.unwrap();
779        assert_eq!(response.headers().get("Connection"), None);
780
781        tokio::time::advance(Duration::from_millis(500)).await;
782        let req = Request::get("http://example.com")
783            .body(Body::empty())
784            .unwrap();
785        let response = service.call(req).await.unwrap();
786        assert_eq!(response.headers().get("Connection"), None);
787
788        tokio::time::advance(Duration::from_millis(500)).await;
789        let req = Request::get("http://example.com")
790            .body(Body::empty())
791            .unwrap();
792        let response = service.call(req).await.unwrap();
793        assert_eq!(
794            response.headers().get("Connection"),
795            Some(&HeaderValue::from_static("close"))
796        );
797    }
798
799    #[tokio::test]
800    async fn test_max_connection_age_service_http2() {
801        tokio::time::pause();
802
803        let start_reference = Instant::now();
804        let max_connection_age = Duration::from_secs(0);
805        let mut service = MaxConnectionAgeService {
806            service: tower::service_fn(|_req: Request<Body>| async {
807                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
808            }),
809            start_reference,
810            max_connection_age,
811            peer_addr: "1.2.3.4:1234".parse().unwrap(),
812        };
813
814        let mut req = Request::get("http://example.com")
815            .body(Body::empty())
816            .unwrap();
817        *req.version_mut() = Version::HTTP_2;
818        let response = service.call(req).await.unwrap();
819        assert_eq!(response.headers().get("Connection"), None);
820    }
821
822    #[tokio::test]
823    async fn test_max_connection_age_service_http3() {
824        tokio::time::pause();
825
826        let start_reference = Instant::now();
827        let max_connection_age = Duration::from_secs(0);
828        let mut service = MaxConnectionAgeService {
829            service: tower::service_fn(|_req: Request<Body>| async {
830                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
831            }),
832            start_reference,
833            max_connection_age,
834            peer_addr: "1.2.3.4:1234".parse().unwrap(),
835        };
836
837        let mut req = Request::get("http://example.com")
838            .body(Body::empty())
839            .unwrap();
840        *req.version_mut() = Version::HTTP_3;
841        let response = service.call(req).await.unwrap();
842        assert_eq!(response.headers().get("Connection"), None);
843    }
844
845    #[tokio::test]
846    async fn test_max_connection_age_service_zero_duration() {
847        tokio::time::pause();
848
849        let start_reference = Instant::now();
850        let max_connection_age = Duration::from_millis(0);
851        let mut service = MaxConnectionAgeService {
852            service: tower::service_fn(|_req: Request<Body>| async {
853                Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
854            }),
855            start_reference,
856            max_connection_age,
857            peer_addr: "1.2.3.4:1234".parse().unwrap(),
858        };
859
860        let req = Request::get("http://example.com")
861            .body(Body::empty())
862            .unwrap();
863        let response = service.call(req).await.unwrap();
864        assert_eq!(
865            response.headers().get("Connection"),
866            Some(&HeaderValue::from_static("close"))
867        );
868    }
869
870    // Note that we unfortunately cannot mock the time in this test because the client calls
871    // sleep internally, which advances the clock.  However, this test shouldn't be flakey given
872    // the time bounds provided.
873    #[tokio::test]
874    async fn test_max_connection_age_service_with_hyper_server() {
875        // Create a hyper server with the max connection age layer.
876        let max_connection_age = Duration::from_secs(1);
877        let addr = next_addr();
878        let make_svc = make_service_fn(move |conn: &AddrStream| {
879            let svc = ServiceBuilder::new()
880                .layer(MaxConnectionAgeLayer::new(
881                    max_connection_age,
882                    0.,
883                    conn.remote_addr(),
884                ))
885                .service(tower::service_fn(|_req: Request<Body>| async {
886                    Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
887                }));
888            futures_util::future::ok::<_, Infallible>(svc)
889        });
890
891        tokio::spawn(async move {
892            Server::bind(&addr).serve(make_svc).await.unwrap();
893        });
894
895        // Wait for the server to start.
896        tokio::time::sleep(Duration::from_millis(10)).await;
897
898        // Create a client, which has its own connection pool.
899        let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
900
901        // Responses generated before the client's max connection age has elapsed do not
902        // include a `Connection: close` header in the response.
903        let req = Request::get(format!("http://{addr}/"))
904            .body(Body::empty())
905            .unwrap();
906        let response = client.send(req).await.unwrap();
907        assert_eq!(response.headers().get("Connection"), None);
908
909        let req = Request::get(format!("http://{addr}/"))
910            .body(Body::empty())
911            .unwrap();
912        let response = client.send(req).await.unwrap();
913        assert_eq!(response.headers().get("Connection"), None);
914
915        // The first response generated after the client's max connection age has elapsed should
916        // include the `Connection: close` header.
917        tokio::time::sleep(Duration::from_secs(1)).await;
918        let req = Request::get(format!("http://{addr}/"))
919            .body(Body::empty())
920            .unwrap();
921        let response = client.send(req).await.unwrap();
922        assert_eq!(
923            response.headers().get("Connection"),
924            Some(&HeaderValue::from_static("close")),
925        );
926
927        // The next request should establish a new connection.
928        // Importantly, this also confirms that each connection has its own independent
929        // connection age timer.
930        let req = Request::get(format!("http://{addr}/"))
931            .body(Body::empty())
932            .unwrap();
933        let response = client.send(req).await.unwrap();
934        assert_eq!(response.headers().get("Connection"), None);
935    }
936}