vector/sinks/websocket_server/
buffering.rs1use crate::serde::default_decoding;
2use std::{collections::VecDeque, net::SocketAddr, num::NonZeroUsize};
3
4use bytes::Bytes;
5use derivative::Derivative;
6use tokio_tungstenite::tungstenite::{handshake::server::Request, Message};
7use url::Url;
8use uuid::Uuid;
9use vector_config::configurable_component;
10use vector_lib::{
11 codecs::decoding::{format::Deserializer as _, DeserializerConfig},
12 event::{Event, MaybeAsLogMut},
13 lookup::lookup_v2::ConfigValuePath,
14};
15use vrl::prelude::VrlValueConvert;
16
17#[configurable_component]
19#[derive(Clone, Debug)]
20pub struct MessageBufferingConfig {
21 #[serde(default = "default_max_events")]
26 pub max_events: NonZeroUsize,
27
28 #[serde(default, skip_serializing_if = "crate::serde::is_default")]
33 pub message_id_path: Option<ConfigValuePath>,
34
35 #[configurable(derived)]
36 pub client_ack_config: Option<BufferingAckConfig>,
37}
38
39#[configurable_component]
44#[derive(Clone, Debug, Derivative)]
45pub struct BufferingAckConfig {
46 #[configurable(derived)]
47 #[derivative(Default(value = "default_decoding()"))]
48 #[serde(default = "default_decoding")]
49 pub ack_decoding: DeserializerConfig,
50
51 pub message_id_path: ConfigValuePath,
54
55 #[configurable(derived)]
56 #[serde(default = "default_client_key_config")]
57 pub client_key: ClientKeyConfig,
58}
59
60#[configurable_component]
62#[derive(Clone, Debug)]
63#[serde(tag = "type", rename_all = "snake_case")]
64#[configurable(metadata(
65 docs::enum_tag_description = "The type of client key to use, when tracking ACKed message for message buffering."
66))]
67pub enum ClientKeyConfig {
68 IpAddress {
70 #[serde(default = "crate::serde::default_false")]
74 with_port: bool,
75 },
76 Header {
78 name: String,
80 },
81}
82
83const fn default_client_key_config() -> ClientKeyConfig {
84 ClientKeyConfig::IpAddress { with_port: false }
85}
86
87const fn default_max_events() -> NonZeroUsize {
88 unsafe { NonZeroUsize::new_unchecked(1000) }
89}
90
91const LAST_RECEIVED_QUERY_PARAM_NAME: &str = "last_received";
92
93pub struct BufferReplayRequest {
94 should_replay: bool,
95 replay_from: Option<Uuid>,
96}
97
98impl BufferReplayRequest {
99 pub const NO_REPLAY: Self = Self {
100 should_replay: false,
101 replay_from: None,
102 };
103 pub const REPLAY_ALL: Self = Self {
104 should_replay: true,
105 replay_from: None,
106 };
107
108 pub const fn with_replay_from(replay_from: Uuid) -> Self {
109 Self {
110 should_replay: true,
111 replay_from: Some(replay_from),
112 }
113 }
114
115 pub fn replay_messages(
116 &self,
117 buffer: &VecDeque<(Uuid, Message)>,
118 replay: impl FnMut(&(Uuid, Message)),
119 ) {
120 if self.should_replay {
121 buffer
122 .iter()
123 .filter(|(id, _)| Some(*id) > self.replay_from)
124 .for_each(replay);
125 }
126 }
127}
128
129pub trait WsMessageBufferConfig {
130 fn should_buffer(&self) -> bool;
132 fn client_key(&self, request: &Request, client_address: &SocketAddr) -> Option<String>;
135 fn buffer_capacity(&self) -> usize;
137 fn extract_message_replay_request(
139 &self,
140 request: &Request,
141 client_checkpoint: Option<Uuid>,
142 ) -> BufferReplayRequest;
143 fn add_replay_message_id_to_event(&self, event: &mut Event) -> Uuid;
146 fn handle_ack_request(&self, request: Message) -> Option<Uuid>;
148}
149
150impl WsMessageBufferConfig for Option<MessageBufferingConfig> {
151 fn should_buffer(&self) -> bool {
152 self.is_some()
153 }
154
155 fn client_key(&self, request: &Request, client_address: &SocketAddr) -> Option<String> {
156 self.as_ref()
157 .and_then(|mb| mb.client_ack_config.as_ref())
158 .and_then(|ack| match &ack.client_key {
159 ClientKeyConfig::IpAddress { with_port } => Some(if *with_port {
160 client_address.to_string()
161 } else {
162 client_address.ip().to_string()
163 }),
164 ClientKeyConfig::Header { name } => request
165 .headers()
166 .get(name)
167 .and_then(|h| h.to_str().ok())
168 .map(ToString::to_string),
169 })
170 }
171
172 fn buffer_capacity(&self) -> usize {
173 self.as_ref().map_or(0, |mb| mb.max_events.get())
174 }
175
176 fn extract_message_replay_request(
177 &self,
178 request: &Request,
179 client_checkpoint: Option<Uuid>,
180 ) -> BufferReplayRequest {
181 if self.is_none() {
183 return BufferReplayRequest::NO_REPLAY;
184 }
185
186 let default_request = client_checkpoint
187 .map(BufferReplayRequest::with_replay_from)
188 .unwrap_or(BufferReplayRequest::NO_REPLAY);
191
192 let Some(query_params) = request.uri().query() else {
194 return default_request;
195 };
196
197 if !query_params.contains(LAST_RECEIVED_QUERY_PARAM_NAME) {
199 return default_request;
200 }
201
202 let base_url = Url::parse("ws://localhost").ok();
204 match Url::options()
205 .base_url(base_url.as_ref())
206 .parse(request.uri().to_string().as_str())
207 {
208 Ok(url) => {
209 if let Some((_, last_received_param_value)) = url
210 .query_pairs()
211 .find(|(k, _)| k == LAST_RECEIVED_QUERY_PARAM_NAME)
212 {
213 match Uuid::parse_str(&last_received_param_value) {
214 Ok(last_received_val) => {
215 return BufferReplayRequest::with_replay_from(last_received_val)
216 }
217 Err(err) => {
218 warn!(message = "Parsing last received message UUID failed.", %err)
219 }
220 }
221 }
222 }
223 Err(err) => {
224 warn!(message = "Parsing request URL for websocket connection request failed.", %err)
225 }
226 }
227
228 BufferReplayRequest::REPLAY_ALL
231 }
232
233 fn add_replay_message_id_to_event(&self, event: &mut Event) -> Uuid {
234 let message_id = Uuid::now_v7();
235 if let Some(MessageBufferingConfig {
236 message_id_path: Some(message_id_path),
237 ..
238 }) = self
239 {
240 if let Some(log) = event.maybe_as_log_mut() {
241 let mut buffer = [0; 36];
242 let uuid = message_id.hyphenated().encode_lower(&mut buffer);
243 log.value_mut()
244 .insert(message_id_path, Bytes::copy_from_slice(uuid.as_bytes()));
245 }
246 }
247 message_id
248 }
249
250 fn handle_ack_request(&self, request: Message) -> Option<Uuid> {
251 let ack_config = self.as_ref().and_then(|mb| mb.client_ack_config.as_ref())?;
252
253 let parsed_message = ack_config
254 .ack_decoding
255 .build()
256 .expect("Invalid `ack_decoding` config.")
257 .parse(request.into_data().into(), Default::default())
258 .inspect_err(|err| {
259 debug!(message = "Parsing ACK request failed.", %err);
260 })
261 .ok()?;
262
263 let Some(message_id_field) = parsed_message
264 .first()?
265 .maybe_as_log()?
266 .value()
267 .get(&ack_config.message_id_path)
268 else {
269 debug!("Couldn't find message ID in ACK request.");
270 return None;
271 };
272
273 message_id_field
274 .try_bytes_utf8_lossy()
275 .map_err(|_| "Message ID is not a valid string.")
276 .and_then(|id| {
277 Uuid::parse_str(id.trim()).map_err(|_| "Message ID is not a valid UUID.")
278 })
279 .inspect_err(|err| debug!(message = "Parsing message ID in ACK request failed.", %err))
280 .ok()
281 }
282}