1use std::collections::{HashMap, HashSet};
2
3use bytes::BytesMut;
4use futures_util::StreamExt;
5use serde::{Deserialize, Serialize};
6use tokio::{io::AsyncWriteExt, process::Command, time};
7use tokio_util::codec;
8use vector_lib::configurable::{component::GenerateConfig, configurable_component};
9use vrl::value::Value;
10
11use crate::{config::SecretBackend, signal};
12
13#[configurable_component(secrets("exec"))]
15#[configurable(metadata(docs::enum_tag_description = "The protocol version."))]
16#[derive(Clone, Debug)]
17#[serde(rename_all = "snake_case", tag = "version")]
18pub enum ExecVersion {
19 V1,
21
22 V1_1 {
24 backend_type: String,
26 backend_config: Value,
30 },
31}
32
33impl ExecVersion {
34 fn new_query(&self, secrets: HashSet<String>) -> ExecQuery {
35 match &self {
36 ExecVersion::V1 => ExecQuery {
37 version: "1.0".to_string(),
38 secrets,
39 r#type: None,
40 config: None,
41 },
42 ExecVersion::V1_1 {
43 backend_type,
44 backend_config,
45 ..
46 } => ExecQuery {
47 version: "1.1".to_string(),
48 secrets,
49 r#type: Some(backend_type.clone()),
50 config: Some(backend_config.clone()),
51 },
52 }
53 }
54}
55
56impl GenerateConfig for ExecVersion {
57 fn generate_config() -> toml::Value {
58 toml::Value::try_from(ExecVersion::V1).unwrap()
59 }
60}
61
62#[configurable_component(secrets("exec"))]
64#[derive(Clone, Debug)]
65pub struct ExecBackend {
66 pub command: Vec<String>,
70
71 #[serde(default = "default_timeout_secs")]
73 pub timeout: u64,
74
75 #[serde(default = "default_protocol_version")]
77 pub protocol: ExecVersion,
78}
79
80impl GenerateConfig for ExecBackend {
81 fn generate_config() -> toml::Value {
82 toml::Value::try_from(ExecBackend {
83 command: vec![String::from("/path/to/script")],
84 timeout: 5,
85 protocol: ExecVersion::V1,
86 })
87 .unwrap()
88 }
89}
90
91const fn default_timeout_secs() -> u64 {
92 5
93}
94
95const fn default_protocol_version() -> ExecVersion {
96 ExecVersion::V1
97}
98
99#[derive(Clone, Debug, Serialize)]
100struct ExecQuery {
101 version: String,
103 secrets: HashSet<String>,
104 #[serde(skip_serializing_if = "Option::is_none")]
106 r#type: Option<String>,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 config: Option<Value>,
109}
110
111#[derive(Clone, Debug, Deserialize, Serialize)]
112struct ExecResponse {
113 value: Option<String>,
114 error: Option<String>,
115}
116
117impl SecretBackend for ExecBackend {
118 async fn retrieve(
119 &mut self,
120 secret_keys: HashSet<String>,
121 signal_rx: &mut signal::SignalRx,
122 ) -> crate::Result<HashMap<String, String>> {
123 let mut output = query_backend(
124 &self.command,
125 self.protocol.new_query(secret_keys.clone()),
126 self.timeout,
127 signal_rx,
128 )
129 .await?;
130 let mut secrets = HashMap::new();
131 for k in secret_keys.into_iter() {
132 if let Some(secret) = output.get_mut(&k) {
133 if let Some(e) = &secret.error {
134 return Err(format!("secret for key '{k}' was not retrieved: {e}").into());
135 }
136 if let Some(v) = secret.value.take() {
137 if v.is_empty() {
138 return Err(format!("secret for key '{k}' was empty").into());
139 }
140 secrets.insert(k.to_string(), v);
141 } else {
142 return Err(format!("secret for key '{k}' was empty").into());
143 }
144 } else {
145 return Err(format!("secret for key '{k}' was not retrieved").into());
146 }
147 }
148 Ok(secrets)
149 }
150}
151
152async fn query_backend(
153 cmd: &[String],
154 query: ExecQuery,
155 timeout: u64,
156 signal_rx: &mut signal::SignalRx,
157) -> crate::Result<HashMap<String, ExecResponse>> {
158 let command = &cmd[0];
159 let mut command = Command::new(command);
160
161 if cmd.len() > 1 {
162 command.args(&cmd[1..]);
163 };
164
165 command.kill_on_drop(true);
166 command.stderr(std::process::Stdio::piped());
167 command.stdin(std::process::Stdio::piped());
168 command.stdout(std::process::Stdio::piped());
169
170 let mut child = command.spawn()?;
171 let mut stdin = child.stdin.take().ok_or("unable to acquire stdin")?;
172 let mut stderr_stream = child
173 .stderr
174 .map(|s| codec::FramedRead::new(s, codec::LinesCodec::new()))
175 .ok_or("unable to acquire stderr")?;
176 let mut stdout_stream = child
177 .stdout
178 .map(|s| codec::FramedRead::new(s, codec::BytesCodec::new()))
179 .ok_or("unable to acquire stdout")?;
180
181 let query = serde_json::to_vec(&query)?;
182 tokio::spawn(async move { stdin.write_all(&query).await });
183
184 let timeout = time::sleep(time::Duration::from_secs(timeout));
185 tokio::pin!(timeout);
186 let mut output = BytesMut::new();
187 loop {
188 tokio::select! {
189 biased;
190 Ok(signal::SignalTo::Shutdown(_) | signal::SignalTo::Quit) = signal_rx.recv() => {
191 drop(command);
192 return Err("Secret retrieval was interrupted.".into());
193 }
194 Some(stderr) = stderr_stream.next() => {
195 match stderr {
196 Ok(l) => warn!("An exec backend generated message on stderr: {}.", l),
197 Err(e) => warn!("Error while reading from an exec backend stderr: {}.", e),
198 }
199 }
200 stdout = stdout_stream.next() => {
201 match stdout {
202 None => break,
203 Some(Ok(b)) => output.extend(b),
204 Some(Err(e)) => return Err(format!("Error while reading from an exec backend stdout: {e}.").into()),
205 }
206 }
207 _ = &mut timeout => {
208 drop(command);
209 return Err("Command timed-out".into());
210 }
211 }
212 }
213
214 let response = serde_json::from_slice::<HashMap<String, ExecResponse>>(&output)?;
215 Ok(response)
216}
217
218#[cfg(test)]
219mod tests {
220 use std::{
221 collections::{HashMap, HashSet},
222 path::PathBuf,
223 };
224
225 use rstest::rstest;
226 use tokio::sync::broadcast;
227 use vrl::value;
228
229 use crate::{
230 config::SecretBackend,
231 secrets::exec::{ExecBackend, ExecVersion},
232 };
233
234 fn make_test_backend(protocol: ExecVersion) -> ExecBackend {
235 let command_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
236 .join("tests/behavior/secrets/mock_secrets_exec.py");
237 ExecBackend {
238 command: ["python", command_path.to_str().unwrap()]
239 .map(String::from)
240 .to_vec(),
241 timeout: 5,
242 protocol,
243 }
244 }
245
246 #[tokio::test(flavor = "multi_thread")]
247 #[rstest(
248 protocol,
249 case(ExecVersion::V1),
250 case(ExecVersion::V1_1 {
251 backend_type: "file.json".to_string(),
252 backend_config: value!({"file_path": "/abc.json"}),
253 })
254 )]
255 async fn test_exec_backend(protocol: ExecVersion) {
256 let mut backend = make_test_backend(protocol);
257 let (_tx, mut rx) = broadcast::channel(1);
258 let fake_secret_values: HashMap<String, String> = [
260 ("fake_secret_1", "123456"),
261 ("fake_secret_2", "123457"),
262 ("fake_secret_3", "123458"),
263 ("fake_secret_4", "123459"),
264 ("fake_secret_5", "123460"),
265 ]
266 .into_iter()
267 .map(|(k, v)| (k.to_string(), v.to_string()))
268 .collect();
269 let fetched_keys = backend
272 .retrieve(fake_secret_values.keys().cloned().collect(), &mut rx)
273 .await
274 .unwrap();
275 assert_eq!(fetched_keys.len(), 5);
277 for (fake_secret_key, fake_secret_value) in fake_secret_values {
278 assert_eq!(fetched_keys.get(&fake_secret_key), Some(&fake_secret_value));
279 }
280 }
281
282 #[tokio::test(flavor = "multi_thread")]
283 async fn test_exec_backend_missing_secrets() {
284 let mut backend = make_test_backend(ExecVersion::V1);
285 let (_tx, mut rx) = broadcast::channel(1);
286 let query_secrets: HashSet<String> =
287 ["fake_secret_900"].into_iter().map(String::from).collect();
288 let fetched_keys = backend.retrieve(query_secrets.clone(), &mut rx).await;
289 assert_eq!(
290 format!("{}", fetched_keys.unwrap_err()),
291 "secret for key 'fake_secret_900' was not retrieved: backend does not provide secret key"
292 );
293 }
294}