1pub mod auth;
3pub mod region;
4pub mod timeout;
5
6pub use auth::{AwsAuthentication, ImdsAuthentication};
7use aws_config::{
8 meta::region::ProvideRegion, retry::RetryConfig, timeout::TimeoutConfig, Region, SdkConfig,
9};
10use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
11use aws_sigv4::{
12 http_request::{PayloadChecksumKind, SignableBody, SignableRequest, SigningSettings},
13 sign::v4,
14};
15use aws_smithy_async::rt::sleep::TokioSleep;
16use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder;
17use aws_smithy_runtime_api::client::{
18 http::{
19 HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
20 },
21 identity::Identity,
22 orchestrator::{HttpRequest, HttpResponse},
23 result::SdkError,
24 runtime_components::RuntimeComponents,
25};
26use aws_smithy_types::body::SdkBody;
27use aws_types::sdk_config::SharedHttpClient;
28use bytes::Bytes;
29use futures_util::FutureExt;
30use http::HeaderMap;
31use http_body::{combinators::BoxBody, Body};
32use pin_project::pin_project;
33use regex::RegexSet;
34pub use region::RegionOrEndpoint;
35use snafu::Snafu;
36use std::{
37 error::Error,
38 pin::Pin,
39 sync::{
40 atomic::{AtomicUsize, Ordering},
41 Arc, OnceLock,
42 },
43 task::{Context, Poll},
44 time::{Duration, SystemTime},
45};
46pub use timeout::AwsTimeout;
47
48use crate::config::ProxyConfig;
49use crate::http::{build_proxy_connector, build_tls_connector, status};
50use crate::internal_events::AwsBytesSent;
51use crate::tls::{MaybeTlsSettings, TlsConfig};
52
53static RETRIABLE_CODES: OnceLock<RegexSet> = OnceLock::new();
54
55pub fn is_retriable_error<T>(error: &SdkError<T, HttpResponse>) -> bool {
57 match error {
58 SdkError::TimeoutError(_) | SdkError::DispatchFailure(_) => true,
59 SdkError::ConstructionFailure(_) => false,
60 SdkError::ResponseError(err) => check_response(err.raw()),
61 SdkError::ServiceError(err) => check_response(err.raw()),
62 _ => {
63 warn!("AWS returned unknown error, retrying request.");
64 true
65 }
66 }
67}
68
69fn check_response(res: &HttpResponse) -> bool {
70 let retry_header = res.headers().get("x-amz-retry-after").is_some();
74
75 let re = RETRIABLE_CODES.get_or_init(|| {
89 RegexSet::new(["RequestTimeout", "RequestExpired", "ThrottlingException"])
90 .expect("invalid regex")
91 });
92
93 let status = res.status();
94 let response_body = String::from_utf8_lossy(res.body().bytes().unwrap_or(&[]));
95
96 retry_header
97 || status.is_server_error()
98 || status.as_u16() == status::TOO_MANY_REQUESTS
99 || (status.is_client_error() && re.is_match(response_body.as_ref()))
100}
101
102fn connector(
106 proxy: &ProxyConfig,
107 tls_options: Option<&TlsConfig>,
108) -> crate::Result<SharedHttpClient> {
109 let tls_settings = MaybeTlsSettings::tls_client(tls_options)?;
110
111 if proxy.enabled {
112 let proxy = build_proxy_connector(tls_settings, proxy)?;
113 Ok(HyperClientBuilder::new().build(proxy))
114 } else {
115 let tls_connector = build_tls_connector(tls_settings)?;
116 Ok(HyperClientBuilder::new().build(tls_connector))
117 }
118}
119
120pub trait ClientBuilder {
122 type Client;
124
125 fn build(&self, config: &SdkConfig) -> Self::Client;
127}
128
129pub fn region_provider(
131 proxy: &ProxyConfig,
132 tls_options: Option<&TlsConfig>,
133) -> crate::Result<impl ProvideRegion + use<>> {
134 let config = aws_config::provider_config::ProviderConfig::default()
135 .with_http_client(connector(proxy, tls_options)?);
136
137 Ok(aws_config::meta::region::RegionProviderChain::first_try(
138 aws_config::environment::EnvironmentVariableRegionProvider::new(),
139 )
140 .or_else(aws_config::profile::ProfileFileRegionProvider::builder().build())
141 .or_else(
142 aws_config::imds::region::ImdsRegionProvider::builder()
143 .configure(&config)
144 .build(),
145 ))
146}
147
148async fn resolve_region(
149 proxy: &ProxyConfig,
150 tls_options: Option<&TlsConfig>,
151 region: Option<Region>,
152) -> crate::Result<Region> {
153 match region {
154 Some(region) => Ok(region),
155 None => region_provider(proxy, tls_options)?
156 .region()
157 .await
158 .ok_or_else(|| {
159 "Could not determine region from Vector configuration or default providers".into()
160 }),
161 }
162}
163
164pub async fn create_client<T>(
166 builder: &T,
167 auth: &AwsAuthentication,
168 region: Option<Region>,
169 endpoint: Option<String>,
170 proxy: &ProxyConfig,
171 tls_options: Option<&TlsConfig>,
172 timeout: Option<&AwsTimeout>,
173) -> crate::Result<T::Client>
174where
175 T: ClientBuilder,
176{
177 create_client_and_region::<T>(builder, auth, region, endpoint, proxy, tls_options, timeout)
178 .await
179 .map(|(client, _)| client)
180}
181
182pub async fn create_client_and_region<T>(
184 builder: &T,
185 auth: &AwsAuthentication,
186 region: Option<Region>,
187 endpoint: Option<String>,
188 proxy: &ProxyConfig,
189 tls_options: Option<&TlsConfig>,
190 timeout: Option<&AwsTimeout>,
191) -> crate::Result<(T::Client, Region)>
192where
193 T: ClientBuilder,
194{
195 let retry_config = RetryConfig::disabled();
196
197 let region = resolve_region(proxy, tls_options, region).await?;
200
201 let provider_config =
202 aws_config::provider_config::ProviderConfig::empty().with_region(Some(region.clone()));
203
204 let connector = connector(proxy, tls_options)?;
205
206 let connector = AwsHttpClient {
208 http: connector,
209 region: region.clone(),
210 };
211
212 let mut config_builder = SdkConfig::builder()
214 .http_client(connector)
215 .sleep_impl(Arc::new(TokioSleep::new()))
216 .identity_cache(auth.credentials_cache().await?)
217 .credentials_provider(
218 auth.credentials_provider(region.clone(), proxy, tls_options)
219 .await?,
220 )
221 .region(region.clone())
222 .retry_config(retry_config.clone());
223
224 if let Some(endpoint_override) = endpoint {
225 config_builder = config_builder.endpoint_url(endpoint_override);
226 } else if let Some(endpoint_from_config) =
227 aws_config::default_provider::endpoint_url::endpoint_url_provider(&provider_config).await
228 {
229 config_builder = config_builder.endpoint_url(endpoint_from_config);
230 }
231
232 if let Some(use_fips) =
233 aws_config::default_provider::use_fips::use_fips_provider(&provider_config).await
234 {
235 config_builder = config_builder.use_fips(use_fips);
236 }
237
238 if let Some(timeout) = timeout {
239 let mut timeout_config_builder = TimeoutConfig::builder();
240
241 let operation_timeout = timeout.operation_timeout();
242 let connect_timeout = timeout.connect_timeout();
243 let read_timeout = timeout.read_timeout();
244
245 timeout_config_builder
246 .set_operation_timeout(operation_timeout.map(Duration::from_secs))
247 .set_connect_timeout(connect_timeout.map(Duration::from_secs))
248 .set_read_timeout(read_timeout.map(Duration::from_secs));
249
250 config_builder = config_builder.timeout_config(timeout_config_builder.build());
251 }
252
253 let config = config_builder.build();
254
255 Ok((T::build(builder, &config), region))
256}
257
258#[derive(Snafu, Debug)]
259enum SigningError {
260 #[snafu(display("cannot sign the request because the headers are not valid utf-8"))]
261 NotUTF8Header,
262}
263
264pub async fn sign_request(
267 service_name: &str,
268 request: &mut http::Request<Bytes>,
269 credentials_provider: &SharedCredentialsProvider,
270 region: Option<&Region>,
271 payload_checksum_sha256: bool,
272) -> crate::Result<()> {
273 let headers = request
274 .headers()
275 .iter()
276 .map(|(k, v)| {
277 Ok((
278 k.as_str(),
279 std::str::from_utf8(v.as_bytes()).map_err(|_| SigningError::NotUTF8Header)?,
280 ))
281 })
282 .collect::<Result<Vec<_>, SigningError>>()?;
283
284 let signable_request = SignableRequest::new(
285 request.method().as_str(),
286 request.uri().to_string(),
287 headers.into_iter(),
288 SignableBody::Bytes(request.body().as_ref()),
289 )?;
290
291 let credentials = credentials_provider.provide_credentials().await?;
292 let identity = Identity::new(credentials, None);
293
294 let mut signing_settings = SigningSettings::default();
295
296 if payload_checksum_sha256 {
299 signing_settings.payload_checksum_kind = PayloadChecksumKind::XAmzSha256;
300 }
301
302 let signing_params_builder = v4::SigningParams::builder()
303 .identity(&identity)
304 .region(region.as_ref().map(|r| r.as_ref()).unwrap_or(""))
305 .name(service_name)
306 .time(SystemTime::now())
307 .settings(signing_settings);
308
309 let signing_params = signing_params_builder
310 .build()
311 .expect("all signing params set");
312
313 let (signing_instructions, _signature) =
314 aws_sigv4::http_request::sign(signable_request, &signing_params.into())?.into_parts();
315 signing_instructions.apply_to_request_http0x(request);
316
317 Ok(())
318}
319
320#[derive(Debug)]
321struct AwsHttpClient<T> {
322 http: T,
323 region: Region,
324}
325
326impl<T> HttpClient for AwsHttpClient<T>
327where
328 T: HttpClient,
329{
330 fn http_connector(
331 &self,
332 settings: &HttpConnectorSettings,
333 components: &RuntimeComponents,
334 ) -> SharedHttpConnector {
335 let http_connector = self.http.http_connector(settings, components);
336
337 SharedHttpConnector::new(AwsConnector {
338 region: self.region.clone(),
339 http: http_connector,
340 })
341 }
342}
343
344#[derive(Clone, Debug)]
345struct AwsConnector<T> {
346 http: T,
347 region: Region,
348}
349
350impl<T> HttpConnector for AwsConnector<T>
351where
352 T: HttpConnector,
353{
354 fn call(&self, req: HttpRequest) -> HttpConnectorFuture {
355 let bytes_sent = Arc::new(std::sync::atomic::AtomicUsize::new(0));
356 let req = req.map(|body| {
357 let bytes_sent = Arc::clone(&bytes_sent);
358 body.map_preserve_contents(move |body| {
359 let body = MeasuredBody::new(body, Arc::clone(&bytes_sent));
360 SdkBody::from_body_0_4(BoxBody::new(body))
361 })
362 });
363
364 let fut = self.http.call(req);
365 let region = self.region.clone();
366
367 HttpConnectorFuture::new(fut.inspect(move |result| {
368 let byte_size = bytes_sent.load(Ordering::Relaxed);
369 if let Ok(result) = result {
370 if result.status().is_success() {
371 emit!(AwsBytesSent {
372 byte_size,
373 region: Some(region),
374 });
375 }
376 }
377 }))
378 }
379}
380
381#[pin_project]
382struct MeasuredBody {
383 #[pin]
384 inner: SdkBody,
385 shared_bytes_sent: Arc<AtomicUsize>,
386}
387
388impl MeasuredBody {
389 const fn new(body: SdkBody, shared_bytes_sent: Arc<AtomicUsize>) -> Self {
390 Self {
391 inner: body,
392 shared_bytes_sent,
393 }
394 }
395}
396
397impl Body for MeasuredBody {
398 type Data = Bytes;
399 type Error = Box<dyn Error + Send + Sync>;
400
401 fn poll_data(
402 self: Pin<&mut Self>,
403 cx: &mut Context<'_>,
404 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
405 let this = self.project();
406
407 match this.inner.poll_data(cx) {
408 Poll::Ready(Some(Ok(data))) => {
409 this.shared_bytes_sent
410 .fetch_add(data.len(), Ordering::Release);
411 Poll::Ready(Some(Ok(data)))
412 }
413 Poll::Ready(None) => Poll::Ready(None),
414 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
415 Poll::Pending => Poll::Pending,
416 }
417 }
418
419 fn poll_trailers(
420 self: Pin<&mut Self>,
421 _cx: &mut Context<'_>,
422 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
423 Poll::Ready(Ok(None))
424 }
425}