vector/aws/
mod.rs

1//! Shared functionality for the AWS components.
2pub mod auth;
3pub mod region;
4pub mod timeout;
5
6use std::{
7    error::Error,
8    pin::Pin,
9    sync::{
10        Arc, OnceLock,
11        atomic::{AtomicUsize, Ordering},
12    },
13    task::{Context, Poll},
14    time::{Duration, SystemTime},
15};
16
17pub use auth::{AwsAuthentication, ImdsAuthentication};
18use aws_config::{
19    Region, SdkConfig, meta::region::ProvideRegion, retry::RetryConfig, timeout::TimeoutConfig,
20};
21use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
22use aws_sigv4::{
23    http_request::{PayloadChecksumKind, SignableBody, SignableRequest, SigningSettings},
24    sign::v4,
25};
26use aws_smithy_async::rt::sleep::TokioSleep;
27use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder;
28use aws_smithy_runtime_api::client::{
29    http::{
30        HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
31    },
32    identity::Identity,
33    orchestrator::{HttpRequest, HttpResponse},
34    result::SdkError,
35    runtime_components::RuntimeComponents,
36};
37use aws_smithy_types::body::SdkBody;
38use aws_types::sdk_config::SharedHttpClient;
39use bytes::Bytes;
40use futures_util::FutureExt;
41use http::HeaderMap;
42use http_body::{Body, combinators::BoxBody};
43use pin_project::pin_project;
44use regex::RegexSet;
45pub use region::RegionOrEndpoint;
46use snafu::Snafu;
47pub use timeout::AwsTimeout;
48
49use crate::{
50    config::ProxyConfig,
51    http::{build_proxy_connector, build_tls_connector, status},
52    internal_events::AwsBytesSent,
53    tls::{MaybeTlsSettings, TlsConfig},
54};
55
56static RETRIABLE_CODES: OnceLock<RegexSet> = OnceLock::new();
57
58/// Checks if the request can be retried after the given error was returned.
59pub fn is_retriable_error<T>(error: &SdkError<T, HttpResponse>) -> bool {
60    match error {
61        SdkError::TimeoutError(_) | SdkError::DispatchFailure(_) => true,
62        SdkError::ConstructionFailure(_) => false,
63        SdkError::ResponseError(err) => check_response(err.raw()),
64        SdkError::ServiceError(err) => check_response(err.raw()),
65        _ => {
66            warn!("AWS returned unknown error, retrying request.");
67            true
68        }
69    }
70}
71
72fn check_response(res: &HttpResponse) -> bool {
73    // This header is a direct indication that we should retry the request. Eventually it'd
74    // be nice to actually schedule the retry after the given delay, but for now we just
75    // check that it contains a positive value.
76    let retry_header = res.headers().get("x-amz-retry-after").is_some();
77
78    // Certain 400-level responses will contain an error code indicating that the request
79    // should be retried. Since we don't retry 400-level responses by default, we'll look
80    // for these specifically before falling back to more general heuristics. Because AWS
81    // services use a mix of XML and JSON response bodies and the AWS SDK doesn't give us
82    // a parsed representation, we resort to a simple string match.
83    //
84    // S3: RequestTimeout
85    // SQS: RequestExpired, ThrottlingException
86    // ECS: RequestExpired, ThrottlingException
87    // Kinesis: RequestExpired, ThrottlingException
88    // Cloudwatch: RequestExpired, ThrottlingException
89    //
90    // Now just look for those when it's a client_error
91    let re = RETRIABLE_CODES.get_or_init(|| {
92        RegexSet::new(["RequestTimeout", "RequestExpired", "ThrottlingException"])
93            .expect("invalid regex")
94    });
95
96    let status = res.status();
97    let response_body = String::from_utf8_lossy(res.body().bytes().unwrap_or(&[]));
98
99    retry_header
100        || status.is_server_error()
101        || status.as_u16() == status::TOO_MANY_REQUESTS
102        || (status.is_client_error() && re.is_match(response_body.as_ref()))
103}
104
105/// Creates the http connector that has been configured to use the given proxy and TLS settings.
106/// All AWS requests should use this connector as the aws crates by default use RustTLS which we
107/// have turned off as we want to consistently use openssl.
108fn connector(
109    proxy: &ProxyConfig,
110    tls_options: Option<&TlsConfig>,
111) -> crate::Result<SharedHttpClient> {
112    let tls_settings = MaybeTlsSettings::tls_client(tls_options)?;
113
114    if proxy.enabled {
115        let proxy = build_proxy_connector(tls_settings, proxy)?;
116        Ok(HyperClientBuilder::new().build(proxy))
117    } else {
118        let tls_connector = build_tls_connector(tls_settings)?;
119        Ok(HyperClientBuilder::new().build(tls_connector))
120    }
121}
122
123/// Implement for each AWS service to create the appropriate AWS sdk client.
124pub trait ClientBuilder {
125    /// The type of the client in the SDK.
126    type Client;
127
128    /// Build the client using the given config settings.
129    fn build(&self, config: &SdkConfig) -> Self::Client;
130}
131
132/// Provides the configured AWS region.
133pub fn region_provider(
134    proxy: &ProxyConfig,
135    tls_options: Option<&TlsConfig>,
136) -> crate::Result<impl ProvideRegion + use<>> {
137    let config = aws_config::provider_config::ProviderConfig::default()
138        .with_http_client(connector(proxy, tls_options)?);
139
140    Ok(aws_config::meta::region::RegionProviderChain::first_try(
141        aws_config::environment::EnvironmentVariableRegionProvider::new(),
142    )
143    .or_else(aws_config::profile::ProfileFileRegionProvider::builder().build())
144    .or_else(
145        aws_config::imds::region::ImdsRegionProvider::builder()
146            .configure(&config)
147            .build(),
148    ))
149}
150
151async fn resolve_region(
152    proxy: &ProxyConfig,
153    tls_options: Option<&TlsConfig>,
154    region: Option<Region>,
155) -> crate::Result<Region> {
156    match region {
157        Some(region) => Ok(region),
158        None => region_provider(proxy, tls_options)?
159            .region()
160            .await
161            .ok_or_else(|| {
162                "Could not determine region from Vector configuration or default providers".into()
163            }),
164    }
165}
166
167/// Create the SDK client using the provided settings.
168pub async fn create_client<T>(
169    builder: &T,
170    auth: &AwsAuthentication,
171    region: Option<Region>,
172    endpoint: Option<String>,
173    proxy: &ProxyConfig,
174    tls_options: Option<&TlsConfig>,
175    timeout: Option<&AwsTimeout>,
176) -> crate::Result<T::Client>
177where
178    T: ClientBuilder,
179{
180    create_client_and_region::<T>(builder, auth, region, endpoint, proxy, tls_options, timeout)
181        .await
182        .map(|(client, _)| client)
183}
184
185/// Create the SDK client and resolve the region using the provided settings.
186pub async fn create_client_and_region<T>(
187    builder: &T,
188    auth: &AwsAuthentication,
189    region: Option<Region>,
190    endpoint: Option<String>,
191    proxy: &ProxyConfig,
192    tls_options: Option<&TlsConfig>,
193    timeout: Option<&AwsTimeout>,
194) -> crate::Result<(T::Client, Region)>
195where
196    T: ClientBuilder,
197{
198    let retry_config = RetryConfig::disabled();
199
200    // The default credentials chains will look for a region if not given but we'd like to
201    // error up front if later SDK calls will fail due to lack of region configuration
202    let region = resolve_region(proxy, tls_options, region).await?;
203
204    let provider_config =
205        aws_config::provider_config::ProviderConfig::empty().with_region(Some(region.clone()));
206
207    let connector = connector(proxy, tls_options)?;
208
209    // Create a custom http connector that will emit the required metrics for us.
210    let connector = AwsHttpClient {
211        http: connector,
212        region: region.clone(),
213    };
214
215    // Build the configuration first.
216    let mut config_builder = SdkConfig::builder()
217        .http_client(connector)
218        .sleep_impl(Arc::new(TokioSleep::new()))
219        .identity_cache(auth.credentials_cache().await?)
220        .credentials_provider(
221            auth.credentials_provider(region.clone(), proxy, tls_options)
222                .await?,
223        )
224        .region(region.clone())
225        .retry_config(retry_config.clone());
226
227    if let Some(endpoint_override) = endpoint {
228        config_builder = config_builder.endpoint_url(endpoint_override);
229    } else if let Some(endpoint_from_config) =
230        aws_config::default_provider::endpoint_url::endpoint_url_provider(&provider_config).await
231    {
232        config_builder = config_builder.endpoint_url(endpoint_from_config);
233    }
234
235    if let Some(use_fips) =
236        aws_config::default_provider::use_fips::use_fips_provider(&provider_config).await
237    {
238        config_builder = config_builder.use_fips(use_fips);
239    }
240
241    if let Some(timeout) = timeout {
242        let mut timeout_config_builder = TimeoutConfig::builder();
243
244        let operation_timeout = timeout.operation_timeout();
245        let connect_timeout = timeout.connect_timeout();
246        let read_timeout = timeout.read_timeout();
247
248        timeout_config_builder
249            .set_operation_timeout(operation_timeout.map(Duration::from_secs))
250            .set_connect_timeout(connect_timeout.map(Duration::from_secs))
251            .set_read_timeout(read_timeout.map(Duration::from_secs));
252
253        config_builder = config_builder.timeout_config(timeout_config_builder.build());
254    }
255
256    let config = config_builder.build();
257
258    Ok((T::build(builder, &config), region))
259}
260
261#[derive(Snafu, Debug)]
262enum SigningError {
263    #[snafu(display("cannot sign the request because the headers are not valid utf-8"))]
264    NotUTF8Header,
265}
266
267/// Sign the request prior to sending to AWS.
268/// The signature is added to the provided `request`.
269pub async fn sign_request(
270    service_name: &str,
271    request: &mut http::Request<Bytes>,
272    credentials_provider: &SharedCredentialsProvider,
273    region: Option<&Region>,
274    payload_checksum_sha256: bool,
275) -> crate::Result<()> {
276    let headers = request
277        .headers()
278        .iter()
279        .map(|(k, v)| {
280            Ok((
281                k.as_str(),
282                std::str::from_utf8(v.as_bytes()).map_err(|_| SigningError::NotUTF8Header)?,
283            ))
284        })
285        .collect::<Result<Vec<_>, SigningError>>()?;
286
287    let signable_request = SignableRequest::new(
288        request.method().as_str(),
289        request.uri().to_string(),
290        headers.into_iter(),
291        SignableBody::Bytes(request.body().as_ref()),
292    )?;
293
294    let credentials = credentials_provider.provide_credentials().await?;
295    let identity = Identity::new(credentials, None);
296
297    let mut signing_settings = SigningSettings::default();
298
299    // Include the x-amz-content-sha256 header when calculating the AWS v4 signature;
300    // this is required by some AWS services, e.g. S3 and OpenSearch Serverless
301    if payload_checksum_sha256 {
302        signing_settings.payload_checksum_kind = PayloadChecksumKind::XAmzSha256;
303    }
304
305    let signing_params_builder = v4::SigningParams::builder()
306        .identity(&identity)
307        .region(region.as_ref().map(|r| r.as_ref()).unwrap_or(""))
308        .name(service_name)
309        .time(SystemTime::now())
310        .settings(signing_settings);
311
312    let signing_params = signing_params_builder
313        .build()
314        .expect("all signing params set");
315
316    let (signing_instructions, _signature) =
317        aws_sigv4::http_request::sign(signable_request, &signing_params.into())?.into_parts();
318    signing_instructions.apply_to_request_http0x(request);
319
320    Ok(())
321}
322
323#[derive(Debug)]
324struct AwsHttpClient<T> {
325    http: T,
326    region: Region,
327}
328
329impl<T> HttpClient for AwsHttpClient<T>
330where
331    T: HttpClient,
332{
333    fn http_connector(
334        &self,
335        settings: &HttpConnectorSettings,
336        components: &RuntimeComponents,
337    ) -> SharedHttpConnector {
338        let http_connector = self.http.http_connector(settings, components);
339
340        SharedHttpConnector::new(AwsConnector {
341            region: self.region.clone(),
342            http: http_connector,
343        })
344    }
345}
346
347#[derive(Clone, Debug)]
348struct AwsConnector<T> {
349    http: T,
350    region: Region,
351}
352
353impl<T> HttpConnector for AwsConnector<T>
354where
355    T: HttpConnector,
356{
357    fn call(&self, req: HttpRequest) -> HttpConnectorFuture {
358        let bytes_sent = Arc::new(std::sync::atomic::AtomicUsize::new(0));
359        let req = req.map(|body| {
360            let bytes_sent = Arc::clone(&bytes_sent);
361            body.map_preserve_contents(move |body| {
362                let body = MeasuredBody::new(body, Arc::clone(&bytes_sent));
363                SdkBody::from_body_0_4(BoxBody::new(body))
364            })
365        });
366
367        let fut = self.http.call(req);
368        let region = self.region.clone();
369
370        HttpConnectorFuture::new(fut.inspect(move |result| {
371            let byte_size = bytes_sent.load(Ordering::Relaxed);
372            if let Ok(result) = result
373                && result.status().is_success()
374            {
375                emit!(AwsBytesSent {
376                    byte_size,
377                    region: Some(region),
378                });
379            }
380        }))
381    }
382}
383
384#[pin_project]
385struct MeasuredBody {
386    #[pin]
387    inner: SdkBody,
388    shared_bytes_sent: Arc<AtomicUsize>,
389}
390
391impl MeasuredBody {
392    const fn new(body: SdkBody, shared_bytes_sent: Arc<AtomicUsize>) -> Self {
393        Self {
394            inner: body,
395            shared_bytes_sent,
396        }
397    }
398}
399
400impl Body for MeasuredBody {
401    type Data = Bytes;
402    type Error = Box<dyn Error + Send + Sync>;
403
404    fn poll_data(
405        self: Pin<&mut Self>,
406        cx: &mut Context<'_>,
407    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
408        let this = self.project();
409
410        match this.inner.poll_data(cx) {
411            Poll::Ready(Some(Ok(data))) => {
412                this.shared_bytes_sent
413                    .fetch_add(data.len(), Ordering::Release);
414                Poll::Ready(Some(Ok(data)))
415            }
416            Poll::Ready(None) => Poll::Ready(None),
417            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
418            Poll::Pending => Poll::Pending,
419        }
420    }
421
422    fn poll_trailers(
423        self: Pin<&mut Self>,
424        _cx: &mut Context<'_>,
425    ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
426        Poll::Ready(Ok(None))
427    }
428}