1#![allow(missing_docs)]
2use futures::future::BoxFuture;
3use headers::{Authorization, HeaderMapExt};
4use http::{
5 header::HeaderValue, request::Builder, uri::InvalidUri, HeaderMap, Request, Response, Uri,
6 Version,
7};
8use hyper::{
9 body::{Body, HttpBody},
10 client,
11 client::{Client, HttpConnector},
12};
13use hyper_openssl::HttpsConnector;
14use hyper_proxy::ProxyConnector;
15use rand::Rng;
16use serde_with::serde_as;
17use snafu::{ResultExt, Snafu};
18use std::{
19 collections::HashMap,
20 fmt,
21 net::SocketAddr,
22 task::{Context, Poll},
23 time::Duration,
24};
25use tokio::time::Instant;
26use tower::{Layer, Service};
27use tower_http::{
28 classify::{ServerErrorsAsFailures, SharedClassifier},
29 trace::TraceLayer,
30};
31use tracing::{Instrument, Span};
32use vector_lib::configurable::configurable_component;
33use vector_lib::sensitive_string::SensitiveString;
34
35#[cfg(feature = "aws-core")]
36use crate::aws::AwsAuthentication;
37
38use crate::{
39 config::ProxyConfig,
40 internal_events::{http_client, HttpServerRequestReceived, HttpServerResponseSent},
41 tls::{tls_connector_builder, MaybeTlsSettings, TlsError},
42};
43
44pub mod status {
45 pub const FORBIDDEN: u16 = 403;
46 pub const NOT_FOUND: u16 = 404;
47 pub const TOO_MANY_REQUESTS: u16 = 429;
48}
49
50#[derive(Debug, Snafu)]
51#[snafu(visibility(pub(crate)))]
52pub enum HttpError {
53 #[snafu(display("Failed to build TLS connector: {}", source))]
54 BuildTlsConnector { source: TlsError },
55 #[snafu(display("Failed to build HTTPS connector: {}", source))]
56 MakeHttpsConnector { source: openssl::error::ErrorStack },
57 #[snafu(display("Failed to build Proxy connector: {}", source))]
58 MakeProxyConnector { source: InvalidUri },
59 #[snafu(display("Failed to make HTTP(S) request: {}", source))]
60 CallRequest { source: hyper::Error },
61 #[snafu(display("Failed to build HTTP request: {}", source))]
62 BuildRequest { source: http::Error },
63}
64
65impl HttpError {
66 pub const fn is_retriable(&self) -> bool {
67 match self {
68 HttpError::BuildRequest { .. } | HttpError::MakeProxyConnector { .. } => false,
69 HttpError::CallRequest { .. }
70 | HttpError::BuildTlsConnector { .. }
71 | HttpError::MakeHttpsConnector { .. } => true,
72 }
73 }
74}
75
76pub type HttpClientFuture = <HttpClient as Service<http::Request<Body>>>::Future;
77type HttpProxyConnector = ProxyConnector<HttpsConnector<HttpConnector>>;
78
79pub struct HttpClient<B = Body> {
80 client: Client<HttpProxyConnector, B>,
81 user_agent: HeaderValue,
82 proxy_connector: HttpProxyConnector,
83}
84
85impl<B> HttpClient<B>
86where
87 B: fmt::Debug + HttpBody + Send + 'static,
88 B::Data: Send,
89 B::Error: Into<crate::Error>,
90{
91 pub fn new(
92 tls_settings: impl Into<MaybeTlsSettings>,
93 proxy_config: &ProxyConfig,
94 ) -> Result<HttpClient<B>, HttpError> {
95 HttpClient::new_with_custom_client(tls_settings, proxy_config, &mut Client::builder())
96 }
97
98 pub fn new_with_custom_client(
99 tls_settings: impl Into<MaybeTlsSettings>,
100 proxy_config: &ProxyConfig,
101 client_builder: &mut client::Builder,
102 ) -> Result<HttpClient<B>, HttpError> {
103 let proxy_connector = build_proxy_connector(tls_settings.into(), proxy_config)?;
104 let client = client_builder.build(proxy_connector.clone());
105
106 let app_name = crate::get_app_name();
107 let version = crate::get_version();
108 let user_agent = HeaderValue::from_str(&format!("{app_name}/{version}"))
109 .expect("Invalid header value for user-agent!");
110
111 Ok(HttpClient {
112 client,
113 user_agent,
114 proxy_connector,
115 })
116 }
117
118 pub fn send(
119 &self,
120 mut request: Request<B>,
121 ) -> BoxFuture<'static, Result<http::Response<Body>, HttpError>> {
122 let span = tracing::info_span!("http");
123 let _enter = span.enter();
124
125 default_request_headers(&mut request, &self.user_agent);
126 self.maybe_add_proxy_headers(&mut request);
127
128 emit!(http_client::AboutToSendHttpRequest { request: &request });
129
130 let response = self.client.request(request);
131
132 let fut = async move {
133 let before = std::time::Instant::now();
136
137 let response_result = response.await;
139
140 let roundtrip = before.elapsed();
143
144 let response = response_result
146 .inspect_err(|error| {
147 emit!(http_client::GotHttpWarning { error, roundtrip });
149 })
150 .context(CallRequestSnafu)?;
151
152 emit!(http_client::GotHttpResponse {
154 response: &response,
155 roundtrip
156 });
157 Ok(response)
158 }
159 .instrument(span.clone().or_current());
160
161 Box::pin(fut)
162 }
163
164 fn maybe_add_proxy_headers(&self, request: &mut Request<B>) {
165 if let Some(proxy_headers) = self.proxy_connector.http_headers(request.uri()) {
166 for (k, v) in proxy_headers {
167 let request_headers = request.headers_mut();
168 if !request_headers.contains_key(k) {
169 request_headers.insert(k, v.into());
170 }
171 }
172 }
173 }
174}
175
176pub fn build_proxy_connector(
177 tls_settings: MaybeTlsSettings,
178 proxy_config: &ProxyConfig,
179) -> Result<ProxyConnector<HttpsConnector<HttpConnector>>, HttpError> {
180 let tls = tls_connector_builder(&tls_settings)
182 .context(BuildTlsConnectorSnafu)?
183 .build();
184 let https = build_tls_connector(tls_settings)?;
185 let mut proxy = ProxyConnector::new(https).unwrap();
186 proxy.set_tls(Some(tls));
189 proxy_config
190 .configure(&mut proxy)
191 .context(MakeProxyConnectorSnafu)?;
192 Ok(proxy)
193}
194
195pub fn build_tls_connector(
196 tls_settings: MaybeTlsSettings,
197) -> Result<HttpsConnector<HttpConnector>, HttpError> {
198 let mut http = HttpConnector::new();
199 http.enforce_http(false);
200
201 let tls = tls_connector_builder(&tls_settings).context(BuildTlsConnectorSnafu)?;
202 let mut https = HttpsConnector::with_connector(http, tls).context(MakeHttpsConnectorSnafu)?;
203
204 let settings = tls_settings.tls().cloned();
205 https.set_callback(move |c, _uri| {
206 if let Some(settings) = &settings {
207 settings.apply_connect_configuration(c)
208 } else {
209 Ok(())
210 }
211 });
212 Ok(https)
213}
214
215fn default_request_headers<B>(request: &mut Request<B>, user_agent: &HeaderValue) {
216 if !request.headers().contains_key("User-Agent") {
217 request
218 .headers_mut()
219 .insert("User-Agent", user_agent.clone());
220 }
221
222 if !request.headers().contains_key("Accept-Encoding") {
223 request
226 .headers_mut()
227 .insert("Accept-Encoding", HeaderValue::from_static("identity"));
228 }
229}
230
231impl<B> Service<Request<B>> for HttpClient<B>
232where
233 B: fmt::Debug + HttpBody + Send + 'static,
234 B::Data: Send,
235 B::Error: Into<crate::Error> + Send,
236{
237 type Response = http::Response<Body>;
238 type Error = HttpError;
239 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
240
241 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
242 Poll::Ready(Ok(()))
243 }
244
245 fn call(&mut self, request: Request<B>) -> Self::Future {
246 self.send(request)
247 }
248}
249
250impl<B> Clone for HttpClient<B> {
251 fn clone(&self) -> Self {
252 Self {
253 client: self.client.clone(),
254 user_agent: self.user_agent.clone(),
255 proxy_connector: self.proxy_connector.clone(),
256 }
257 }
258}
259
260impl<B> fmt::Debug for HttpClient<B> {
261 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262 f.debug_struct("HttpClient")
263 .field("client", &self.client)
264 .field("user_agent", &self.user_agent)
265 .finish()
266 }
267}
268
269#[configurable_component]
274#[derive(Clone, Debug, Eq, PartialEq)]
275#[serde(deny_unknown_fields, rename_all = "snake_case", tag = "strategy")]
276#[configurable(metadata(docs::enum_tag_description = "The authentication strategy to use."))]
277pub enum Auth {
278 Basic {
284 #[configurable(metadata(docs::examples = "${USERNAME}"))]
286 #[configurable(metadata(docs::examples = "username"))]
287 user: String,
288
289 #[configurable(metadata(docs::examples = "${PASSWORD}"))]
291 #[configurable(metadata(docs::examples = "password"))]
292 password: SensitiveString,
293 },
294
295 Bearer {
299 token: SensitiveString,
301 },
302
303 #[cfg(feature = "aws-core")]
304 Aws {
306 auth: AwsAuthentication,
308
309 service: String,
311 },
312}
313
314pub trait MaybeAuth: Sized {
315 fn choose_one(&self, other: &Self) -> crate::Result<Self>;
316}
317
318impl MaybeAuth for Option<Auth> {
319 fn choose_one(&self, other: &Self) -> crate::Result<Self> {
320 if self.is_some() && other.is_some() {
321 Err("Two authorization credentials was provided.".into())
322 } else {
323 Ok(self.clone().or_else(|| other.clone()))
324 }
325 }
326}
327
328impl Auth {
329 pub fn apply<B>(&self, req: &mut Request<B>) {
330 self.apply_headers_map(req.headers_mut())
331 }
332
333 pub fn apply_builder(&self, mut builder: Builder) -> Builder {
334 if let Some(map) = builder.headers_mut() {
335 self.apply_headers_map(map)
336 }
337 builder
338 }
339
340 pub fn apply_headers_map(&self, map: &mut HeaderMap) {
341 match &self {
342 Auth::Basic { user, password } => {
343 let auth = Authorization::basic(user.as_str(), password.inner());
344 map.typed_insert(auth);
345 }
346 Auth::Bearer { token } => match Authorization::bearer(token.inner()) {
347 Ok(auth) => map.typed_insert(auth),
348 Err(error) => error!(message = "Invalid bearer token.", token = %token, %error),
349 },
350 #[cfg(feature = "aws-core")]
351 _ => {}
352 }
353 }
354}
355
356pub fn get_http_scheme_from_uri(uri: &Uri) -> &'static str {
357 uri.scheme_str().map_or("http", |scheme| match scheme {
360 "http" => "http",
361 "https" => "https",
362 s => panic!("invalid URI scheme for HTTP client: {s}"),
367 })
368}
369
370pub fn build_http_trace_layer<T, U>(
374 span: Span,
375) -> TraceLayer<
376 SharedClassifier<ServerErrorsAsFailures>,
377 impl Fn(&Request<T>) -> Span + Clone,
378 impl Fn(&Request<T>, &Span) + Clone,
379 impl Fn(&Response<U>, Duration, &Span) + Clone,
380 (),
381 (),
382 (),
383> {
384 TraceLayer::new_for_http()
385 .make_span_with(move |request: &Request<T>| {
386 error_span!(
388 parent: &span,
389 "http-request",
390 method = %request.method(),
391 path = %request.uri().path(),
392 )
393 })
394 .on_request(Box::new(|_request: &Request<T>, _span: &Span| {
395 emit!(HttpServerRequestReceived);
396 }))
397 .on_response(|response: &Response<U>, latency: Duration, _span: &Span| {
398 emit!(HttpServerResponseSent { response, latency });
399 })
400 .on_failure(())
401 .on_body_chunk(())
402 .on_eos(())
403}
404
405#[serde_as]
407#[configurable_component]
408#[derive(Clone, Debug, PartialEq)]
409#[serde(deny_unknown_fields)]
410pub struct KeepaliveConfig {
411 #[serde(default = "default_max_connection_age")]
421 #[configurable(metadata(docs::examples = 600))]
422 #[configurable(metadata(docs::type_unit = "seconds"))]
423 #[configurable(metadata(docs::human_name = "Maximum Connection Age"))]
424 pub max_connection_age_secs: Option<u64>,
425
426 #[serde(default = "default_max_connection_age_jitter_factor")]
431 #[configurable(validation(range(min = 0.0, max = 1.0)))]
432 pub max_connection_age_jitter_factor: f64,
433}
434
435const fn default_max_connection_age() -> Option<u64> {
436 Some(300) }
438
439const fn default_max_connection_age_jitter_factor() -> f64 {
440 0.1
441}
442
443impl Default for KeepaliveConfig {
444 fn default() -> Self {
445 Self {
446 max_connection_age_secs: default_max_connection_age(),
447 max_connection_age_jitter_factor: default_max_connection_age_jitter_factor(),
448 }
449 }
450}
451
452pub struct MaxConnectionAgeLayer {
462 start_reference: Instant,
463 max_connection_age: Duration,
464 peer_addr: SocketAddr,
465}
466
467impl MaxConnectionAgeLayer {
468 pub fn new(max_connection_age: Duration, jitter_factor: f64, peer_addr: SocketAddr) -> Self {
469 Self {
470 start_reference: Instant::now(),
471 max_connection_age: Self::jittered_duration(max_connection_age, jitter_factor),
472 peer_addr,
473 }
474 }
475
476 fn jittered_duration(duration: Duration, jitter_factor: f64) -> Duration {
477 let jitter_factor = jitter_factor.clamp(0.0, 1.0);
479 let mut rng = rand::rng();
481 let random_jitter_factor = rng.random_range(-jitter_factor..=jitter_factor) + 1.;
482 duration.mul_f64(random_jitter_factor)
483 }
484}
485
486impl<S> Layer<S> for MaxConnectionAgeLayer
487where
488 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
489 S::Future: Send + 'static,
490{
491 type Service = MaxConnectionAgeService<S>;
492
493 fn layer(&self, service: S) -> Self::Service {
494 MaxConnectionAgeService {
495 service,
496 start_reference: self.start_reference,
497 max_connection_age: self.max_connection_age,
498 peer_addr: self.peer_addr,
499 }
500 }
501}
502
503#[derive(Clone)]
513pub struct MaxConnectionAgeService<S> {
514 service: S,
515 start_reference: Instant,
516 max_connection_age: Duration,
517 peer_addr: SocketAddr,
518}
519
520impl<S, E> Service<Request<Body>> for MaxConnectionAgeService<S>
521where
522 S: Service<Request<Body>, Response = Response<Body>, Error = E> + Clone + Send + 'static,
523 S::Future: Send + 'static,
524{
525 type Response = S::Response;
526 type Error = E;
527 type Future = BoxFuture<'static, Result<Self::Response, E>>;
528
529 fn poll_ready(
530 &mut self,
531 cx: &mut std::task::Context<'_>,
532 ) -> std::task::Poll<Result<(), Self::Error>> {
533 self.service.poll_ready(cx)
534 }
535
536 fn call(&mut self, req: Request<Body>) -> Self::Future {
537 let start_reference = self.start_reference;
538 let max_connection_age = self.max_connection_age;
539 let peer_addr = self.peer_addr;
540 let version = req.version();
541 let future = self.service.call(req);
542 Box::pin(async move {
543 let mut response = future.await?;
544 match version {
545 Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => {
546 if start_reference.elapsed() >= max_connection_age {
547 debug!(
548 message = "Closing connection due to max connection age.",
549 ?max_connection_age,
550 connection_age = ?start_reference.elapsed(),
551 ?peer_addr,
552 );
553 response.headers_mut().insert(
556 hyper::header::CONNECTION,
557 hyper::header::HeaderValue::from_static("close"),
558 );
559 }
560 }
561 Version::HTTP_2 => (),
563 Version::HTTP_3 => (),
565 _ => (),
566 }
567 Ok(response)
568 })
569 }
570}
571
572#[configurable_component]
574#[derive(Clone, Debug, Default, Eq, PartialEq)]
575#[serde(rename_all = "snake_case")]
576pub enum ParamType {
577 #[default]
579 String,
580 Vrl,
582}
583
584impl ParamType {
585 fn is_default(&self) -> bool {
586 *self == Self::default()
587 }
588}
589
590#[configurable_component]
593#[derive(Clone, Debug, Eq, PartialEq)]
594#[serde(untagged)]
595pub enum ParameterValue {
596 String(String),
598 Typed {
600 value: String,
602 #[serde(
604 default,
605 skip_serializing_if = "ParamType::is_default",
606 rename = "type"
607 )]
608 r#type: ParamType,
609 },
610}
611
612impl ParameterValue {
613 pub const fn is_vrl(&self) -> bool {
615 match self {
616 ParameterValue::String(_) => false,
617 ParameterValue::Typed { r#type, .. } => matches!(r#type, ParamType::Vrl),
618 }
619 }
620
621 #[allow(clippy::missing_const_for_fn)]
623 pub fn value(&self) -> &str {
624 match self {
625 ParameterValue::String(s) => s,
626 ParameterValue::Typed { value, .. } => value,
627 }
628 }
629
630 pub fn into_value(self) -> String {
632 match self {
633 ParameterValue::String(s) => s,
634 ParameterValue::Typed { value, .. } => value,
635 }
636 }
637}
638
639#[configurable_component]
641#[derive(Clone, Debug, Eq, PartialEq)]
642#[serde(untagged)]
643#[configurable(metadata(docs::enum_tag_description = "Query parameter value"))]
644pub enum QueryParameterValue {
645 SingleParam(ParameterValue),
647 MultiParams(Vec<ParameterValue>),
649}
650
651impl QueryParameterValue {
652 pub fn iter(&self) -> impl Iterator<Item = &ParameterValue> {
654 match self {
655 QueryParameterValue::SingleParam(param) => std::slice::from_ref(param).iter(),
656 QueryParameterValue::MultiParams(params) => params.iter(),
657 }
658 }
659
660 fn into_vec(self) -> Vec<ParameterValue> {
662 match self {
663 QueryParameterValue::SingleParam(param) => vec![param],
664 QueryParameterValue::MultiParams(params) => params,
665 }
666 }
667}
668
669impl IntoIterator for QueryParameterValue {
671 type Item = ParameterValue;
672 type IntoIter = std::vec::IntoIter<ParameterValue>;
673
674 fn into_iter(self) -> Self::IntoIter {
675 self.into_vec().into_iter()
676 }
677}
678
679pub type QueryParameters = HashMap<String, QueryParameterValue>;
680
681#[cfg(test)]
682mod tests {
683 use std::convert::Infallible;
684
685 use hyper::{server::conn::AddrStream, service::make_service_fn, Server};
686 use proptest::prelude::*;
687 use tower::ServiceBuilder;
688
689 use crate::test_util::next_addr;
690
691 use super::*;
692
693 #[test]
694 fn test_default_request_headers_defaults() {
695 let user_agent = HeaderValue::from_static("vector");
696 let mut request = Request::post("http://example.com").body(()).unwrap();
697 default_request_headers(&mut request, &user_agent);
698 assert_eq!(
699 request.headers().get("Accept-Encoding"),
700 Some(&HeaderValue::from_static("identity")),
701 );
702 assert_eq!(request.headers().get("User-Agent"), Some(&user_agent));
703 }
704
705 #[test]
706 fn test_default_request_headers_does_not_overwrite() {
707 let mut request = Request::post("http://example.com")
708 .header("Accept-Encoding", "gzip")
709 .header("User-Agent", "foo")
710 .body(())
711 .unwrap();
712 default_request_headers(&mut request, &HeaderValue::from_static("vector"));
713 assert_eq!(
714 request.headers().get("Accept-Encoding"),
715 Some(&HeaderValue::from_static("gzip")),
716 );
717 assert_eq!(
718 request.headers().get("User-Agent"),
719 Some(&HeaderValue::from_static("foo"))
720 );
721 }
722
723 proptest! {
724 #[test]
725 fn test_jittered_duration(duration_in_secs in 0u64..120, jitter_factor in 0.0..1.0) {
726 let duration = Duration::from_secs(duration_in_secs);
727 let jittered_duration = MaxConnectionAgeLayer::jittered_duration(duration, jitter_factor);
728
729 if jitter_factor == 0.0 {
731 prop_assert_eq!(
733 jittered_duration,
734 duration,
735 "jittered_duration {:?} should be equal to duration {:?}",
736 jittered_duration,
737 duration,
738 );
739 } else if duration_in_secs > 0 {
740 let lower_bound = duration.mul_f64(1.0 - jitter_factor);
742 let upper_bound = duration.mul_f64(1.0 + jitter_factor);
743 prop_assert!(
744 jittered_duration >= lower_bound && jittered_duration <= upper_bound,
745 "jittered_duration {:?} should be between {:?} and {:?}",
746 jittered_duration,
747 lower_bound,
748 upper_bound,
749 );
750 } else {
751 prop_assert_eq!(
753 jittered_duration,
754 Duration::from_secs(0),
755 "jittered_duration {:?} should be equal to zero",
756 jittered_duration,
757 );
758 }
759 }
760 }
761
762 #[tokio::test]
763 async fn test_max_connection_age_service() {
764 tokio::time::pause();
765
766 let start_reference = Instant::now();
767 let max_connection_age = Duration::from_secs(1);
768 let mut service = MaxConnectionAgeService {
769 service: tower::service_fn(|_req: Request<Body>| async {
770 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
771 }),
772 start_reference,
773 max_connection_age,
774 peer_addr: "1.2.3.4:1234".parse().unwrap(),
775 };
776
777 let req = Request::get("http://example.com")
778 .body(Body::empty())
779 .unwrap();
780 let response = service.call(req).await.unwrap();
781 assert_eq!(response.headers().get("Connection"), None);
782
783 tokio::time::advance(Duration::from_millis(500)).await;
784 let req = Request::get("http://example.com")
785 .body(Body::empty())
786 .unwrap();
787 let response = service.call(req).await.unwrap();
788 assert_eq!(response.headers().get("Connection"), None);
789
790 tokio::time::advance(Duration::from_millis(500)).await;
791 let req = Request::get("http://example.com")
792 .body(Body::empty())
793 .unwrap();
794 let response = service.call(req).await.unwrap();
795 assert_eq!(
796 response.headers().get("Connection"),
797 Some(&HeaderValue::from_static("close"))
798 );
799 }
800
801 #[tokio::test]
802 async fn test_max_connection_age_service_http2() {
803 tokio::time::pause();
804
805 let start_reference = Instant::now();
806 let max_connection_age = Duration::from_secs(0);
807 let mut service = MaxConnectionAgeService {
808 service: tower::service_fn(|_req: Request<Body>| async {
809 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
810 }),
811 start_reference,
812 max_connection_age,
813 peer_addr: "1.2.3.4:1234".parse().unwrap(),
814 };
815
816 let mut req = Request::get("http://example.com")
817 .body(Body::empty())
818 .unwrap();
819 *req.version_mut() = Version::HTTP_2;
820 let response = service.call(req).await.unwrap();
821 assert_eq!(response.headers().get("Connection"), None);
822 }
823
824 #[tokio::test]
825 async fn test_max_connection_age_service_http3() {
826 tokio::time::pause();
827
828 let start_reference = Instant::now();
829 let max_connection_age = Duration::from_secs(0);
830 let mut service = MaxConnectionAgeService {
831 service: tower::service_fn(|_req: Request<Body>| async {
832 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
833 }),
834 start_reference,
835 max_connection_age,
836 peer_addr: "1.2.3.4:1234".parse().unwrap(),
837 };
838
839 let mut req = Request::get("http://example.com")
840 .body(Body::empty())
841 .unwrap();
842 *req.version_mut() = Version::HTTP_3;
843 let response = service.call(req).await.unwrap();
844 assert_eq!(response.headers().get("Connection"), None);
845 }
846
847 #[tokio::test]
848 async fn test_max_connection_age_service_zero_duration() {
849 tokio::time::pause();
850
851 let start_reference = Instant::now();
852 let max_connection_age = Duration::from_millis(0);
853 let mut service = MaxConnectionAgeService {
854 service: tower::service_fn(|_req: Request<Body>| async {
855 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
856 }),
857 start_reference,
858 max_connection_age,
859 peer_addr: "1.2.3.4:1234".parse().unwrap(),
860 };
861
862 let req = Request::get("http://example.com")
863 .body(Body::empty())
864 .unwrap();
865 let response = service.call(req).await.unwrap();
866 assert_eq!(
867 response.headers().get("Connection"),
868 Some(&HeaderValue::from_static("close"))
869 );
870 }
871
872 #[tokio::test]
876 async fn test_max_connection_age_service_with_hyper_server() {
877 let max_connection_age = Duration::from_secs(1);
879 let addr = next_addr();
880 let make_svc = make_service_fn(move |conn: &AddrStream| {
881 let svc = ServiceBuilder::new()
882 .layer(MaxConnectionAgeLayer::new(
883 max_connection_age,
884 0.,
885 conn.remote_addr(),
886 ))
887 .service(tower::service_fn(|_req: Request<Body>| async {
888 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
889 }));
890 futures_util::future::ok::<_, Infallible>(svc)
891 });
892
893 tokio::spawn(async move {
894 Server::bind(&addr).serve(make_svc).await.unwrap();
895 });
896
897 tokio::time::sleep(Duration::from_millis(10)).await;
899
900 let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
902
903 let req = Request::get(format!("http://{addr}/"))
906 .body(Body::empty())
907 .unwrap();
908 let response = client.send(req).await.unwrap();
909 assert_eq!(response.headers().get("Connection"), None);
910
911 let req = Request::get(format!("http://{addr}/"))
912 .body(Body::empty())
913 .unwrap();
914 let response = client.send(req).await.unwrap();
915 assert_eq!(response.headers().get("Connection"), None);
916
917 tokio::time::sleep(Duration::from_secs(1)).await;
920 let req = Request::get(format!("http://{addr}/"))
921 .body(Body::empty())
922 .unwrap();
923 let response = client.send(req).await.unwrap();
924 assert_eq!(
925 response.headers().get("Connection"),
926 Some(&HeaderValue::from_static("close")),
927 );
928
929 let req = Request::get(format!("http://{addr}/"))
933 .body(Body::empty())
934 .unwrap();
935 let response = client.send(req).await.unwrap();
936 assert_eq!(response.headers().get("Connection"), None);
937 }
938}