vector/sinks/azure_common/
shared_key_policy.rs

1use std::{collections::BTreeMap, fmt::Write as _, sync::Arc};
2
3use async_trait::async_trait;
4use azure_core::http::policies::{Policy, PolicyResult};
5use azure_core::http::{Context, Request, Url};
6use azure_core::{
7    Result as AzureResult, base64,
8    error::Error as AzureError,
9    time::{OffsetDateTime, to_rfc7231},
10};
11
12use openssl::{hash::MessageDigest, pkey::PKey, sign::Signer};
13
14/// Shared Key authorization policy for Azure Blob Storage requests.
15///
16/// This policy injects the required headers (x-ms-date, x-ms-version) if missing and
17/// adds the `Authorization: SharedKey {account}:{signature}` header. The signature
18/// is computed according to the "Authorize with Shared Key" rules for the Blob service:
19///
20/// StringToSign =
21///   VERB + "\n" +
22///   Content-Encoding + "\n" +
23///   Content-Language + "\n" +
24///   Content-Length + "\n" +
25///   Content-MD5 + "\n" +
26///   Content-Type + "\n" +
27///   Date + "\n" +
28///   If-Modified-Since + "\n" +
29///   If-Match + "\n" +
30///   If-None-Match + "\n" +
31///   If-Unmodified-Since + "\n" +
32///   Range + "\n" +
33///   CanonicalizedHeaders +
34///   CanonicalizedResource
35///
36/// Notes:
37/// - We set x-ms-date, leaving the standard Date field empty in the signature.
38/// - If Content-Length header is present with "0", the canonicalized value must be the empty string.
39/// - Canonicalized headers include all x-ms-* headers (lowercased, sorted).
40/// - Canonicalized resource is "/{account}{path}\n" + sorted lowercase query params.
41///
42#[derive(Debug)]
43pub struct SharedKeyAuthorizationPolicy {
44    account_name: String,
45    account_key: Vec<u8>, // decoded from base64
46    storage_version: String,
47}
48
49impl SharedKeyAuthorizationPolicy {
50    /// Create a new shared key policy.
51    ///
52    /// - `account_name`: The storage account name.
53    /// - `account_key_b64`: Base64-encoded storage account key.
54    /// - `storage_version`: x-ms-version value to send (e.g. "2025-11-05").
55    pub fn new(
56        account_name: String,
57        account_key_b64: String,
58        storage_version: String,
59    ) -> AzureResult<Self> {
60        let account_key = base64::decode(account_key_b64.as_bytes()).map_err(|e| {
61            AzureError::with_message(
62                azure_core::error::ErrorKind::Other,
63                format!("invalid account key base64: {e}"),
64            )
65        })?;
66        Ok(Self {
67            account_name,
68            account_key,
69            storage_version,
70        })
71    }
72
73    fn ensure_ms_headers(&self, request: &mut Request) -> AzureResult<(String, String)> {
74        // Always set x-ms-date and x-ms-version explicitly to known values for signing.
75        let now = OffsetDateTime::now_utc();
76        let ms_date = to_rfc7231(&now);
77        request.insert_header("x-ms-date", ms_date.clone());
78        let ms_version = self.storage_version.clone();
79        request.insert_header("x-ms-version", ms_version.clone());
80        Ok((ms_date, ms_version))
81    }
82
83    fn build_string_to_sign(
84        &self,
85        req: &Request,
86        ms_date: &str,
87        ms_version: &str,
88    ) -> AzureResult<String> {
89        let method = req.method().as_str();
90        let url = req.url();
91
92        let mut s = String::with_capacity(512);
93
94        // VERB
95        s.push_str(method);
96        s.push('\n');
97
98        // Resolve standard headers (case-insensitive) and write them in order required by the spec.
99        // https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#shared-key-format-for-2009-09-19-and-later
100        let header = |name: &str| -> Option<&str> {
101            for (n, v) in req.headers().iter() {
102                if n.as_str().eq_ignore_ascii_case(name) {
103                    return Some(v.as_str());
104                }
105            }
106            None
107        };
108
109        // Content-Encoding
110        if let Some(v) = header("Content-Encoding") {
111            s.push_str(v);
112        }
113        s.push('\n');
114
115        // Content-Language
116        if let Some(v) = header("Content-Language") {
117            s.push_str(v);
118        }
119        s.push('\n');
120
121        // Content-Length (include value if present; keep "0")
122        if let Some(v) = header("Content-Length") {
123            s.push_str(v);
124        }
125        s.push('\n');
126
127        // Content-MD5
128        if let Some(v) = header("Content-MD5") {
129            s.push_str(v);
130        }
131        s.push('\n');
132
133        // Content-Type
134        if let Some(v) = header("Content-Type") {
135            s.push_str(v);
136        }
137        s.push('\n');
138
139        // Date (unused when x-ms-date is used)
140        s.push('\n');
141
142        // If-Modified-Since
143        if let Some(v) = header("If-Modified-Since") {
144            s.push_str(v);
145        }
146        s.push('\n');
147
148        // If-Match
149        if let Some(v) = header("If-Match") {
150            s.push_str(v);
151        }
152        s.push('\n');
153
154        // If-None-Match
155        if let Some(v) = header("If-None-Match") {
156            s.push_str(v);
157        }
158        s.push('\n');
159
160        // If-Unmodified-Since
161        if let Some(v) = header("If-Unmodified-Since") {
162            s.push_str(v);
163        }
164        s.push('\n');
165
166        // Range
167        if let Some(v) = header("Range") {
168            s.push_str(v);
169        }
170        s.push('\n');
171
172        // CanonicalizedHeaders: include all x-ms-* headers, lowercased, sorted by name.
173        // If multiple values for the same header exist, sort values and join with commas.
174        let mut xms: BTreeMap<String, Vec<String>> = BTreeMap::new();
175        for (name, value) in req.headers().iter() {
176            let key = name.as_str().to_ascii_lowercase();
177            if key.starts_with("x-ms-") {
178                xms.entry(key)
179                    .or_default()
180                    .push(value.as_str().trim().to_string());
181            }
182        }
183        // Ensure required headers are present (they should have been inserted).
184        xms.entry("x-ms-date".to_string())
185            .or_default()
186            .push(ms_date.to_string());
187        xms.entry("x-ms-version".to_string())
188            .or_default()
189            .push(ms_version.to_string());
190
191        for (k, mut vals) in xms {
192            vals.sort();
193            vals.dedup();
194            let joined = vals.join(",");
195            let _ = writeln!(s, "{}:{}", k, joined);
196        }
197
198        // CanonicalizedResource
199        append_canonicalized_resource(&mut s, &self.account_name, url)?;
200
201        Ok(s)
202    }
203
204    fn sign(&self, string_to_sign: &str) -> AzureResult<String> {
205        let pkey = PKey::hmac(&self.account_key).map_err(|e| {
206            AzureError::with_message(
207                azure_core::error::ErrorKind::Other,
208                format!("failed to create HMAC key: {e}"),
209            )
210        })?;
211        let mut signer = Signer::new(MessageDigest::sha256(), &pkey).map_err(|e| {
212            AzureError::with_message(
213                azure_core::error::ErrorKind::Other,
214                format!("failed to create signer: {e}"),
215            )
216        })?;
217        signer.update(string_to_sign.as_bytes()).map_err(|e| {
218            AzureError::with_message(
219                azure_core::error::ErrorKind::Other,
220                format!("signer update failed: {e}"),
221            )
222        })?;
223        let mac = signer.sign_to_vec().map_err(|e| {
224            AzureError::with_message(
225                azure_core::error::ErrorKind::Other,
226                format!("signer sign failed: {e}"),
227            )
228        })?;
229        Ok(base64::encode(&mac))
230    }
231}
232
233#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
234#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
235impl Policy for SharedKeyAuthorizationPolicy {
236    async fn send(
237        &self,
238        ctx: &Context,
239        request: &mut Request,
240        next: &[Arc<dyn Policy>],
241    ) -> PolicyResult {
242        // Ensure required x-ms headers are present
243        let (ms_date, ms_version) = self.ensure_ms_headers(request)?;
244        // Build string to sign
245        let sts = self.build_string_to_sign(request, &ms_date, &ms_version)?;
246        let signature = self.sign(&sts)?;
247
248        // Authorization: SharedKey {account}:{signature}
249        request.insert_header(
250            "authorization",
251            format!("SharedKey {}:{}", self.account_name, signature),
252        );
253
254        // Continue pipeline
255        next[0].send(ctx, request, &next[1..]).await
256    }
257}
258
259// ---------- Helpers ----------
260
261fn append_canonicalized_resource(s: &mut String, account: &str, url: &Url) -> AzureResult<()> {
262    // "/{account_name}{path}\n"
263    s.push('/');
264    s.push_str(account);
265    // Append the URL path exactly as-is (per spec).
266    s.push_str(url.path());
267
268    // Canonicalized query: lowercase names, sort by name, join multi-values by comma, each line "name:value\n"
269    // https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#shared-key-format-for-2009-09-19-and-later
270    if url.query().is_some() {
271        let mut qp_map: BTreeMap<String, Vec<String>> = BTreeMap::new();
272        for (name, value) in url.query_pairs() {
273            let key_l = name.to_ascii_lowercase();
274            let v = value.to_string();
275            if v.is_empty() {
276                continue;
277            }
278            qp_map.entry(key_l).or_default().push(v);
279        }
280        for (k, mut vals) in qp_map {
281            vals.sort();
282            let mut line = String::new();
283            let _ = write!(&mut line, "\n{}:", k);
284            let joined = vals.join(",");
285            line.push_str(&joined);
286            s.push_str(&line);
287        }
288    }
289
290    Ok(())
291}