vector/sinks/websocket_server/
buffering.rs1use std::{collections::VecDeque, net::SocketAddr, num::NonZeroUsize};
2
3use bytes::Bytes;
4use derivative::Derivative;
5use tokio_tungstenite::tungstenite::{Message, handshake::server::Request};
6use url::Url;
7use uuid::Uuid;
8use vector_config::configurable_component;
9use vector_lib::{
10 codecs::decoding::{DeserializerConfig, format::Deserializer as _},
11 event::{Event, MaybeAsLogMut},
12 lookup::lookup_v2::ConfigValuePath,
13};
14use vrl::prelude::VrlValueConvert;
15
16use crate::serde::default_decoding;
17
18#[configurable_component]
20#[derive(Clone, Debug)]
21pub struct MessageBufferingConfig {
22 #[serde(default = "default_max_events")]
27 pub max_events: NonZeroUsize,
28
29 #[serde(default, skip_serializing_if = "crate::serde::is_default")]
34 pub message_id_path: Option<ConfigValuePath>,
35
36 #[configurable(derived)]
37 pub client_ack_config: Option<BufferingAckConfig>,
38}
39
40#[configurable_component]
45#[derive(Clone, Debug, Derivative)]
46pub struct BufferingAckConfig {
47 #[configurable(derived)]
48 #[derivative(Default(value = "default_decoding()"))]
49 #[serde(default = "default_decoding")]
50 pub ack_decoding: DeserializerConfig,
51
52 pub message_id_path: ConfigValuePath,
55
56 #[configurable(derived)]
57 #[serde(default = "default_client_key_config")]
58 pub client_key: ClientKeyConfig,
59}
60
61#[configurable_component]
63#[derive(Clone, Debug)]
64#[serde(tag = "type", rename_all = "snake_case")]
65#[configurable(metadata(
66 docs::enum_tag_description = "The type of client key to use, when tracking ACKed message for message buffering."
67))]
68pub enum ClientKeyConfig {
69 IpAddress {
71 #[serde(default = "crate::serde::default_false")]
75 with_port: bool,
76 },
77 Header {
79 name: String,
81 },
82}
83
84const fn default_client_key_config() -> ClientKeyConfig {
85 ClientKeyConfig::IpAddress { with_port: false }
86}
87
88const fn default_max_events() -> NonZeroUsize {
89 unsafe { NonZeroUsize::new_unchecked(1000) }
90}
91
92const LAST_RECEIVED_QUERY_PARAM_NAME: &str = "last_received";
93
94pub struct BufferReplayRequest {
95 should_replay: bool,
96 replay_from: Option<Uuid>,
97}
98
99impl BufferReplayRequest {
100 pub const NO_REPLAY: Self = Self {
101 should_replay: false,
102 replay_from: None,
103 };
104 pub const REPLAY_ALL: Self = Self {
105 should_replay: true,
106 replay_from: None,
107 };
108
109 pub const fn with_replay_from(replay_from: Uuid) -> Self {
110 Self {
111 should_replay: true,
112 replay_from: Some(replay_from),
113 }
114 }
115
116 pub fn replay_messages(
117 &self,
118 buffer: &VecDeque<(Uuid, Message)>,
119 replay: impl FnMut(&(Uuid, Message)),
120 ) {
121 if self.should_replay {
122 buffer
123 .iter()
124 .filter(|(id, _)| Some(*id) > self.replay_from)
125 .for_each(replay);
126 }
127 }
128}
129
130pub trait WsMessageBufferConfig {
131 fn should_buffer(&self) -> bool;
133 fn client_key(&self, request: &Request, client_address: &SocketAddr) -> Option<String>;
136 fn buffer_capacity(&self) -> usize;
138 fn extract_message_replay_request(
140 &self,
141 request: &Request,
142 client_checkpoint: Option<Uuid>,
143 ) -> BufferReplayRequest;
144 fn add_replay_message_id_to_event(&self, event: &mut Event) -> Uuid;
147 fn handle_ack_request(&self, request: Message) -> Option<Uuid>;
149}
150
151impl WsMessageBufferConfig for Option<MessageBufferingConfig> {
152 fn should_buffer(&self) -> bool {
153 self.is_some()
154 }
155
156 fn client_key(&self, request: &Request, client_address: &SocketAddr) -> Option<String> {
157 self.as_ref()
158 .and_then(|mb| mb.client_ack_config.as_ref())
159 .and_then(|ack| match &ack.client_key {
160 ClientKeyConfig::IpAddress { with_port } => Some(if *with_port {
161 client_address.to_string()
162 } else {
163 client_address.ip().to_string()
164 }),
165 ClientKeyConfig::Header { name } => request
166 .headers()
167 .get(name)
168 .and_then(|h| h.to_str().ok())
169 .map(ToString::to_string),
170 })
171 }
172
173 fn buffer_capacity(&self) -> usize {
174 self.as_ref().map_or(0, |mb| mb.max_events.get())
175 }
176
177 fn extract_message_replay_request(
178 &self,
179 request: &Request,
180 client_checkpoint: Option<Uuid>,
181 ) -> BufferReplayRequest {
182 if self.is_none() {
184 return BufferReplayRequest::NO_REPLAY;
185 }
186
187 let default_request = client_checkpoint
188 .map(BufferReplayRequest::with_replay_from)
189 .unwrap_or(BufferReplayRequest::NO_REPLAY);
192
193 let Some(query_params) = request.uri().query() else {
195 return default_request;
196 };
197
198 if !query_params.contains(LAST_RECEIVED_QUERY_PARAM_NAME) {
200 return default_request;
201 }
202
203 let base_url = Url::parse("ws://localhost").ok();
205 match Url::options()
206 .base_url(base_url.as_ref())
207 .parse(request.uri().to_string().as_str())
208 {
209 Ok(url) => {
210 if let Some((_, last_received_param_value)) = url
211 .query_pairs()
212 .find(|(k, _)| k == LAST_RECEIVED_QUERY_PARAM_NAME)
213 {
214 match Uuid::parse_str(&last_received_param_value) {
215 Ok(last_received_val) => {
216 return BufferReplayRequest::with_replay_from(last_received_val);
217 }
218 Err(err) => {
219 warn!(message = "Parsing last received message UUID failed.", %err)
220 }
221 }
222 }
223 }
224 Err(err) => {
225 warn!(message = "Parsing request URL for websocket connection request failed.", %err)
226 }
227 }
228
229 BufferReplayRequest::REPLAY_ALL
232 }
233
234 fn add_replay_message_id_to_event(&self, event: &mut Event) -> Uuid {
235 let message_id = Uuid::now_v7();
236 if let Some(MessageBufferingConfig {
237 message_id_path: Some(message_id_path),
238 ..
239 }) = self
240 && let Some(log) = event.maybe_as_log_mut()
241 {
242 let mut buffer = [0; 36];
243 let uuid = message_id.hyphenated().encode_lower(&mut buffer);
244 log.value_mut()
245 .insert(message_id_path, Bytes::copy_from_slice(uuid.as_bytes()));
246 }
247 message_id
248 }
249
250 fn handle_ack_request(&self, request: Message) -> Option<Uuid> {
251 let ack_config = self.as_ref().and_then(|mb| mb.client_ack_config.as_ref())?;
252
253 let parsed_message = ack_config
254 .ack_decoding
255 .build()
256 .expect("Invalid `ack_decoding` config.")
257 .parse(request.into_data().into(), Default::default())
258 .inspect_err(|err| {
259 debug!(message = "Parsing ACK request failed.", %err);
260 })
261 .ok()?;
262
263 let Some(message_id_field) = parsed_message
264 .first()?
265 .maybe_as_log()?
266 .value()
267 .get(&ack_config.message_id_path)
268 else {
269 debug!("Couldn't find message ID in ACK request.");
270 return None;
271 };
272
273 message_id_field
274 .try_bytes_utf8_lossy()
275 .map_err(|_| "Message ID is not a valid string.")
276 .and_then(|id| {
277 Uuid::parse_str(id.trim()).map_err(|_| "Message ID is not a valid UUID.")
278 })
279 .inspect_err(|err| debug!(message = "Parsing message ID in ACK request failed.", %err))
280 .ok()
281 }
282}