1#![allow(missing_docs)]
2use std::{
3 collections::HashMap,
4 fmt,
5 net::SocketAddr,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use futures::future::BoxFuture;
11use headers::{Authorization, HeaderMapExt};
12use http::{
13 HeaderMap, Request, Response, Uri, Version, header::HeaderValue, request::Builder,
14 uri::InvalidUri,
15};
16use hyper::{
17 body::{Body, HttpBody},
18 client,
19 client::{Client, HttpConnector},
20};
21use hyper_openssl::HttpsConnector;
22use hyper_proxy::ProxyConnector;
23use rand::Rng;
24use serde_with::serde_as;
25use snafu::{ResultExt, Snafu};
26use tokio::time::Instant;
27use tower::{Layer, Service};
28use tower_http::{
29 classify::{ServerErrorsAsFailures, SharedClassifier},
30 trace::TraceLayer,
31};
32use tracing::{Instrument, Span};
33use vector_lib::{configurable::configurable_component, sensitive_string::SensitiveString};
34
35#[cfg(feature = "aws-core")]
36use crate::aws::AwsAuthentication;
37use crate::{
38 config::ProxyConfig,
39 internal_events::{HttpServerRequestReceived, HttpServerResponseSent, http_client},
40 tls::{MaybeTlsSettings, TlsError, tls_connector_builder},
41};
42
43pub mod status {
44 pub const FORBIDDEN: u16 = 403;
45 pub const NOT_FOUND: u16 = 404;
46 pub const TOO_MANY_REQUESTS: u16 = 429;
47}
48
49#[derive(Debug, Snafu)]
50#[snafu(visibility(pub(crate)))]
51pub enum HttpError {
52 #[snafu(display("Failed to build TLS connector: {}", source))]
53 BuildTlsConnector { source: TlsError },
54 #[snafu(display("Failed to build HTTPS connector: {}", source))]
55 MakeHttpsConnector { source: openssl::error::ErrorStack },
56 #[snafu(display("Failed to build Proxy connector: {}", source))]
57 MakeProxyConnector { source: InvalidUri },
58 #[snafu(display("Failed to make HTTP(S) request: {}", source))]
59 CallRequest { source: hyper::Error },
60 #[snafu(display("Failed to build HTTP request: {}", source))]
61 BuildRequest { source: http::Error },
62}
63
64impl HttpError {
65 pub const fn is_retriable(&self) -> bool {
66 match self {
67 HttpError::BuildRequest { .. } | HttpError::MakeProxyConnector { .. } => false,
68 HttpError::CallRequest { .. }
69 | HttpError::BuildTlsConnector { .. }
70 | HttpError::MakeHttpsConnector { .. } => true,
71 }
72 }
73}
74
75pub type HttpClientFuture = <HttpClient as Service<http::Request<Body>>>::Future;
76type HttpProxyConnector = ProxyConnector<HttpsConnector<HttpConnector>>;
77
78pub struct HttpClient<B = Body> {
79 client: Client<HttpProxyConnector, B>,
80 user_agent: HeaderValue,
81 proxy_connector: HttpProxyConnector,
82}
83
84impl<B> HttpClient<B>
85where
86 B: fmt::Debug + HttpBody + Send + 'static,
87 B::Data: Send,
88 B::Error: Into<crate::Error>,
89{
90 pub fn new(
91 tls_settings: impl Into<MaybeTlsSettings>,
92 proxy_config: &ProxyConfig,
93 ) -> Result<HttpClient<B>, HttpError> {
94 HttpClient::new_with_custom_client(tls_settings, proxy_config, &mut Client::builder())
95 }
96
97 pub fn new_with_custom_client(
98 tls_settings: impl Into<MaybeTlsSettings>,
99 proxy_config: &ProxyConfig,
100 client_builder: &mut client::Builder,
101 ) -> Result<HttpClient<B>, HttpError> {
102 let proxy_connector = build_proxy_connector(tls_settings.into(), proxy_config)?;
103 let client = client_builder.build(proxy_connector.clone());
104
105 let app_name = crate::get_app_name();
106 let version = crate::get_version();
107 let user_agent = HeaderValue::from_str(&format!("{app_name}/{version}"))
108 .expect("Invalid header value for user-agent!");
109
110 Ok(HttpClient {
111 client,
112 user_agent,
113 proxy_connector,
114 })
115 }
116
117 pub fn send(
118 &self,
119 mut request: Request<B>,
120 ) -> BoxFuture<'static, Result<http::Response<Body>, HttpError>> {
121 let span = tracing::info_span!("http");
122 let _enter = span.enter();
123
124 default_request_headers(&mut request, &self.user_agent);
125 self.maybe_add_proxy_headers(&mut request);
126
127 emit!(http_client::AboutToSendHttpRequest { request: &request });
128
129 let response = self.client.request(request);
130
131 let fut = async move {
132 let before = std::time::Instant::now();
135
136 let response_result = response.await;
138
139 let roundtrip = before.elapsed();
142
143 let response = response_result
145 .inspect_err(|error| {
146 emit!(http_client::GotHttpWarning { error, roundtrip });
148 })
149 .context(CallRequestSnafu)?;
150
151 emit!(http_client::GotHttpResponse {
153 response: &response,
154 roundtrip
155 });
156 Ok(response)
157 }
158 .instrument(span.clone().or_current());
159
160 Box::pin(fut)
161 }
162
163 fn maybe_add_proxy_headers(&self, request: &mut Request<B>) {
164 if let Some(proxy_headers) = self.proxy_connector.http_headers(request.uri()) {
165 for (k, v) in proxy_headers {
166 let request_headers = request.headers_mut();
167 if !request_headers.contains_key(k) {
168 request_headers.insert(k, v.into());
169 }
170 }
171 }
172 }
173}
174
175pub fn build_proxy_connector(
176 tls_settings: MaybeTlsSettings,
177 proxy_config: &ProxyConfig,
178) -> Result<ProxyConnector<HttpsConnector<HttpConnector>>, HttpError> {
179 let tls = tls_connector_builder(&tls_settings)
181 .context(BuildTlsConnectorSnafu)?
182 .build();
183 let https = build_tls_connector(tls_settings)?;
184 let mut proxy = ProxyConnector::new(https).unwrap();
185 proxy.set_tls(Some(tls));
188 proxy_config
189 .configure(&mut proxy)
190 .context(MakeProxyConnectorSnafu)?;
191 Ok(proxy)
192}
193
194pub fn build_tls_connector(
195 tls_settings: MaybeTlsSettings,
196) -> Result<HttpsConnector<HttpConnector>, HttpError> {
197 let mut http = HttpConnector::new();
198 http.enforce_http(false);
199
200 let tls = tls_connector_builder(&tls_settings).context(BuildTlsConnectorSnafu)?;
201 let mut https = HttpsConnector::with_connector(http, tls).context(MakeHttpsConnectorSnafu)?;
202
203 let settings = tls_settings.tls().cloned();
204 https.set_callback(move |c, _uri| {
205 if let Some(settings) = &settings {
206 settings.apply_connect_configuration(c)
207 } else {
208 Ok(())
209 }
210 });
211 Ok(https)
212}
213
214fn default_request_headers<B>(request: &mut Request<B>, user_agent: &HeaderValue) {
215 if !request.headers().contains_key("User-Agent") {
216 request
217 .headers_mut()
218 .insert("User-Agent", user_agent.clone());
219 }
220
221 if !request.headers().contains_key("Accept-Encoding") {
222 request
225 .headers_mut()
226 .insert("Accept-Encoding", HeaderValue::from_static("identity"));
227 }
228}
229
230impl<B> Service<Request<B>> for HttpClient<B>
231where
232 B: fmt::Debug + HttpBody + Send + 'static,
233 B::Data: Send,
234 B::Error: Into<crate::Error> + Send,
235{
236 type Response = http::Response<Body>;
237 type Error = HttpError;
238 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
239
240 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
241 Poll::Ready(Ok(()))
242 }
243
244 fn call(&mut self, request: Request<B>) -> Self::Future {
245 self.send(request)
246 }
247}
248
249impl<B> Clone for HttpClient<B> {
250 fn clone(&self) -> Self {
251 Self {
252 client: self.client.clone(),
253 user_agent: self.user_agent.clone(),
254 proxy_connector: self.proxy_connector.clone(),
255 }
256 }
257}
258
259impl<B> fmt::Debug for HttpClient<B> {
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 f.debug_struct("HttpClient")
262 .field("client", &self.client)
263 .field("user_agent", &self.user_agent)
264 .finish()
265 }
266}
267
268#[configurable_component]
273#[derive(Clone, Debug, Eq, PartialEq)]
274#[serde(deny_unknown_fields, rename_all = "snake_case", tag = "strategy")]
275#[configurable(metadata(docs::enum_tag_description = "The authentication strategy to use."))]
276pub enum Auth {
277 Basic {
283 #[configurable(metadata(docs::examples = "${USERNAME}"))]
285 #[configurable(metadata(docs::examples = "username"))]
286 user: String,
287
288 #[configurable(metadata(docs::examples = "${PASSWORD}"))]
290 #[configurable(metadata(docs::examples = "password"))]
291 password: SensitiveString,
292 },
293
294 Bearer {
298 token: SensitiveString,
300 },
301
302 #[cfg(feature = "aws-core")]
303 Aws {
305 auth: AwsAuthentication,
307
308 service: String,
310 },
311
312 Custom {
314 #[configurable(metadata(docs::examples = "${AUTH_HEADER_VALUE}"))]
316 #[configurable(metadata(docs::examples = "CUSTOM_PREFIX ${TOKEN}"))]
317 value: String,
318 },
319}
320
321pub trait MaybeAuth: Sized {
322 fn choose_one(&self, other: &Self) -> crate::Result<Self>;
323}
324
325impl MaybeAuth for Option<Auth> {
326 fn choose_one(&self, other: &Self) -> crate::Result<Self> {
327 if self.is_some() && other.is_some() {
328 Err("Two authorization credentials was provided.".into())
329 } else {
330 Ok(self.clone().or_else(|| other.clone()))
331 }
332 }
333}
334
335impl Auth {
336 pub fn apply<B>(&self, req: &mut Request<B>) {
337 self.apply_headers_map(req.headers_mut())
338 }
339
340 pub fn apply_builder(&self, mut builder: Builder) -> Builder {
341 if let Some(map) = builder.headers_mut() {
342 self.apply_headers_map(map)
343 }
344 builder
345 }
346
347 pub fn apply_headers_map(&self, map: &mut HeaderMap) {
348 match &self {
349 Auth::Basic { user, password } => {
350 let auth = Authorization::basic(user.as_str(), password.inner());
351 map.typed_insert(auth);
352 }
353 Auth::Bearer { token } => match Authorization::bearer(token.inner()) {
354 Ok(auth) => map.typed_insert(auth),
355 Err(error) => error!(message = "Invalid bearer token.", token = %token, %error),
356 },
357 Auth::Custom { value } => {
358 match HeaderValue::from_str(value) {
361 Ok(header_val) => {
362 map.insert(http::header::AUTHORIZATION, header_val);
363 }
364 Err(error) => {
365 error!(message = "Invalid custom auth header value.", value = %value, %error)
366 }
367 }
368 }
369 #[cfg(feature = "aws-core")]
370 _ => {}
371 }
372 }
373}
374
375pub fn get_http_scheme_from_uri(uri: &Uri) -> &'static str {
376 uri.scheme_str().map_or("http", |scheme| match scheme {
379 "http" => "http",
380 "https" => "https",
381 s => panic!("invalid URI scheme for HTTP client: {s}"),
386 })
387}
388
389pub fn build_http_trace_layer<T, U>(
393 span: Span,
394) -> TraceLayer<
395 SharedClassifier<ServerErrorsAsFailures>,
396 impl Fn(&Request<T>) -> Span + Clone,
397 impl Fn(&Request<T>, &Span) + Clone,
398 impl Fn(&Response<U>, Duration, &Span) + Clone,
399 (),
400 (),
401 (),
402> {
403 TraceLayer::new_for_http()
404 .make_span_with(move |request: &Request<T>| {
405 error_span!(
407 parent: &span,
408 "http-request",
409 method = %request.method(),
410 path = %request.uri().path(),
411 )
412 })
413 .on_request(Box::new(|_request: &Request<T>, _span: &Span| {
414 emit!(HttpServerRequestReceived);
415 }))
416 .on_response(|response: &Response<U>, latency: Duration, _span: &Span| {
417 emit!(HttpServerResponseSent { response, latency });
418 })
419 .on_failure(())
420 .on_body_chunk(())
421 .on_eos(())
422}
423
424#[serde_as]
426#[configurable_component]
427#[derive(Clone, Debug, PartialEq)]
428#[serde(deny_unknown_fields)]
429pub struct KeepaliveConfig {
430 #[serde(default = "default_max_connection_age")]
440 #[configurable(metadata(docs::examples = 600))]
441 #[configurable(metadata(docs::type_unit = "seconds"))]
442 #[configurable(metadata(docs::human_name = "Maximum Connection Age"))]
443 pub max_connection_age_secs: Option<u64>,
444
445 #[serde(default = "default_max_connection_age_jitter_factor")]
450 #[configurable(validation(range(min = 0.0, max = 1.0)))]
451 pub max_connection_age_jitter_factor: f64,
452}
453
454const fn default_max_connection_age() -> Option<u64> {
455 Some(300) }
457
458const fn default_max_connection_age_jitter_factor() -> f64 {
459 0.1
460}
461
462impl Default for KeepaliveConfig {
463 fn default() -> Self {
464 Self {
465 max_connection_age_secs: default_max_connection_age(),
466 max_connection_age_jitter_factor: default_max_connection_age_jitter_factor(),
467 }
468 }
469}
470
471pub struct MaxConnectionAgeLayer {
481 start_reference: Instant,
482 max_connection_age: Duration,
483 peer_addr: SocketAddr,
484}
485
486impl MaxConnectionAgeLayer {
487 pub fn new(max_connection_age: Duration, jitter_factor: f64, peer_addr: SocketAddr) -> Self {
488 Self {
489 start_reference: Instant::now(),
490 max_connection_age: Self::jittered_duration(max_connection_age, jitter_factor),
491 peer_addr,
492 }
493 }
494
495 fn jittered_duration(duration: Duration, jitter_factor: f64) -> Duration {
496 let jitter_factor = jitter_factor.clamp(0.0, 1.0);
498 let mut rng = rand::rng();
500 let random_jitter_factor = rng.random_range(-jitter_factor..=jitter_factor) + 1.;
501 duration.mul_f64(random_jitter_factor)
502 }
503}
504
505impl<S> Layer<S> for MaxConnectionAgeLayer
506where
507 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
508 S::Future: Send + 'static,
509{
510 type Service = MaxConnectionAgeService<S>;
511
512 fn layer(&self, service: S) -> Self::Service {
513 MaxConnectionAgeService {
514 service,
515 start_reference: self.start_reference,
516 max_connection_age: self.max_connection_age,
517 peer_addr: self.peer_addr,
518 }
519 }
520}
521
522#[derive(Clone)]
532pub struct MaxConnectionAgeService<S> {
533 service: S,
534 start_reference: Instant,
535 max_connection_age: Duration,
536 peer_addr: SocketAddr,
537}
538
539impl<S, E> Service<Request<Body>> for MaxConnectionAgeService<S>
540where
541 S: Service<Request<Body>, Response = Response<Body>, Error = E> + Clone + Send + 'static,
542 S::Future: Send + 'static,
543{
544 type Response = S::Response;
545 type Error = E;
546 type Future = BoxFuture<'static, Result<Self::Response, E>>;
547
548 fn poll_ready(
549 &mut self,
550 cx: &mut std::task::Context<'_>,
551 ) -> std::task::Poll<Result<(), Self::Error>> {
552 self.service.poll_ready(cx)
553 }
554
555 fn call(&mut self, req: Request<Body>) -> Self::Future {
556 let start_reference = self.start_reference;
557 let max_connection_age = self.max_connection_age;
558 let peer_addr = self.peer_addr;
559 let version = req.version();
560 let future = self.service.call(req);
561 Box::pin(async move {
562 let mut response = future.await?;
563 match version {
564 Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => {
565 if start_reference.elapsed() >= max_connection_age {
566 debug!(
567 message = "Closing connection due to max connection age.",
568 ?max_connection_age,
569 connection_age = ?start_reference.elapsed(),
570 ?peer_addr,
571 );
572 response.headers_mut().insert(
575 hyper::header::CONNECTION,
576 hyper::header::HeaderValue::from_static("close"),
577 );
578 }
579 }
580 Version::HTTP_2 => (),
582 Version::HTTP_3 => (),
584 _ => (),
585 }
586 Ok(response)
587 })
588 }
589}
590
591#[configurable_component]
593#[derive(Clone, Debug, Default, Eq, PartialEq)]
594#[serde(rename_all = "snake_case")]
595pub enum ParamType {
596 #[default]
598 String,
599 Vrl,
601}
602
603impl ParamType {
604 fn is_default(&self) -> bool {
605 *self == Self::default()
606 }
607}
608
609#[configurable_component]
612#[derive(Clone, Debug, Eq, PartialEq)]
613#[serde(untagged)]
614pub enum ParameterValue {
615 String(String),
617 Typed {
619 value: String,
621 #[serde(
623 default,
624 skip_serializing_if = "ParamType::is_default",
625 rename = "type"
626 )]
627 r#type: ParamType,
628 },
629}
630
631impl ParameterValue {
632 pub const fn is_vrl(&self) -> bool {
634 match self {
635 ParameterValue::String(_) => false,
636 ParameterValue::Typed { r#type, .. } => matches!(r#type, ParamType::Vrl),
637 }
638 }
639
640 #[allow(clippy::missing_const_for_fn)]
642 pub fn value(&self) -> &str {
643 match self {
644 ParameterValue::String(s) => s,
645 ParameterValue::Typed { value, .. } => value,
646 }
647 }
648
649 pub fn into_value(self) -> String {
651 match self {
652 ParameterValue::String(s) => s,
653 ParameterValue::Typed { value, .. } => value,
654 }
655 }
656}
657
658#[configurable_component]
660#[derive(Clone, Debug, Eq, PartialEq)]
661#[serde(untagged)]
662#[configurable(metadata(docs::enum_tag_description = "Query parameter value"))]
663pub enum QueryParameterValue {
664 SingleParam(ParameterValue),
666 MultiParams(Vec<ParameterValue>),
668}
669
670impl QueryParameterValue {
671 pub fn iter(&self) -> impl Iterator<Item = &ParameterValue> {
673 match self {
674 QueryParameterValue::SingleParam(param) => std::slice::from_ref(param).iter(),
675 QueryParameterValue::MultiParams(params) => params.iter(),
676 }
677 }
678
679 fn into_vec(self) -> Vec<ParameterValue> {
681 match self {
682 QueryParameterValue::SingleParam(param) => vec![param],
683 QueryParameterValue::MultiParams(params) => params,
684 }
685 }
686}
687
688impl IntoIterator for QueryParameterValue {
690 type Item = ParameterValue;
691 type IntoIter = std::vec::IntoIter<ParameterValue>;
692
693 fn into_iter(self) -> Self::IntoIter {
694 self.into_vec().into_iter()
695 }
696}
697
698pub type QueryParameters = HashMap<String, QueryParameterValue>;
699
700#[cfg(test)]
701mod tests {
702 use std::convert::Infallible;
703
704 use hyper::{Server, server::conn::AddrStream, service::make_service_fn};
705 use proptest::prelude::*;
706 use tower::ServiceBuilder;
707
708 use super::*;
709 use crate::test_util::addr::next_addr;
710
711 #[test]
712 fn test_default_request_headers_defaults() {
713 let user_agent = HeaderValue::from_static("vector");
714 let mut request = Request::post("http://example.com").body(()).unwrap();
715 default_request_headers(&mut request, &user_agent);
716 assert_eq!(
717 request.headers().get("Accept-Encoding"),
718 Some(&HeaderValue::from_static("identity")),
719 );
720 assert_eq!(request.headers().get("User-Agent"), Some(&user_agent));
721 }
722
723 #[test]
724 fn test_default_request_headers_does_not_overwrite() {
725 let mut request = Request::post("http://example.com")
726 .header("Accept-Encoding", "gzip")
727 .header("User-Agent", "foo")
728 .body(())
729 .unwrap();
730 default_request_headers(&mut request, &HeaderValue::from_static("vector"));
731 assert_eq!(
732 request.headers().get("Accept-Encoding"),
733 Some(&HeaderValue::from_static("gzip")),
734 );
735 assert_eq!(
736 request.headers().get("User-Agent"),
737 Some(&HeaderValue::from_static("foo"))
738 );
739 }
740
741 proptest! {
742 #[test]
743 fn test_jittered_duration(duration_in_secs in 0u64..120, jitter_factor in 0.0..1.0) {
744 let duration = Duration::from_secs(duration_in_secs);
745 let jittered_duration = MaxConnectionAgeLayer::jittered_duration(duration, jitter_factor);
746
747 if jitter_factor == 0.0 {
749 prop_assert_eq!(
751 jittered_duration,
752 duration,
753 "jittered_duration {:?} should be equal to duration {:?}",
754 jittered_duration,
755 duration,
756 );
757 } else if duration_in_secs > 0 {
758 let lower_bound = duration.mul_f64(1.0 - jitter_factor);
760 let upper_bound = duration.mul_f64(1.0 + jitter_factor);
761 prop_assert!(
762 jittered_duration >= lower_bound && jittered_duration <= upper_bound,
763 "jittered_duration {:?} should be between {:?} and {:?}",
764 jittered_duration,
765 lower_bound,
766 upper_bound,
767 );
768 } else {
769 prop_assert_eq!(
771 jittered_duration,
772 Duration::from_secs(0),
773 "jittered_duration {:?} should be equal to zero",
774 jittered_duration,
775 );
776 }
777 }
778 }
779
780 #[tokio::test]
781 async fn test_max_connection_age_service() {
782 tokio::time::pause();
783
784 let start_reference = Instant::now();
785 let max_connection_age = Duration::from_secs(1);
786 let mut service = MaxConnectionAgeService {
787 service: tower::service_fn(|_req: Request<Body>| async {
788 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
789 }),
790 start_reference,
791 max_connection_age,
792 peer_addr: "1.2.3.4:1234".parse().unwrap(),
793 };
794
795 let req = Request::get("http://example.com")
796 .body(Body::empty())
797 .unwrap();
798 let response = service.call(req).await.unwrap();
799 assert_eq!(response.headers().get("Connection"), None);
800
801 tokio::time::advance(Duration::from_millis(500)).await;
802 let req = Request::get("http://example.com")
803 .body(Body::empty())
804 .unwrap();
805 let response = service.call(req).await.unwrap();
806 assert_eq!(response.headers().get("Connection"), None);
807
808 tokio::time::advance(Duration::from_millis(500)).await;
809 let req = Request::get("http://example.com")
810 .body(Body::empty())
811 .unwrap();
812 let response = service.call(req).await.unwrap();
813 assert_eq!(
814 response.headers().get("Connection"),
815 Some(&HeaderValue::from_static("close"))
816 );
817 }
818
819 #[tokio::test]
820 async fn test_max_connection_age_service_http2() {
821 tokio::time::pause();
822
823 let start_reference = Instant::now();
824 let max_connection_age = Duration::from_secs(0);
825 let mut service = MaxConnectionAgeService {
826 service: tower::service_fn(|_req: Request<Body>| async {
827 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
828 }),
829 start_reference,
830 max_connection_age,
831 peer_addr: "1.2.3.4:1234".parse().unwrap(),
832 };
833
834 let mut req = Request::get("http://example.com")
835 .body(Body::empty())
836 .unwrap();
837 *req.version_mut() = Version::HTTP_2;
838 let response = service.call(req).await.unwrap();
839 assert_eq!(response.headers().get("Connection"), None);
840 }
841
842 #[tokio::test]
843 async fn test_max_connection_age_service_http3() {
844 tokio::time::pause();
845
846 let start_reference = Instant::now();
847 let max_connection_age = Duration::from_secs(0);
848 let mut service = MaxConnectionAgeService {
849 service: tower::service_fn(|_req: Request<Body>| async {
850 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
851 }),
852 start_reference,
853 max_connection_age,
854 peer_addr: "1.2.3.4:1234".parse().unwrap(),
855 };
856
857 let mut req = Request::get("http://example.com")
858 .body(Body::empty())
859 .unwrap();
860 *req.version_mut() = Version::HTTP_3;
861 let response = service.call(req).await.unwrap();
862 assert_eq!(response.headers().get("Connection"), None);
863 }
864
865 #[tokio::test]
866 async fn test_max_connection_age_service_zero_duration() {
867 tokio::time::pause();
868
869 let start_reference = Instant::now();
870 let max_connection_age = Duration::from_millis(0);
871 let mut service = MaxConnectionAgeService {
872 service: tower::service_fn(|_req: Request<Body>| async {
873 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
874 }),
875 start_reference,
876 max_connection_age,
877 peer_addr: "1.2.3.4:1234".parse().unwrap(),
878 };
879
880 let req = Request::get("http://example.com")
881 .body(Body::empty())
882 .unwrap();
883 let response = service.call(req).await.unwrap();
884 assert_eq!(
885 response.headers().get("Connection"),
886 Some(&HeaderValue::from_static("close"))
887 );
888 }
889
890 #[tokio::test]
894 async fn test_max_connection_age_service_with_hyper_server() {
895 let max_connection_age = Duration::from_secs(1);
897 let (_guard, addr) = next_addr();
898 let make_svc = make_service_fn(move |conn: &AddrStream| {
899 let svc = ServiceBuilder::new()
900 .layer(MaxConnectionAgeLayer::new(
901 max_connection_age,
902 0.,
903 conn.remote_addr(),
904 ))
905 .service(tower::service_fn(|_req: Request<Body>| async {
906 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
907 }));
908 futures_util::future::ok::<_, Infallible>(svc)
909 });
910
911 tokio::spawn(async move {
912 Server::bind(&addr).serve(make_svc).await.unwrap();
913 });
914
915 tokio::time::sleep(Duration::from_millis(10)).await;
917
918 let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
920
921 let req = Request::get(format!("http://{addr}/"))
924 .body(Body::empty())
925 .unwrap();
926 let response = client.send(req).await.unwrap();
927 assert_eq!(response.headers().get("Connection"), None);
928
929 let req = Request::get(format!("http://{addr}/"))
930 .body(Body::empty())
931 .unwrap();
932 let response = client.send(req).await.unwrap();
933 assert_eq!(response.headers().get("Connection"), None);
934
935 tokio::time::sleep(Duration::from_secs(1)).await;
938 let req = Request::get(format!("http://{addr}/"))
939 .body(Body::empty())
940 .unwrap();
941 let response = client.send(req).await.unwrap();
942 assert_eq!(
943 response.headers().get("Connection"),
944 Some(&HeaderValue::from_static("close")),
945 );
946
947 let req = Request::get(format!("http://{addr}/"))
951 .body(Body::empty())
952 .unwrap();
953 let response = client.send(req).await.unwrap();
954 assert_eq!(response.headers().get("Connection"), None);
955 }
956}