1use std::{
2 io,
3 num::NonZeroU64,
4 time::{Duration, Instant},
5};
6
7use crate::{
8 codecs::{Encoder, Transformer},
9 common::websocket::{PingInterval, WebSocketConnector, is_closed},
10 event::{Event, EventStatus, Finalizable},
11 internal_events::{
12 ConnectionOpen, OpenGauge, WebSocketConnectionError, WebSocketConnectionShutdown,
13 },
14 sinks::{util::StreamSink, websocket::config::WebSocketSinkConfig},
15};
16use async_trait::async_trait;
17use bytes::BytesMut;
18use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt, stream::BoxStream};
19use tokio_tungstenite::tungstenite::{error::Error as TungsteniteError, protocol::Message};
20use tokio_util::codec::Encoder as _;
21use vector_lib::{
22 EstimatedJsonEncodedSizeOf, emit,
23 internal_event::{
24 ByteSize, BytesSent, CountByteSize, EventsSent, InternalEventHandle as _, Output, Protocol,
25 },
26};
27
28pub struct WebSocketSink {
29 transformer: Transformer,
30 encoder: Encoder<()>,
31 connector: WebSocketConnector,
32 ping_interval: Option<NonZeroU64>,
33 ping_timeout: Option<NonZeroU64>,
34}
35
36impl WebSocketSink {
37 pub(crate) fn new(
38 config: &WebSocketSinkConfig,
39 connector: WebSocketConnector,
40 ) -> crate::Result<Self> {
41 let transformer = config.encoding.transformer();
42 let serializer = config.encoding.build()?;
43 let encoder = Encoder::<()>::new(serializer);
44
45 Ok(Self {
46 transformer,
47 encoder,
48 connector,
49 ping_interval: config.common.ping_interval,
50 ping_timeout: config.common.ping_timeout,
51 })
52 }
53
54 async fn create_sink_and_stream(
55 &self,
56 ) -> (
57 impl Sink<Message, Error = TungsteniteError> + use<>,
58 impl Stream<Item = Result<Message, TungsteniteError>> + use<>,
59 ) {
60 let ws_stream = self.connector.connect_backoff().await;
61 ws_stream.split()
62 }
63
64 fn check_received_pong_time(&self, last_pong: Instant) -> Result<(), TungsteniteError> {
65 if let Some(ping_timeout) = self.ping_timeout
66 && last_pong.elapsed() > Duration::from_secs(ping_timeout.into())
67 {
68 return Err(TungsteniteError::Io(io::Error::new(
69 io::ErrorKind::TimedOut,
70 "Pong not received in time",
71 )));
72 }
73
74 Ok(())
75 }
76
77 async fn handle_events<I, WS, O>(
78 &mut self,
79 input: &mut I,
80 ws_stream: &mut WS,
81 ws_sink: &mut O,
82 ) -> Result<(), ()>
83 where
84 I: Stream<Item = Event> + Unpin,
85 WS: Stream<Item = Result<Message, TungsteniteError>> + Unpin,
86 O: Sink<Message, Error = TungsteniteError> + Unpin,
87 {
88 const PING: &[u8] = b"PING";
89
90 let mut ping_interval = PingInterval::new(self.ping_interval.map(u64::from));
93
94 if let Err(error) = ws_sink.send(Message::Ping(PING.to_vec())).await {
95 emit!(WebSocketConnectionError { error });
96 return Err(());
97 }
98 let mut last_pong = Instant::now();
99
100 let bytes_sent = register!(BytesSent::from(Protocol("websocket".into())));
101 let events_sent = register!(EventsSent::from(Output(None)));
102 let encode_as_binary = self.encoder.serializer().is_binary();
103
104 loop {
105 let result = tokio::select! {
106 _ = ping_interval.tick() => {
107 match self.check_received_pong_time(last_pong) {
108 Ok(()) => ws_sink.send(Message::Ping(PING.to_vec())).await.map(|_| ()),
109 Err(e) => Err(e)
110 }
111 },
112
113 Some(msg) = ws_stream.next() => {
114 match msg {
116 Ok(Message::Pong(_)) => {
117 last_pong = Instant::now();
118 Ok(())
119 },
120 Ok(_) => Ok(()),
121 Err(e) => Err(e)
122 }
123 },
124
125 event = input.next() => {
126 let mut event = if let Some(event) = event {
127 event
128 } else {
129 break;
130 };
131
132 let finalizers = event.take_finalizers();
133
134 self.transformer.transform(&mut event);
135
136 let event_byte_size = event.estimated_json_encoded_size_of();
137
138 let mut bytes = BytesMut::new();
139 match self.encoder.encode(event, &mut bytes) {
140 Ok(()) => {
141 finalizers.update_status(EventStatus::Delivered);
142
143 let message = if encode_as_binary {
144 Message::binary(bytes)
145 }
146 else {
147 Message::text(String::from_utf8_lossy(&bytes))
148 };
149 let message_len = message.len();
150
151 ws_sink.send(message).await.map(|_| {
152 events_sent.emit(CountByteSize(1, event_byte_size));
153 bytes_sent.emit(ByteSize(message_len));
154 })
155 },
156 Err(_) => {
157 finalizers.update_status(EventStatus::Errored);
159 Ok(())
160 }
161 }
162 },
163 else => break,
164 };
165
166 if let Err(error) = result {
167 if is_closed(&error) {
168 emit!(WebSocketConnectionShutdown);
169 } else {
170 emit!(WebSocketConnectionError { error });
171 }
172 return Err(());
173 }
174 }
175
176 Ok(())
177 }
178}
179
180#[async_trait]
181impl StreamSink<Event> for WebSocketSink {
182 async fn run(mut self: Box<Self>, input: BoxStream<'_, Event>) -> Result<(), ()> {
183 let input = input.fuse().peekable();
184 pin_mut!(input);
185
186 while input.as_mut().peek().await.is_some() {
187 let (ws_sink, ws_stream) = self.create_sink_and_stream().await;
188 pin_mut!(ws_sink);
189 pin_mut!(ws_stream);
190
191 let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
192
193 if self
194 .handle_events(&mut input, &mut ws_stream, &mut ws_sink)
195 .await
196 .is_ok()
197 {
198 _ = ws_sink.close().await;
199 }
200 }
201
202 Ok(())
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use std::net::SocketAddr;
209
210 use futures::{FutureExt, StreamExt, future};
211 use serde_json::Value as JsonValue;
212 use tokio::{time, time::timeout};
213 use tokio_tungstenite::{
214 accept_async, accept_hdr_async,
215 tungstenite::{
216 error::ProtocolError,
217 handshake::server::{Request, Response},
218 },
219 };
220 use vector_lib::codecs::JsonSerializerConfig;
221
222 use super::*;
223 use crate::{
224 common::websocket::WebSocketCommonConfig,
225 config::{SinkConfig, SinkContext},
226 http::Auth,
227 test_util::{
228 CountReceiver,
229 addr::next_addr,
230 components::{SINK_TAGS, run_and_assert_sink_compliance},
231 random_lines_with_stream, trace_init,
232 },
233 tls::{self, MaybeTlsSettings, TlsConfig, TlsEnableableConfig},
234 };
235
236 #[tokio::test(flavor = "multi_thread")]
237 async fn test_websocket() {
238 trace_init();
239
240 let (_guard, addr) = next_addr();
241 let config = WebSocketSinkConfig {
242 common: WebSocketCommonConfig {
243 uri: format!("ws://{addr}"),
244 tls: None,
245 ping_interval: None,
246 ping_timeout: None,
247 auth: None,
248 },
249 encoding: JsonSerializerConfig::default().into(),
250 acknowledgements: Default::default(),
251 };
252 let tls = MaybeTlsSettings::Raw(());
253
254 send_events_and_assert(addr, config, tls, None).await;
255 }
256
257 #[tokio::test(flavor = "multi_thread")]
258 async fn test_auth_websocket() {
259 trace_init();
260
261 let auth = Some(Auth::Bearer {
262 token: "OiJIUzI1NiIsInR5cCI6IkpXVCJ".to_string().into(),
263 });
264 let auth_clone = auth.clone();
265 let (_guard, addr) = next_addr();
266 let config = WebSocketSinkConfig {
267 common: WebSocketCommonConfig {
268 uri: format!("ws://{addr}"),
269 tls: None,
270 ping_interval: None,
271 ping_timeout: None,
272 auth: None,
273 },
274 encoding: JsonSerializerConfig::default().into(),
275 acknowledgements: Default::default(),
276 };
277 let tls = MaybeTlsSettings::Raw(());
278
279 send_events_and_assert(addr, config, tls, auth_clone).await;
280 }
281
282 #[tokio::test(flavor = "multi_thread")]
283 async fn test_tls_websocket() {
284 trace_init();
285
286 let (_guard, addr) = next_addr();
287 let tls_config = Some(TlsEnableableConfig::test_config());
288 let tls = MaybeTlsSettings::from_config(tls_config.as_ref(), true).unwrap();
289
290 let config = WebSocketSinkConfig {
291 common: WebSocketCommonConfig {
292 uri: format!("wss://{addr}"),
293 tls: Some(TlsEnableableConfig {
294 enabled: Some(true),
295 options: TlsConfig {
296 verify_certificate: Some(false),
297 verify_hostname: Some(true),
298 ca_file: Some(tls::TEST_PEM_CRT_PATH.into()),
299 ..Default::default()
300 },
301 }),
302 ping_timeout: None,
303 ping_interval: None,
304 auth: None,
305 },
306 encoding: JsonSerializerConfig::default().into(),
307 acknowledgements: Default::default(),
308 };
309
310 send_events_and_assert(addr, config, tls, None).await;
311 }
312
313 #[tokio::test]
314 async fn test_websocket_reconnect() {
315 trace_init();
316
317 let (_guard, addr) = next_addr();
318 let config = WebSocketSinkConfig {
319 common: WebSocketCommonConfig {
320 uri: format!("ws://{addr}"),
321 tls: None,
322 ping_interval: None,
323 ping_timeout: None,
324 auth: None,
325 },
326 encoding: JsonSerializerConfig::default().into(),
327 acknowledgements: Default::default(),
328 };
329 let tls = MaybeTlsSettings::Raw(());
330
331 let mut receiver = create_count_receiver(addr, tls.clone(), true, None);
332
333 let context = SinkContext::default();
334 let (sink, _healthcheck) = config.build(context).await.unwrap();
335
336 let (_lines, events) = random_lines_with_stream(10, 100, None);
337 let events = events.then(|event| async move {
338 time::sleep(Duration::from_millis(10)).await;
339 event
340 });
341 drop(tokio::spawn(sink.run(events)));
342
343 receiver.connected().await;
344 time::sleep(Duration::from_millis(500)).await;
345 assert!(!receiver.await.is_empty());
346
347 let mut receiver = create_count_receiver(addr, tls, false, None);
348 assert!(
349 timeout(Duration::from_secs(10), receiver.connected())
350 .await
351 .is_ok()
352 );
353 }
354
355 async fn send_events_and_assert(
356 addr: SocketAddr,
357 config: WebSocketSinkConfig,
358 tls: MaybeTlsSettings,
359 auth: Option<Auth>,
360 ) {
361 let mut receiver = create_count_receiver(addr, tls, false, auth);
362
363 let context = SinkContext::default();
364 let (sink, _healthcheck) = config.build(context).await.unwrap();
365
366 let (lines, events) = random_lines_with_stream(10, 100, None);
367 run_and_assert_sink_compliance(sink, events, &SINK_TAGS).await;
368
369 receiver.connected().await;
370
371 let output = receiver.await;
372 assert_eq!(lines.len(), output.len());
373 let message_key = crate::config::log_schema()
374 .message_key()
375 .expect("global log_schema.message_key to be valid path")
376 .to_string();
377 for (source, received) in lines.iter().zip(output) {
378 let json = serde_json::from_str::<JsonValue>(&received).expect("Invalid JSON");
379 let received = json.get(message_key.as_str()).unwrap().as_str().unwrap();
380 assert_eq!(source, received);
381 }
382 }
383
384 fn create_count_receiver(
385 addr: SocketAddr,
386 tls: MaybeTlsSettings,
387 interrupt_stream: bool,
388 auth: Option<Auth>,
389 ) -> CountReceiver<String> {
390 CountReceiver::receive_items_stream(move |tripwire, connected| async move {
391 let listener = tls.bind(&addr).await.unwrap();
392 let stream = listener.accept_stream();
393
394 let tripwire = tripwire.map(|_| ()).shared();
395 let stream_tripwire = tripwire.clone();
396 let mut connected = Some(connected);
397
398 let stream = stream
399 .take_until(tripwire)
400 .filter_map(move |maybe_tls_stream| {
401 let au = auth.clone();
402 async move {
403 let maybe_tls_stream = maybe_tls_stream.unwrap();
404 let ws_stream = match au {
405 Some(a) => {
406 let auth_callback = |req: &Request, res: Response| {
407 let hdr = req.headers().get("Authorization");
408 if let Some(h) = hdr {
409 match a {
410 Auth::Bearer { token } => {
411 if format!("Bearer {}", token.inner())
412 != h.to_str().unwrap()
413 {
414 return Err(
415 http::Response::<Option<String>>::new(None),
416 );
417 }
418 }
419 Auth::Basic {
420 user: _user,
421 password: _password,
422 } => { }
423 Auth::Custom { .. } => { }
424 #[cfg(feature = "aws-core")]
425 _ => {}
426 }
427 }
428 Ok(res)
429 };
430 accept_hdr_async(maybe_tls_stream, auth_callback)
431 .await
432 .unwrap()
433 }
434 None => accept_async(maybe_tls_stream).await.unwrap(),
435 };
436
437 Some(
438 ws_stream
439 .filter_map(|msg| {
440 future::ready(match msg {
441 Ok(msg) if msg.is_text() => {
442 Some(Ok(msg.into_text().unwrap()))
443 }
444 Err(TungsteniteError::Protocol(
445 ProtocolError::ResetWithoutClosingHandshake,
446 )) => None,
447 Err(e) => Some(Err(e)),
448 _ => None,
449 })
450 })
451 .take_while(|msg| future::ready(msg.is_ok()))
452 .filter_map(|msg| future::ready(msg.ok())),
453 )
454 }
455 })
456 .map(move |ws_stream| {
457 connected.take().map(|trigger| trigger.send(()));
458 ws_stream
459 })
460 .flatten();
461
462 match interrupt_stream {
463 false => stream.boxed(),
464 true => stream.take_until(stream_tripwire).boxed(),
465 }
466 })
467 }
468}