vector/aws/
mod.rs

1//! Shared functionality for the AWS components.
2pub mod auth;
3pub mod region;
4pub mod timeout;
5
6pub use auth::{AwsAuthentication, ImdsAuthentication};
7use aws_config::{
8    meta::region::ProvideRegion, retry::RetryConfig, timeout::TimeoutConfig, Region, SdkConfig,
9};
10use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
11use aws_sigv4::{
12    http_request::{PayloadChecksumKind, SignableBody, SignableRequest, SigningSettings},
13    sign::v4,
14};
15use aws_smithy_async::rt::sleep::TokioSleep;
16use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder;
17use aws_smithy_runtime_api::client::{
18    http::{
19        HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
20    },
21    identity::Identity,
22    orchestrator::{HttpRequest, HttpResponse},
23    result::SdkError,
24    runtime_components::RuntimeComponents,
25};
26use aws_smithy_types::body::SdkBody;
27use aws_types::sdk_config::SharedHttpClient;
28use bytes::Bytes;
29use futures_util::FutureExt;
30use http::HeaderMap;
31use http_body::{combinators::BoxBody, Body};
32use pin_project::pin_project;
33use regex::RegexSet;
34pub use region::RegionOrEndpoint;
35use snafu::Snafu;
36use std::{
37    error::Error,
38    pin::Pin,
39    sync::{
40        atomic::{AtomicUsize, Ordering},
41        Arc, OnceLock,
42    },
43    task::{Context, Poll},
44    time::{Duration, SystemTime},
45};
46pub use timeout::AwsTimeout;
47
48use crate::config::ProxyConfig;
49use crate::http::{build_proxy_connector, build_tls_connector, status};
50use crate::internal_events::AwsBytesSent;
51use crate::tls::{MaybeTlsSettings, TlsConfig};
52
53static RETRIABLE_CODES: OnceLock<RegexSet> = OnceLock::new();
54
55/// Checks if the request can be retried after the given error was returned.
56pub fn is_retriable_error<T>(error: &SdkError<T, HttpResponse>) -> bool {
57    match error {
58        SdkError::TimeoutError(_) | SdkError::DispatchFailure(_) => true,
59        SdkError::ConstructionFailure(_) => false,
60        SdkError::ResponseError(err) => check_response(err.raw()),
61        SdkError::ServiceError(err) => check_response(err.raw()),
62        _ => {
63            warn!("AWS returned unknown error, retrying request.");
64            true
65        }
66    }
67}
68
69fn check_response(res: &HttpResponse) -> bool {
70    // This header is a direct indication that we should retry the request. Eventually it'd
71    // be nice to actually schedule the retry after the given delay, but for now we just
72    // check that it contains a positive value.
73    let retry_header = res.headers().get("x-amz-retry-after").is_some();
74
75    // Certain 400-level responses will contain an error code indicating that the request
76    // should be retried. Since we don't retry 400-level responses by default, we'll look
77    // for these specifically before falling back to more general heuristics. Because AWS
78    // services use a mix of XML and JSON response bodies and the AWS SDK doesn't give us
79    // a parsed representation, we resort to a simple string match.
80    //
81    // S3: RequestTimeout
82    // SQS: RequestExpired, ThrottlingException
83    // ECS: RequestExpired, ThrottlingException
84    // Kinesis: RequestExpired, ThrottlingException
85    // Cloudwatch: RequestExpired, ThrottlingException
86    //
87    // Now just look for those when it's a client_error
88    let re = RETRIABLE_CODES.get_or_init(|| {
89        RegexSet::new(["RequestTimeout", "RequestExpired", "ThrottlingException"])
90            .expect("invalid regex")
91    });
92
93    let status = res.status();
94    let response_body = String::from_utf8_lossy(res.body().bytes().unwrap_or(&[]));
95
96    retry_header
97        || status.is_server_error()
98        || status.as_u16() == status::TOO_MANY_REQUESTS
99        || (status.is_client_error() && re.is_match(response_body.as_ref()))
100}
101
102/// Creates the http connector that has been configured to use the given proxy and TLS settings.
103/// All AWS requests should use this connector as the aws crates by default use RustTLS which we
104/// have turned off as we want to consistently use openssl.
105fn connector(
106    proxy: &ProxyConfig,
107    tls_options: Option<&TlsConfig>,
108) -> crate::Result<SharedHttpClient> {
109    let tls_settings = MaybeTlsSettings::tls_client(tls_options)?;
110
111    if proxy.enabled {
112        let proxy = build_proxy_connector(tls_settings, proxy)?;
113        Ok(HyperClientBuilder::new().build(proxy))
114    } else {
115        let tls_connector = build_tls_connector(tls_settings)?;
116        Ok(HyperClientBuilder::new().build(tls_connector))
117    }
118}
119
120/// Implement for each AWS service to create the appropriate AWS sdk client.
121pub trait ClientBuilder {
122    /// The type of the client in the SDK.
123    type Client;
124
125    /// Build the client using the given config settings.
126    fn build(&self, config: &SdkConfig) -> Self::Client;
127}
128
129/// Provides the configured AWS region.
130pub fn region_provider(
131    proxy: &ProxyConfig,
132    tls_options: Option<&TlsConfig>,
133) -> crate::Result<impl ProvideRegion + use<>> {
134    let config = aws_config::provider_config::ProviderConfig::default()
135        .with_http_client(connector(proxy, tls_options)?);
136
137    Ok(aws_config::meta::region::RegionProviderChain::first_try(
138        aws_config::environment::EnvironmentVariableRegionProvider::new(),
139    )
140    .or_else(aws_config::profile::ProfileFileRegionProvider::builder().build())
141    .or_else(
142        aws_config::imds::region::ImdsRegionProvider::builder()
143            .configure(&config)
144            .build(),
145    ))
146}
147
148async fn resolve_region(
149    proxy: &ProxyConfig,
150    tls_options: Option<&TlsConfig>,
151    region: Option<Region>,
152) -> crate::Result<Region> {
153    match region {
154        Some(region) => Ok(region),
155        None => region_provider(proxy, tls_options)?
156            .region()
157            .await
158            .ok_or_else(|| {
159                "Could not determine region from Vector configuration or default providers".into()
160            }),
161    }
162}
163
164/// Create the SDK client using the provided settings.
165pub async fn create_client<T>(
166    builder: &T,
167    auth: &AwsAuthentication,
168    region: Option<Region>,
169    endpoint: Option<String>,
170    proxy: &ProxyConfig,
171    tls_options: Option<&TlsConfig>,
172    timeout: Option<&AwsTimeout>,
173) -> crate::Result<T::Client>
174where
175    T: ClientBuilder,
176{
177    create_client_and_region::<T>(builder, auth, region, endpoint, proxy, tls_options, timeout)
178        .await
179        .map(|(client, _)| client)
180}
181
182/// Create the SDK client and resolve the region using the provided settings.
183pub async fn create_client_and_region<T>(
184    builder: &T,
185    auth: &AwsAuthentication,
186    region: Option<Region>,
187    endpoint: Option<String>,
188    proxy: &ProxyConfig,
189    tls_options: Option<&TlsConfig>,
190    timeout: Option<&AwsTimeout>,
191) -> crate::Result<(T::Client, Region)>
192where
193    T: ClientBuilder,
194{
195    let retry_config = RetryConfig::disabled();
196
197    // The default credentials chains will look for a region if not given but we'd like to
198    // error up front if later SDK calls will fail due to lack of region configuration
199    let region = resolve_region(proxy, tls_options, region).await?;
200
201    let provider_config =
202        aws_config::provider_config::ProviderConfig::empty().with_region(Some(region.clone()));
203
204    let connector = connector(proxy, tls_options)?;
205
206    // Create a custom http connector that will emit the required metrics for us.
207    let connector = AwsHttpClient {
208        http: connector,
209        region: region.clone(),
210    };
211
212    // Build the configuration first.
213    let mut config_builder = SdkConfig::builder()
214        .http_client(connector)
215        .sleep_impl(Arc::new(TokioSleep::new()))
216        .identity_cache(auth.credentials_cache().await?)
217        .credentials_provider(
218            auth.credentials_provider(region.clone(), proxy, tls_options)
219                .await?,
220        )
221        .region(region.clone())
222        .retry_config(retry_config.clone());
223
224    if let Some(endpoint_override) = endpoint {
225        config_builder = config_builder.endpoint_url(endpoint_override);
226    } else if let Some(endpoint_from_config) =
227        aws_config::default_provider::endpoint_url::endpoint_url_provider(&provider_config).await
228    {
229        config_builder = config_builder.endpoint_url(endpoint_from_config);
230    }
231
232    if let Some(use_fips) =
233        aws_config::default_provider::use_fips::use_fips_provider(&provider_config).await
234    {
235        config_builder = config_builder.use_fips(use_fips);
236    }
237
238    if let Some(timeout) = timeout {
239        let mut timeout_config_builder = TimeoutConfig::builder();
240
241        let operation_timeout = timeout.operation_timeout();
242        let connect_timeout = timeout.connect_timeout();
243        let read_timeout = timeout.read_timeout();
244
245        timeout_config_builder
246            .set_operation_timeout(operation_timeout.map(Duration::from_secs))
247            .set_connect_timeout(connect_timeout.map(Duration::from_secs))
248            .set_read_timeout(read_timeout.map(Duration::from_secs));
249
250        config_builder = config_builder.timeout_config(timeout_config_builder.build());
251    }
252
253    let config = config_builder.build();
254
255    Ok((T::build(builder, &config), region))
256}
257
258#[derive(Snafu, Debug)]
259enum SigningError {
260    #[snafu(display("cannot sign the request because the headers are not valid utf-8"))]
261    NotUTF8Header,
262}
263
264/// Sign the request prior to sending to AWS.
265/// The signature is added to the provided `request`.
266pub async fn sign_request(
267    service_name: &str,
268    request: &mut http::Request<Bytes>,
269    credentials_provider: &SharedCredentialsProvider,
270    region: Option<&Region>,
271    payload_checksum_sha256: bool,
272) -> crate::Result<()> {
273    let headers = request
274        .headers()
275        .iter()
276        .map(|(k, v)| {
277            Ok((
278                k.as_str(),
279                std::str::from_utf8(v.as_bytes()).map_err(|_| SigningError::NotUTF8Header)?,
280            ))
281        })
282        .collect::<Result<Vec<_>, SigningError>>()?;
283
284    let signable_request = SignableRequest::new(
285        request.method().as_str(),
286        request.uri().to_string(),
287        headers.into_iter(),
288        SignableBody::Bytes(request.body().as_ref()),
289    )?;
290
291    let credentials = credentials_provider.provide_credentials().await?;
292    let identity = Identity::new(credentials, None);
293
294    let mut signing_settings = SigningSettings::default();
295
296    // Include the x-amz-content-sha256 header when calculating the AWS v4 signature;
297    // this is required by some AWS services, e.g. S3 and OpenSearch Serverless
298    if payload_checksum_sha256 {
299        signing_settings.payload_checksum_kind = PayloadChecksumKind::XAmzSha256;
300    }
301
302    let signing_params_builder = v4::SigningParams::builder()
303        .identity(&identity)
304        .region(region.as_ref().map(|r| r.as_ref()).unwrap_or(""))
305        .name(service_name)
306        .time(SystemTime::now())
307        .settings(signing_settings);
308
309    let signing_params = signing_params_builder
310        .build()
311        .expect("all signing params set");
312
313    let (signing_instructions, _signature) =
314        aws_sigv4::http_request::sign(signable_request, &signing_params.into())?.into_parts();
315    signing_instructions.apply_to_request_http0x(request);
316
317    Ok(())
318}
319
320#[derive(Debug)]
321struct AwsHttpClient<T> {
322    http: T,
323    region: Region,
324}
325
326impl<T> HttpClient for AwsHttpClient<T>
327where
328    T: HttpClient,
329{
330    fn http_connector(
331        &self,
332        settings: &HttpConnectorSettings,
333        components: &RuntimeComponents,
334    ) -> SharedHttpConnector {
335        let http_connector = self.http.http_connector(settings, components);
336
337        SharedHttpConnector::new(AwsConnector {
338            region: self.region.clone(),
339            http: http_connector,
340        })
341    }
342}
343
344#[derive(Clone, Debug)]
345struct AwsConnector<T> {
346    http: T,
347    region: Region,
348}
349
350impl<T> HttpConnector for AwsConnector<T>
351where
352    T: HttpConnector,
353{
354    fn call(&self, req: HttpRequest) -> HttpConnectorFuture {
355        let bytes_sent = Arc::new(std::sync::atomic::AtomicUsize::new(0));
356        let req = req.map(|body| {
357            let bytes_sent = Arc::clone(&bytes_sent);
358            body.map_preserve_contents(move |body| {
359                let body = MeasuredBody::new(body, Arc::clone(&bytes_sent));
360                SdkBody::from_body_0_4(BoxBody::new(body))
361            })
362        });
363
364        let fut = self.http.call(req);
365        let region = self.region.clone();
366
367        HttpConnectorFuture::new(fut.inspect(move |result| {
368            let byte_size = bytes_sent.load(Ordering::Relaxed);
369            if let Ok(result) = result {
370                if result.status().is_success() {
371                    emit!(AwsBytesSent {
372                        byte_size,
373                        region: Some(region),
374                    });
375                }
376            }
377        }))
378    }
379}
380
381#[pin_project]
382struct MeasuredBody {
383    #[pin]
384    inner: SdkBody,
385    shared_bytes_sent: Arc<AtomicUsize>,
386}
387
388impl MeasuredBody {
389    const fn new(body: SdkBody, shared_bytes_sent: Arc<AtomicUsize>) -> Self {
390        Self {
391            inner: body,
392            shared_bytes_sent,
393        }
394    }
395}
396
397impl Body for MeasuredBody {
398    type Data = Bytes;
399    type Error = Box<dyn Error + Send + Sync>;
400
401    fn poll_data(
402        self: Pin<&mut Self>,
403        cx: &mut Context<'_>,
404    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
405        let this = self.project();
406
407        match this.inner.poll_data(cx) {
408            Poll::Ready(Some(Ok(data))) => {
409                this.shared_bytes_sent
410                    .fetch_add(data.len(), Ordering::Release);
411                Poll::Ready(Some(Ok(data)))
412            }
413            Poll::Ready(None) => Poll::Ready(None),
414            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
415            Poll::Pending => Poll::Pending,
416        }
417    }
418
419    fn poll_trailers(
420        self: Pin<&mut Self>,
421        _cx: &mut Context<'_>,
422    ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
423        Poll::Ready(Ok(None))
424    }
425}