1#![allow(missing_docs)]
2use std::{
3 collections::HashMap,
4 fmt,
5 net::SocketAddr,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use futures::future::BoxFuture;
11use headers::{Authorization, HeaderMapExt};
12use http::{
13 HeaderMap, Request, Response, Uri, Version, header::HeaderValue, request::Builder,
14 uri::InvalidUri,
15};
16use hyper::{
17 body::{Body, HttpBody},
18 client,
19 client::{Client, HttpConnector},
20};
21use hyper_openssl::HttpsConnector;
22use hyper_proxy::ProxyConnector;
23use rand::Rng;
24use serde_with::serde_as;
25use snafu::{ResultExt, Snafu};
26use tokio::time::Instant;
27use tower::{Layer, Service};
28use tower_http::{
29 classify::{ServerErrorsAsFailures, SharedClassifier},
30 trace::TraceLayer,
31};
32use tracing::{Instrument, Span};
33use vector_lib::{configurable::configurable_component, sensitive_string::SensitiveString};
34
35#[cfg(feature = "aws-core")]
36use crate::aws::AwsAuthentication;
37use crate::{
38 config::ProxyConfig,
39 internal_events::{HttpServerRequestReceived, HttpServerResponseSent, http_client},
40 tls::{MaybeTlsSettings, TlsError, tls_connector_builder},
41};
42
43pub mod status {
44 pub const FORBIDDEN: u16 = 403;
45 pub const NOT_FOUND: u16 = 404;
46 pub const TOO_MANY_REQUESTS: u16 = 429;
47}
48
49#[derive(Debug, Snafu)]
50#[snafu(visibility(pub(crate)))]
51pub enum HttpError {
52 #[snafu(display("Failed to build TLS connector: {}", source))]
53 BuildTlsConnector { source: TlsError },
54 #[snafu(display("Failed to build HTTPS connector: {}", source))]
55 MakeHttpsConnector { source: openssl::error::ErrorStack },
56 #[snafu(display("Failed to build Proxy connector: {}", source))]
57 MakeProxyConnector { source: InvalidUri },
58 #[snafu(display("Failed to make HTTP(S) request: {}", source))]
59 CallRequest { source: hyper::Error },
60 #[snafu(display("Failed to build HTTP request: {}", source))]
61 BuildRequest { source: http::Error },
62}
63
64impl HttpError {
65 pub const fn is_retriable(&self) -> bool {
66 match self {
67 HttpError::BuildRequest { .. } | HttpError::MakeProxyConnector { .. } => false,
68 HttpError::CallRequest { .. }
69 | HttpError::BuildTlsConnector { .. }
70 | HttpError::MakeHttpsConnector { .. } => true,
71 }
72 }
73}
74
75pub type HttpClientFuture = <HttpClient as Service<http::Request<Body>>>::Future;
76type HttpProxyConnector = ProxyConnector<HttpsConnector<HttpConnector>>;
77
78pub struct HttpClient<B = Body> {
79 client: Client<HttpProxyConnector, B>,
80 user_agent: HeaderValue,
81 proxy_connector: HttpProxyConnector,
82}
83
84impl<B> HttpClient<B>
85where
86 B: fmt::Debug + HttpBody + Send + 'static,
87 B::Data: Send,
88 B::Error: Into<crate::Error>,
89{
90 pub fn new(
91 tls_settings: impl Into<MaybeTlsSettings>,
92 proxy_config: &ProxyConfig,
93 ) -> Result<HttpClient<B>, HttpError> {
94 HttpClient::new_with_custom_client(tls_settings, proxy_config, &mut Client::builder())
95 }
96
97 pub fn new_with_custom_client(
98 tls_settings: impl Into<MaybeTlsSettings>,
99 proxy_config: &ProxyConfig,
100 client_builder: &mut client::Builder,
101 ) -> Result<HttpClient<B>, HttpError> {
102 let proxy_connector = build_proxy_connector(tls_settings.into(), proxy_config)?;
103 let client = client_builder.build(proxy_connector.clone());
104
105 let app_name = crate::get_app_name();
106 let version = crate::get_version();
107 let user_agent = HeaderValue::from_str(&format!("{app_name}/{version}"))
108 .expect("Invalid header value for user-agent!");
109
110 Ok(HttpClient {
111 client,
112 user_agent,
113 proxy_connector,
114 })
115 }
116
117 pub fn send(
118 &self,
119 mut request: Request<B>,
120 ) -> BoxFuture<'static, Result<http::Response<Body>, HttpError>> {
121 let span = tracing::info_span!("http");
122 let _enter = span.enter();
123
124 default_request_headers(&mut request, &self.user_agent);
125 self.maybe_add_proxy_headers(&mut request);
126
127 emit!(http_client::AboutToSendHttpRequest { request: &request });
128
129 let response = self.client.request(request);
130
131 let fut = async move {
132 let before = std::time::Instant::now();
135
136 let response_result = response.await;
138
139 let roundtrip = before.elapsed();
142
143 let response = response_result
145 .inspect_err(|error| {
146 emit!(http_client::GotHttpWarning { error, roundtrip });
148 })
149 .context(CallRequestSnafu)?;
150
151 emit!(http_client::GotHttpResponse {
153 response: &response,
154 roundtrip
155 });
156 Ok(response)
157 }
158 .instrument(span.clone().or_current());
159
160 Box::pin(fut)
161 }
162
163 fn maybe_add_proxy_headers(&self, request: &mut Request<B>) {
164 if let Some(proxy_headers) = self.proxy_connector.http_headers(request.uri()) {
165 for (k, v) in proxy_headers {
166 let request_headers = request.headers_mut();
167 if !request_headers.contains_key(k) {
168 request_headers.insert(k, v.into());
169 }
170 }
171 }
172 }
173}
174
175pub fn build_proxy_connector(
176 tls_settings: MaybeTlsSettings,
177 proxy_config: &ProxyConfig,
178) -> Result<ProxyConnector<HttpsConnector<HttpConnector>>, HttpError> {
179 let tls = tls_connector_builder(&tls_settings)
181 .context(BuildTlsConnectorSnafu)?
182 .build();
183 let https = build_tls_connector(tls_settings)?;
184 let mut proxy = ProxyConnector::new(https).unwrap();
185 proxy.set_tls(Some(tls));
188 proxy_config
189 .configure(&mut proxy)
190 .context(MakeProxyConnectorSnafu)?;
191 Ok(proxy)
192}
193
194pub fn build_tls_connector(
195 tls_settings: MaybeTlsSettings,
196) -> Result<HttpsConnector<HttpConnector>, HttpError> {
197 let mut http = HttpConnector::new();
198 http.enforce_http(false);
199
200 let tls = tls_connector_builder(&tls_settings).context(BuildTlsConnectorSnafu)?;
201 let mut https = HttpsConnector::with_connector(http, tls).context(MakeHttpsConnectorSnafu)?;
202
203 let settings = tls_settings.tls().cloned();
204 https.set_callback(move |c, _uri| {
205 if let Some(settings) = &settings {
206 settings.apply_connect_configuration(c)
207 } else {
208 Ok(())
209 }
210 });
211 Ok(https)
212}
213
214fn default_request_headers<B>(request: &mut Request<B>, user_agent: &HeaderValue) {
215 if !request.headers().contains_key("User-Agent") {
216 request
217 .headers_mut()
218 .insert("User-Agent", user_agent.clone());
219 }
220
221 if !request.headers().contains_key("Accept-Encoding") {
222 request
225 .headers_mut()
226 .insert("Accept-Encoding", HeaderValue::from_static("identity"));
227 }
228}
229
230impl<B> Service<Request<B>> for HttpClient<B>
231where
232 B: fmt::Debug + HttpBody + Send + 'static,
233 B::Data: Send,
234 B::Error: Into<crate::Error> + Send,
235{
236 type Response = http::Response<Body>;
237 type Error = HttpError;
238 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
239
240 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
241 Poll::Ready(Ok(()))
242 }
243
244 fn call(&mut self, request: Request<B>) -> Self::Future {
245 self.send(request)
246 }
247}
248
249impl<B> Clone for HttpClient<B> {
250 fn clone(&self) -> Self {
251 Self {
252 client: self.client.clone(),
253 user_agent: self.user_agent.clone(),
254 proxy_connector: self.proxy_connector.clone(),
255 }
256 }
257}
258
259impl<B> fmt::Debug for HttpClient<B> {
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 f.debug_struct("HttpClient")
262 .field("client", &self.client)
263 .field("user_agent", &self.user_agent)
264 .finish()
265 }
266}
267
268#[configurable_component]
273#[derive(Clone, Debug, Eq, PartialEq)]
274#[serde(deny_unknown_fields, rename_all = "snake_case", tag = "strategy")]
275#[configurable(metadata(docs::enum_tag_description = "The authentication strategy to use."))]
276pub enum Auth {
277 Basic {
283 #[configurable(metadata(docs::examples = "${USERNAME}"))]
285 #[configurable(metadata(docs::examples = "username"))]
286 user: String,
287
288 #[configurable(metadata(docs::examples = "${PASSWORD}"))]
290 #[configurable(metadata(docs::examples = "password"))]
291 password: SensitiveString,
292 },
293
294 Bearer {
298 token: SensitiveString,
300 },
301
302 #[cfg(feature = "aws-core")]
303 Aws {
305 auth: AwsAuthentication,
307
308 service: String,
310 },
311}
312
313pub trait MaybeAuth: Sized {
314 fn choose_one(&self, other: &Self) -> crate::Result<Self>;
315}
316
317impl MaybeAuth for Option<Auth> {
318 fn choose_one(&self, other: &Self) -> crate::Result<Self> {
319 if self.is_some() && other.is_some() {
320 Err("Two authorization credentials was provided.".into())
321 } else {
322 Ok(self.clone().or_else(|| other.clone()))
323 }
324 }
325}
326
327impl Auth {
328 pub fn apply<B>(&self, req: &mut Request<B>) {
329 self.apply_headers_map(req.headers_mut())
330 }
331
332 pub fn apply_builder(&self, mut builder: Builder) -> Builder {
333 if let Some(map) = builder.headers_mut() {
334 self.apply_headers_map(map)
335 }
336 builder
337 }
338
339 pub fn apply_headers_map(&self, map: &mut HeaderMap) {
340 match &self {
341 Auth::Basic { user, password } => {
342 let auth = Authorization::basic(user.as_str(), password.inner());
343 map.typed_insert(auth);
344 }
345 Auth::Bearer { token } => match Authorization::bearer(token.inner()) {
346 Ok(auth) => map.typed_insert(auth),
347 Err(error) => error!(message = "Invalid bearer token.", token = %token, %error),
348 },
349 #[cfg(feature = "aws-core")]
350 _ => {}
351 }
352 }
353}
354
355pub fn get_http_scheme_from_uri(uri: &Uri) -> &'static str {
356 uri.scheme_str().map_or("http", |scheme| match scheme {
359 "http" => "http",
360 "https" => "https",
361 s => panic!("invalid URI scheme for HTTP client: {s}"),
366 })
367}
368
369pub fn build_http_trace_layer<T, U>(
373 span: Span,
374) -> TraceLayer<
375 SharedClassifier<ServerErrorsAsFailures>,
376 impl Fn(&Request<T>) -> Span + Clone,
377 impl Fn(&Request<T>, &Span) + Clone,
378 impl Fn(&Response<U>, Duration, &Span) + Clone,
379 (),
380 (),
381 (),
382> {
383 TraceLayer::new_for_http()
384 .make_span_with(move |request: &Request<T>| {
385 error_span!(
387 parent: &span,
388 "http-request",
389 method = %request.method(),
390 path = %request.uri().path(),
391 )
392 })
393 .on_request(Box::new(|_request: &Request<T>, _span: &Span| {
394 emit!(HttpServerRequestReceived);
395 }))
396 .on_response(|response: &Response<U>, latency: Duration, _span: &Span| {
397 emit!(HttpServerResponseSent { response, latency });
398 })
399 .on_failure(())
400 .on_body_chunk(())
401 .on_eos(())
402}
403
404#[serde_as]
406#[configurable_component]
407#[derive(Clone, Debug, PartialEq)]
408#[serde(deny_unknown_fields)]
409pub struct KeepaliveConfig {
410 #[serde(default = "default_max_connection_age")]
420 #[configurable(metadata(docs::examples = 600))]
421 #[configurable(metadata(docs::type_unit = "seconds"))]
422 #[configurable(metadata(docs::human_name = "Maximum Connection Age"))]
423 pub max_connection_age_secs: Option<u64>,
424
425 #[serde(default = "default_max_connection_age_jitter_factor")]
430 #[configurable(validation(range(min = 0.0, max = 1.0)))]
431 pub max_connection_age_jitter_factor: f64,
432}
433
434const fn default_max_connection_age() -> Option<u64> {
435 Some(300) }
437
438const fn default_max_connection_age_jitter_factor() -> f64 {
439 0.1
440}
441
442impl Default for KeepaliveConfig {
443 fn default() -> Self {
444 Self {
445 max_connection_age_secs: default_max_connection_age(),
446 max_connection_age_jitter_factor: default_max_connection_age_jitter_factor(),
447 }
448 }
449}
450
451pub struct MaxConnectionAgeLayer {
461 start_reference: Instant,
462 max_connection_age: Duration,
463 peer_addr: SocketAddr,
464}
465
466impl MaxConnectionAgeLayer {
467 pub fn new(max_connection_age: Duration, jitter_factor: f64, peer_addr: SocketAddr) -> Self {
468 Self {
469 start_reference: Instant::now(),
470 max_connection_age: Self::jittered_duration(max_connection_age, jitter_factor),
471 peer_addr,
472 }
473 }
474
475 fn jittered_duration(duration: Duration, jitter_factor: f64) -> Duration {
476 let jitter_factor = jitter_factor.clamp(0.0, 1.0);
478 let mut rng = rand::rng();
480 let random_jitter_factor = rng.random_range(-jitter_factor..=jitter_factor) + 1.;
481 duration.mul_f64(random_jitter_factor)
482 }
483}
484
485impl<S> Layer<S> for MaxConnectionAgeLayer
486where
487 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
488 S::Future: Send + 'static,
489{
490 type Service = MaxConnectionAgeService<S>;
491
492 fn layer(&self, service: S) -> Self::Service {
493 MaxConnectionAgeService {
494 service,
495 start_reference: self.start_reference,
496 max_connection_age: self.max_connection_age,
497 peer_addr: self.peer_addr,
498 }
499 }
500}
501
502#[derive(Clone)]
512pub struct MaxConnectionAgeService<S> {
513 service: S,
514 start_reference: Instant,
515 max_connection_age: Duration,
516 peer_addr: SocketAddr,
517}
518
519impl<S, E> Service<Request<Body>> for MaxConnectionAgeService<S>
520where
521 S: Service<Request<Body>, Response = Response<Body>, Error = E> + Clone + Send + 'static,
522 S::Future: Send + 'static,
523{
524 type Response = S::Response;
525 type Error = E;
526 type Future = BoxFuture<'static, Result<Self::Response, E>>;
527
528 fn poll_ready(
529 &mut self,
530 cx: &mut std::task::Context<'_>,
531 ) -> std::task::Poll<Result<(), Self::Error>> {
532 self.service.poll_ready(cx)
533 }
534
535 fn call(&mut self, req: Request<Body>) -> Self::Future {
536 let start_reference = self.start_reference;
537 let max_connection_age = self.max_connection_age;
538 let peer_addr = self.peer_addr;
539 let version = req.version();
540 let future = self.service.call(req);
541 Box::pin(async move {
542 let mut response = future.await?;
543 match version {
544 Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => {
545 if start_reference.elapsed() >= max_connection_age {
546 debug!(
547 message = "Closing connection due to max connection age.",
548 ?max_connection_age,
549 connection_age = ?start_reference.elapsed(),
550 ?peer_addr,
551 );
552 response.headers_mut().insert(
555 hyper::header::CONNECTION,
556 hyper::header::HeaderValue::from_static("close"),
557 );
558 }
559 }
560 Version::HTTP_2 => (),
562 Version::HTTP_3 => (),
564 _ => (),
565 }
566 Ok(response)
567 })
568 }
569}
570
571#[configurable_component]
573#[derive(Clone, Debug, Default, Eq, PartialEq)]
574#[serde(rename_all = "snake_case")]
575pub enum ParamType {
576 #[default]
578 String,
579 Vrl,
581}
582
583impl ParamType {
584 fn is_default(&self) -> bool {
585 *self == Self::default()
586 }
587}
588
589#[configurable_component]
592#[derive(Clone, Debug, Eq, PartialEq)]
593#[serde(untagged)]
594pub enum ParameterValue {
595 String(String),
597 Typed {
599 value: String,
601 #[serde(
603 default,
604 skip_serializing_if = "ParamType::is_default",
605 rename = "type"
606 )]
607 r#type: ParamType,
608 },
609}
610
611impl ParameterValue {
612 pub const fn is_vrl(&self) -> bool {
614 match self {
615 ParameterValue::String(_) => false,
616 ParameterValue::Typed { r#type, .. } => matches!(r#type, ParamType::Vrl),
617 }
618 }
619
620 #[allow(clippy::missing_const_for_fn)]
622 pub fn value(&self) -> &str {
623 match self {
624 ParameterValue::String(s) => s,
625 ParameterValue::Typed { value, .. } => value,
626 }
627 }
628
629 pub fn into_value(self) -> String {
631 match self {
632 ParameterValue::String(s) => s,
633 ParameterValue::Typed { value, .. } => value,
634 }
635 }
636}
637
638#[configurable_component]
640#[derive(Clone, Debug, Eq, PartialEq)]
641#[serde(untagged)]
642#[configurable(metadata(docs::enum_tag_description = "Query parameter value"))]
643pub enum QueryParameterValue {
644 SingleParam(ParameterValue),
646 MultiParams(Vec<ParameterValue>),
648}
649
650impl QueryParameterValue {
651 pub fn iter(&self) -> impl Iterator<Item = &ParameterValue> {
653 match self {
654 QueryParameterValue::SingleParam(param) => std::slice::from_ref(param).iter(),
655 QueryParameterValue::MultiParams(params) => params.iter(),
656 }
657 }
658
659 fn into_vec(self) -> Vec<ParameterValue> {
661 match self {
662 QueryParameterValue::SingleParam(param) => vec![param],
663 QueryParameterValue::MultiParams(params) => params,
664 }
665 }
666}
667
668impl IntoIterator for QueryParameterValue {
670 type Item = ParameterValue;
671 type IntoIter = std::vec::IntoIter<ParameterValue>;
672
673 fn into_iter(self) -> Self::IntoIter {
674 self.into_vec().into_iter()
675 }
676}
677
678pub type QueryParameters = HashMap<String, QueryParameterValue>;
679
680#[cfg(test)]
681mod tests {
682 use std::convert::Infallible;
683
684 use hyper::{Server, server::conn::AddrStream, service::make_service_fn};
685 use proptest::prelude::*;
686 use tower::ServiceBuilder;
687
688 use super::*;
689 use crate::test_util::next_addr;
690
691 #[test]
692 fn test_default_request_headers_defaults() {
693 let user_agent = HeaderValue::from_static("vector");
694 let mut request = Request::post("http://example.com").body(()).unwrap();
695 default_request_headers(&mut request, &user_agent);
696 assert_eq!(
697 request.headers().get("Accept-Encoding"),
698 Some(&HeaderValue::from_static("identity")),
699 );
700 assert_eq!(request.headers().get("User-Agent"), Some(&user_agent));
701 }
702
703 #[test]
704 fn test_default_request_headers_does_not_overwrite() {
705 let mut request = Request::post("http://example.com")
706 .header("Accept-Encoding", "gzip")
707 .header("User-Agent", "foo")
708 .body(())
709 .unwrap();
710 default_request_headers(&mut request, &HeaderValue::from_static("vector"));
711 assert_eq!(
712 request.headers().get("Accept-Encoding"),
713 Some(&HeaderValue::from_static("gzip")),
714 );
715 assert_eq!(
716 request.headers().get("User-Agent"),
717 Some(&HeaderValue::from_static("foo"))
718 );
719 }
720
721 proptest! {
722 #[test]
723 fn test_jittered_duration(duration_in_secs in 0u64..120, jitter_factor in 0.0..1.0) {
724 let duration = Duration::from_secs(duration_in_secs);
725 let jittered_duration = MaxConnectionAgeLayer::jittered_duration(duration, jitter_factor);
726
727 if jitter_factor == 0.0 {
729 prop_assert_eq!(
731 jittered_duration,
732 duration,
733 "jittered_duration {:?} should be equal to duration {:?}",
734 jittered_duration,
735 duration,
736 );
737 } else if duration_in_secs > 0 {
738 let lower_bound = duration.mul_f64(1.0 - jitter_factor);
740 let upper_bound = duration.mul_f64(1.0 + jitter_factor);
741 prop_assert!(
742 jittered_duration >= lower_bound && jittered_duration <= upper_bound,
743 "jittered_duration {:?} should be between {:?} and {:?}",
744 jittered_duration,
745 lower_bound,
746 upper_bound,
747 );
748 } else {
749 prop_assert_eq!(
751 jittered_duration,
752 Duration::from_secs(0),
753 "jittered_duration {:?} should be equal to zero",
754 jittered_duration,
755 );
756 }
757 }
758 }
759
760 #[tokio::test]
761 async fn test_max_connection_age_service() {
762 tokio::time::pause();
763
764 let start_reference = Instant::now();
765 let max_connection_age = Duration::from_secs(1);
766 let mut service = MaxConnectionAgeService {
767 service: tower::service_fn(|_req: Request<Body>| async {
768 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
769 }),
770 start_reference,
771 max_connection_age,
772 peer_addr: "1.2.3.4:1234".parse().unwrap(),
773 };
774
775 let req = Request::get("http://example.com")
776 .body(Body::empty())
777 .unwrap();
778 let response = service.call(req).await.unwrap();
779 assert_eq!(response.headers().get("Connection"), None);
780
781 tokio::time::advance(Duration::from_millis(500)).await;
782 let req = Request::get("http://example.com")
783 .body(Body::empty())
784 .unwrap();
785 let response = service.call(req).await.unwrap();
786 assert_eq!(response.headers().get("Connection"), None);
787
788 tokio::time::advance(Duration::from_millis(500)).await;
789 let req = Request::get("http://example.com")
790 .body(Body::empty())
791 .unwrap();
792 let response = service.call(req).await.unwrap();
793 assert_eq!(
794 response.headers().get("Connection"),
795 Some(&HeaderValue::from_static("close"))
796 );
797 }
798
799 #[tokio::test]
800 async fn test_max_connection_age_service_http2() {
801 tokio::time::pause();
802
803 let start_reference = Instant::now();
804 let max_connection_age = Duration::from_secs(0);
805 let mut service = MaxConnectionAgeService {
806 service: tower::service_fn(|_req: Request<Body>| async {
807 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
808 }),
809 start_reference,
810 max_connection_age,
811 peer_addr: "1.2.3.4:1234".parse().unwrap(),
812 };
813
814 let mut req = Request::get("http://example.com")
815 .body(Body::empty())
816 .unwrap();
817 *req.version_mut() = Version::HTTP_2;
818 let response = service.call(req).await.unwrap();
819 assert_eq!(response.headers().get("Connection"), None);
820 }
821
822 #[tokio::test]
823 async fn test_max_connection_age_service_http3() {
824 tokio::time::pause();
825
826 let start_reference = Instant::now();
827 let max_connection_age = Duration::from_secs(0);
828 let mut service = MaxConnectionAgeService {
829 service: tower::service_fn(|_req: Request<Body>| async {
830 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
831 }),
832 start_reference,
833 max_connection_age,
834 peer_addr: "1.2.3.4:1234".parse().unwrap(),
835 };
836
837 let mut req = Request::get("http://example.com")
838 .body(Body::empty())
839 .unwrap();
840 *req.version_mut() = Version::HTTP_3;
841 let response = service.call(req).await.unwrap();
842 assert_eq!(response.headers().get("Connection"), None);
843 }
844
845 #[tokio::test]
846 async fn test_max_connection_age_service_zero_duration() {
847 tokio::time::pause();
848
849 let start_reference = Instant::now();
850 let max_connection_age = Duration::from_millis(0);
851 let mut service = MaxConnectionAgeService {
852 service: tower::service_fn(|_req: Request<Body>| async {
853 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
854 }),
855 start_reference,
856 max_connection_age,
857 peer_addr: "1.2.3.4:1234".parse().unwrap(),
858 };
859
860 let req = Request::get("http://example.com")
861 .body(Body::empty())
862 .unwrap();
863 let response = service.call(req).await.unwrap();
864 assert_eq!(
865 response.headers().get("Connection"),
866 Some(&HeaderValue::from_static("close"))
867 );
868 }
869
870 #[tokio::test]
874 async fn test_max_connection_age_service_with_hyper_server() {
875 let max_connection_age = Duration::from_secs(1);
877 let addr = next_addr();
878 let make_svc = make_service_fn(move |conn: &AddrStream| {
879 let svc = ServiceBuilder::new()
880 .layer(MaxConnectionAgeLayer::new(
881 max_connection_age,
882 0.,
883 conn.remote_addr(),
884 ))
885 .service(tower::service_fn(|_req: Request<Body>| async {
886 Ok::<Response<Body>, hyper::Error>(Response::new(Body::empty()))
887 }));
888 futures_util::future::ok::<_, Infallible>(svc)
889 });
890
891 tokio::spawn(async move {
892 Server::bind(&addr).serve(make_svc).await.unwrap();
893 });
894
895 tokio::time::sleep(Duration::from_millis(10)).await;
897
898 let client = HttpClient::new(None, &ProxyConfig::default()).unwrap();
900
901 let req = Request::get(format!("http://{addr}/"))
904 .body(Body::empty())
905 .unwrap();
906 let response = client.send(req).await.unwrap();
907 assert_eq!(response.headers().get("Connection"), None);
908
909 let req = Request::get(format!("http://{addr}/"))
910 .body(Body::empty())
911 .unwrap();
912 let response = client.send(req).await.unwrap();
913 assert_eq!(response.headers().get("Connection"), None);
914
915 tokio::time::sleep(Duration::from_secs(1)).await;
918 let req = Request::get(format!("http://{addr}/"))
919 .body(Body::empty())
920 .unwrap();
921 let response = client.send(req).await.unwrap();
922 assert_eq!(
923 response.headers().get("Connection"),
924 Some(&HeaderValue::from_static("close")),
925 );
926
927 let req = Request::get(format!("http://{addr}/"))
931 .body(Body::empty())
932 .unwrap();
933 let response = client.send(req).await.unwrap();
934 assert_eq!(response.headers().get("Connection"), None);
935 }
936}