1pub mod auth;
3pub mod region;
4pub mod timeout;
5
6use std::{
7 error::Error,
8 pin::Pin,
9 sync::{
10 Arc, OnceLock,
11 atomic::{AtomicUsize, Ordering},
12 },
13 task::{Context, Poll},
14 time::{Duration, SystemTime},
15};
16
17pub use auth::{AwsAuthentication, ImdsAuthentication};
18use aws_config::{
19 Region, SdkConfig, meta::region::ProvideRegion, retry::RetryConfig, timeout::TimeoutConfig,
20};
21use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
22use aws_sigv4::{
23 http_request::{PayloadChecksumKind, SignableBody, SignableRequest, SigningSettings},
24 sign::v4,
25};
26use aws_smithy_async::rt::sleep::TokioSleep;
27use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder;
28use aws_smithy_runtime_api::client::{
29 http::{
30 HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
31 },
32 identity::Identity,
33 orchestrator::{HttpRequest, HttpResponse},
34 result::SdkError,
35 runtime_components::RuntimeComponents,
36};
37use aws_smithy_types::body::SdkBody;
38use aws_types::sdk_config::SharedHttpClient;
39use bytes::Bytes;
40use futures_util::FutureExt;
41use http::HeaderMap;
42use http_body::{Body, combinators::BoxBody};
43use pin_project::pin_project;
44use regex::RegexSet;
45pub use region::RegionOrEndpoint;
46use snafu::Snafu;
47pub use timeout::AwsTimeout;
48
49use crate::{
50 config::ProxyConfig,
51 http::{build_proxy_connector, build_tls_connector, status},
52 internal_events::AwsBytesSent,
53 tls::{MaybeTlsSettings, TlsConfig},
54};
55
56static RETRIABLE_CODES: OnceLock<RegexSet> = OnceLock::new();
57
58pub fn is_retriable_error<T>(error: &SdkError<T, HttpResponse>) -> bool {
60 match error {
61 SdkError::TimeoutError(_) | SdkError::DispatchFailure(_) => true,
62 SdkError::ConstructionFailure(_) => false,
63 SdkError::ResponseError(err) => check_response(err.raw()),
64 SdkError::ServiceError(err) => check_response(err.raw()),
65 _ => {
66 warn!("AWS returned unknown error, retrying request.");
67 true
68 }
69 }
70}
71
72fn check_response(res: &HttpResponse) -> bool {
73 let retry_header = res.headers().get("x-amz-retry-after").is_some();
77
78 let re = RETRIABLE_CODES.get_or_init(|| {
92 RegexSet::new(["RequestTimeout", "RequestExpired", "ThrottlingException"])
93 .expect("invalid regex")
94 });
95
96 let status = res.status();
97 let response_body = String::from_utf8_lossy(res.body().bytes().unwrap_or(&[]));
98
99 retry_header
100 || status.is_server_error()
101 || status.as_u16() == status::TOO_MANY_REQUESTS
102 || (status.is_client_error() && re.is_match(response_body.as_ref()))
103}
104
105fn connector(
109 proxy: &ProxyConfig,
110 tls_options: Option<&TlsConfig>,
111) -> crate::Result<SharedHttpClient> {
112 let tls_settings = MaybeTlsSettings::tls_client(tls_options)?;
113
114 if proxy.enabled {
115 let proxy = build_proxy_connector(tls_settings, proxy)?;
116 Ok(HyperClientBuilder::new().build(proxy))
117 } else {
118 let tls_connector = build_tls_connector(tls_settings)?;
119 Ok(HyperClientBuilder::new().build(tls_connector))
120 }
121}
122
123pub trait ClientBuilder {
125 type Client;
127
128 fn build(&self, config: &SdkConfig) -> Self::Client;
130}
131
132pub fn region_provider(
134 proxy: &ProxyConfig,
135 tls_options: Option<&TlsConfig>,
136) -> crate::Result<impl ProvideRegion + use<>> {
137 let config = aws_config::provider_config::ProviderConfig::default()
138 .with_http_client(connector(proxy, tls_options)?);
139
140 Ok(aws_config::meta::region::RegionProviderChain::first_try(
141 aws_config::environment::EnvironmentVariableRegionProvider::new(),
142 )
143 .or_else(aws_config::profile::ProfileFileRegionProvider::builder().build())
144 .or_else(
145 aws_config::imds::region::ImdsRegionProvider::builder()
146 .configure(&config)
147 .build(),
148 ))
149}
150
151async fn resolve_region(
152 proxy: &ProxyConfig,
153 tls_options: Option<&TlsConfig>,
154 region: Option<Region>,
155) -> crate::Result<Region> {
156 match region {
157 Some(region) => Ok(region),
158 None => region_provider(proxy, tls_options)?
159 .region()
160 .await
161 .ok_or_else(|| {
162 "Could not determine region from Vector configuration or default providers".into()
163 }),
164 }
165}
166
167pub async fn create_client<T>(
169 builder: &T,
170 auth: &AwsAuthentication,
171 region: Option<Region>,
172 endpoint: Option<String>,
173 proxy: &ProxyConfig,
174 tls_options: Option<&TlsConfig>,
175 timeout: Option<&AwsTimeout>,
176) -> crate::Result<T::Client>
177where
178 T: ClientBuilder,
179{
180 create_client_and_region::<T>(builder, auth, region, endpoint, proxy, tls_options, timeout)
181 .await
182 .map(|(client, _)| client)
183}
184
185pub async fn create_client_and_region<T>(
187 builder: &T,
188 auth: &AwsAuthentication,
189 region: Option<Region>,
190 endpoint: Option<String>,
191 proxy: &ProxyConfig,
192 tls_options: Option<&TlsConfig>,
193 timeout: Option<&AwsTimeout>,
194) -> crate::Result<(T::Client, Region)>
195where
196 T: ClientBuilder,
197{
198 let retry_config = RetryConfig::disabled();
199
200 let region = resolve_region(proxy, tls_options, region).await?;
203
204 let provider_config =
205 aws_config::provider_config::ProviderConfig::empty().with_region(Some(region.clone()));
206
207 let connector = connector(proxy, tls_options)?;
208
209 let connector = AwsHttpClient {
211 http: connector,
212 region: region.clone(),
213 };
214
215 let mut config_builder = SdkConfig::builder()
217 .http_client(connector)
218 .sleep_impl(Arc::new(TokioSleep::new()))
219 .identity_cache(auth.credentials_cache().await?)
220 .credentials_provider(
221 auth.credentials_provider(region.clone(), proxy, tls_options)
222 .await?,
223 )
224 .region(region.clone())
225 .retry_config(retry_config.clone());
226
227 if let Some(endpoint_override) = endpoint {
228 config_builder = config_builder.endpoint_url(endpoint_override);
229 } else if let Some(endpoint_from_config) =
230 aws_config::default_provider::endpoint_url::endpoint_url_provider(&provider_config).await
231 {
232 config_builder = config_builder.endpoint_url(endpoint_from_config);
233 }
234
235 if let Some(use_fips) =
236 aws_config::default_provider::use_fips::use_fips_provider(&provider_config).await
237 {
238 config_builder = config_builder.use_fips(use_fips);
239 }
240
241 if let Some(timeout) = timeout {
242 let mut timeout_config_builder = TimeoutConfig::builder();
243
244 let operation_timeout = timeout.operation_timeout();
245 let connect_timeout = timeout.connect_timeout();
246 let read_timeout = timeout.read_timeout();
247
248 timeout_config_builder
249 .set_operation_timeout(operation_timeout.map(Duration::from_secs))
250 .set_connect_timeout(connect_timeout.map(Duration::from_secs))
251 .set_read_timeout(read_timeout.map(Duration::from_secs));
252
253 config_builder = config_builder.timeout_config(timeout_config_builder.build());
254 }
255
256 let config = config_builder.build();
257
258 Ok((T::build(builder, &config), region))
259}
260
261#[derive(Snafu, Debug)]
262enum SigningError {
263 #[snafu(display("cannot sign the request because the headers are not valid utf-8"))]
264 NotUTF8Header,
265}
266
267pub async fn sign_request(
270 service_name: &str,
271 request: &mut http::Request<Bytes>,
272 credentials_provider: &SharedCredentialsProvider,
273 region: Option<&Region>,
274 payload_checksum_sha256: bool,
275) -> crate::Result<()> {
276 let headers = request
277 .headers()
278 .iter()
279 .map(|(k, v)| {
280 Ok((
281 k.as_str(),
282 std::str::from_utf8(v.as_bytes()).map_err(|_| SigningError::NotUTF8Header)?,
283 ))
284 })
285 .collect::<Result<Vec<_>, SigningError>>()?;
286
287 let signable_request = SignableRequest::new(
288 request.method().as_str(),
289 request.uri().to_string(),
290 headers.into_iter(),
291 SignableBody::Bytes(request.body().as_ref()),
292 )?;
293
294 let credentials = credentials_provider.provide_credentials().await?;
295 let identity = Identity::new(credentials, None);
296
297 let mut signing_settings = SigningSettings::default();
298
299 if payload_checksum_sha256 {
302 signing_settings.payload_checksum_kind = PayloadChecksumKind::XAmzSha256;
303 }
304
305 let signing_params_builder = v4::SigningParams::builder()
306 .identity(&identity)
307 .region(region.as_ref().map(|r| r.as_ref()).unwrap_or(""))
308 .name(service_name)
309 .time(SystemTime::now())
310 .settings(signing_settings);
311
312 let signing_params = signing_params_builder
313 .build()
314 .expect("all signing params set");
315
316 let (signing_instructions, _signature) =
317 aws_sigv4::http_request::sign(signable_request, &signing_params.into())?.into_parts();
318 signing_instructions.apply_to_request_http0x(request);
319
320 Ok(())
321}
322
323#[derive(Debug)]
324struct AwsHttpClient<T> {
325 http: T,
326 region: Region,
327}
328
329impl<T> HttpClient for AwsHttpClient<T>
330where
331 T: HttpClient,
332{
333 fn http_connector(
334 &self,
335 settings: &HttpConnectorSettings,
336 components: &RuntimeComponents,
337 ) -> SharedHttpConnector {
338 let http_connector = self.http.http_connector(settings, components);
339
340 SharedHttpConnector::new(AwsConnector {
341 region: self.region.clone(),
342 http: http_connector,
343 })
344 }
345}
346
347#[derive(Clone, Debug)]
348struct AwsConnector<T> {
349 http: T,
350 region: Region,
351}
352
353impl<T> HttpConnector for AwsConnector<T>
354where
355 T: HttpConnector,
356{
357 fn call(&self, req: HttpRequest) -> HttpConnectorFuture {
358 let bytes_sent = Arc::new(std::sync::atomic::AtomicUsize::new(0));
359 let req = req.map(|body| {
360 let bytes_sent = Arc::clone(&bytes_sent);
361 body.map_preserve_contents(move |body| {
362 let body = MeasuredBody::new(body, Arc::clone(&bytes_sent));
363 SdkBody::from_body_0_4(BoxBody::new(body))
364 })
365 });
366
367 let fut = self.http.call(req);
368 let region = self.region.clone();
369
370 HttpConnectorFuture::new(fut.inspect(move |result| {
371 let byte_size = bytes_sent.load(Ordering::Relaxed);
372 if let Ok(result) = result
373 && result.status().is_success()
374 {
375 emit!(AwsBytesSent {
376 byte_size,
377 region: Some(region),
378 });
379 }
380 }))
381 }
382}
383
384#[pin_project]
385struct MeasuredBody {
386 #[pin]
387 inner: SdkBody,
388 shared_bytes_sent: Arc<AtomicUsize>,
389}
390
391impl MeasuredBody {
392 const fn new(body: SdkBody, shared_bytes_sent: Arc<AtomicUsize>) -> Self {
393 Self {
394 inner: body,
395 shared_bytes_sent,
396 }
397 }
398}
399
400impl Body for MeasuredBody {
401 type Data = Bytes;
402 type Error = Box<dyn Error + Send + Sync>;
403
404 fn poll_data(
405 self: Pin<&mut Self>,
406 cx: &mut Context<'_>,
407 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
408 let this = self.project();
409
410 match this.inner.poll_data(cx) {
411 Poll::Ready(Some(Ok(data))) => {
412 this.shared_bytes_sent
413 .fetch_add(data.len(), Ordering::Release);
414 Poll::Ready(Some(Ok(data)))
415 }
416 Poll::Ready(None) => Poll::Ready(None),
417 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
418 Poll::Pending => Poll::Pending,
419 }
420 }
421
422 fn poll_trailers(
423 self: Pin<&mut Self>,
424 _cx: &mut Context<'_>,
425 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
426 Poll::Ready(Ok(None))
427 }
428}