1use std::pin::Pin;
2
3use chrono::Utc;
4use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt};
5use snafu::Snafu;
6use tokio::time;
7use tokio_tungstenite::tungstenite::{
8 Message, error::Error as TungsteniteError, protocol::CloseFrame,
9};
10use tokio_util::codec::FramedRead;
11use vector_lib::{
12 EstimatedJsonEncodedSizeOf,
13 config::LogNamespace,
14 event::{Event, LogEvent},
15 internal_event::{CountByteSize, EventsReceived, InternalEventHandle as _},
16};
17
18use crate::{
19 SourceSender,
20 codecs::Decoder,
21 common::websocket::{PingInterval, WebSocketConnector, is_closed},
22 config::SourceContext,
23 internal_events::{
24 ConnectionOpen, OpenGauge, PROTOCOL, WebSocketBytesReceived, WebSocketConnectionError,
25 WebSocketConnectionEstablished, WebSocketConnectionFailedError,
26 WebSocketConnectionShutdown, WebSocketKind, WebSocketMessageReceived,
27 WebSocketReceiveError, WebSocketSendError,
28 },
29 sources::websocket::config::WebSocketConfig,
30 vector_lib::codecs::StreamDecodingError,
31};
32
33macro_rules! fail_with_event {
34 ($context:expr_2021) => {{
35 emit!(WebSocketConnectionFailedError {
36 error: Box::new($context.build())
37 });
38 return $context.fail();
39 }};
40}
41
42type WebSocketSink = Pin<Box<dyn Sink<Message, Error = TungsteniteError> + Send>>;
43type WebSocketStream = Pin<Box<dyn Stream<Item = Result<Message, TungsteniteError>> + Send>>;
44
45pub(crate) struct WebSocketSourceParams {
46 pub connector: WebSocketConnector,
47 pub decoder: Decoder,
48 pub log_namespace: LogNamespace,
49}
50
51pub(crate) struct WebSocketSource {
52 config: WebSocketConfig,
53 params: WebSocketSourceParams,
54}
55
56#[derive(Debug, Snafu)]
57pub enum WebSocketSourceError {
58 #[snafu(display("Connection attempt timed out"))]
59 ConnectTimeout,
60
61 #[snafu(display("Server did not respond to the initial message in time"))]
62 InitialMessageTimeout,
63
64 #[snafu(display(
65 "The connection was closed after sending the initial message, but before a response."
66 ))]
67 ConnectionClosedPrematurely,
68
69 #[snafu(display("Connection closed by server with code '{}' and reason: '{}'", frame.code, frame.reason))]
70 RemoteClosed { frame: CloseFrame<'static> },
71
72 #[snafu(display("Connection closed by server without a close frame"))]
73 RemoteClosedEmpty,
74
75 #[snafu(display("Connection timed out while waiting for a pong response"))]
76 PongTimeout,
77
78 #[snafu(display("A WebSocket error occurred: {}", source))]
79 Tungstenite { source: TungsteniteError },
80}
81
82impl WebSocketSource {
83 pub const fn new(config: WebSocketConfig, params: WebSocketSourceParams) -> Self {
84 Self { config, params }
85 }
86
87 pub async fn run(self, cx: SourceContext) -> Result<(), WebSocketSourceError> {
88 let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
89 let mut ping_manager = PingManager::new(&self.config);
90
91 let mut out = cx.out;
92
93 let (ws_sink, ws_source) = self.connect(&mut out).await?;
94
95 pin_mut!(ws_sink, ws_source);
96
97 loop {
98 let result = tokio::select! {
99 _ = cx.shutdown.clone() => {
100 info!(internal_log_rate_limit = true, "Received shutdown signal.");
101 break;
102 },
103
104 res = ping_manager.tick(&mut ws_sink) => res,
105
106 Some(msg_result) = ws_source.next() => {
107 match msg_result {
108 Ok(msg) => self.handle_message(msg, &mut ping_manager, &mut out).await,
109 Err(error) => {
110 emit!(WebSocketReceiveError { error: &error });
111 Err(WebSocketSourceError::Tungstenite { source: error })
112 }
113 }
114 }
115 };
116
117 if let Err(error) = result {
118 match error {
119 WebSocketSourceError::RemoteClosed { frame } => {
120 warn!(
121 message = "Connection closed by server.",
122 code = %frame.code,
123 reason = %frame.reason,
124 internal_log_rate_limit = true
125 );
126 emit!(WebSocketConnectionShutdown);
127 }
128 WebSocketSourceError::RemoteClosedEmpty => {
129 warn!(
130 internal_log_rate_limit = true,
131 "Connection closed by server without a close frame."
132 );
133 emit!(WebSocketConnectionShutdown);
134 }
135 WebSocketSourceError::PongTimeout => {
136 error!(
137 internal_log_rate_limit = true,
138 "Disconnecting due to pong timeout."
139 );
140 emit!(WebSocketReceiveError {
141 error: &TungsteniteError::Io(std::io::Error::new(
142 std::io::ErrorKind::TimedOut,
143 "Pong timeout"
144 ))
145 });
146 emit!(WebSocketConnectionShutdown);
147 return Err(error);
148 }
149 WebSocketSourceError::Tungstenite { source: ws_err } => {
150 if is_closed(&ws_err) {
151 emit!(WebSocketConnectionShutdown);
152 }
153 error!(message = "WebSocket connection error.", error = %ws_err, internal_log_rate_limit = true);
154 }
155 WebSocketSourceError::ConnectTimeout
158 | WebSocketSourceError::InitialMessageTimeout
159 | WebSocketSourceError::ConnectionClosedPrematurely => {
160 unreachable!(
161 "Encountered a connection-time error during runtime: {:?}",
162 error
163 );
164 }
165 }
166 if self
167 .reconnect(&mut out, &mut ws_sink, &mut ws_source)
168 .await
169 .is_err()
170 {
171 break;
172 }
173 }
174 }
175 Ok(())
176 }
177
178 async fn handle_message(
179 &self,
180 msg: Message,
181 ping_manager: &mut PingManager,
182 out: &mut SourceSender,
183 ) -> Result<(), WebSocketSourceError> {
184 match msg {
185 Message::Pong(_) => {
186 ping_manager.record_pong();
187 Ok(())
188 }
189 Message::Text(msg_txt) => {
190 if self.is_custom_pong(&msg_txt) {
191 ping_manager.record_pong();
192 debug!("Received custom pong response.");
193 } else {
194 self.process_message(&msg_txt, WebSocketKind::Text, out)
195 .await;
196 }
197 Ok(())
198 }
199 Message::Binary(msg_bytes) => {
200 self.process_message(&msg_bytes, WebSocketKind::Binary, out)
201 .await;
202 Ok(())
203 }
204 Message::Ping(_) => Ok(()),
205 Message::Close(frame) => self.handle_close_frame(frame),
206 Message::Frame(_) => {
207 warn!(
208 internal_log_rate_limit = true,
209 "Unsupported message type received: frame."
210 );
211 Ok(())
212 }
213 }
214 }
215
216 async fn process_message<T>(&self, payload: &T, kind: WebSocketKind, out: &mut SourceSender)
217 where
218 T: AsRef<[u8]> + ?Sized,
219 {
220 let payload_bytes = payload.as_ref();
221
222 emit!(WebSocketBytesReceived {
223 byte_size: payload_bytes.len(),
224 url: &self.config.common.uri,
225 protocol: PROTOCOL,
226 kind,
227 });
228 let mut stream = FramedRead::new(payload_bytes, self.params.decoder.clone());
229
230 while let Some(result) = stream.next().await {
231 match result {
232 Ok((events, _)) => {
233 if events.is_empty() {
234 continue;
235 }
236
237 let event_count = events.len();
238 let byte_size = events.estimated_json_encoded_size_of();
239
240 register!(EventsReceived).emit(CountByteSize(event_count, byte_size));
241 emit!(WebSocketMessageReceived {
242 count: event_count,
243 byte_size,
244 url: &self.config.common.uri,
245 protocol: PROTOCOL,
246 kind,
247 });
248
249 let events_with_meta = events.into_iter().map(|mut event| {
250 if let Event::Log(event) = &mut event {
251 self.add_metadata(event);
252 }
253 event
254 });
255
256 if let Err(error) = out.send_batch(events_with_meta).await {
257 error!(message = "Error sending events.", %error, internal_log_rate_limit = true);
258 }
259 }
260 Err(error) => {
261 if !error.can_continue() {
262 break;
263 }
264 }
265 }
266 }
267 }
268
269 fn add_metadata(&self, event: &mut LogEvent) {
270 self.params
271 .log_namespace
272 .insert_standard_vector_source_metadata(event, WebSocketConfig::NAME, Utc::now());
273 }
274
275 async fn reconnect(
276 &self,
277 out: &mut SourceSender,
278 ws_sink: &mut WebSocketSink,
279 ws_source: &mut WebSocketStream,
280 ) -> Result<(), WebSocketSourceError> {
281 info!(
282 internal_log_rate_limit = true,
283 "Reconnecting to WebSocket..."
284 );
285
286 let (new_sink, new_source) = self.connect(out).await?;
287
288 *ws_sink = new_sink;
289 *ws_source = new_source;
290
291 info!(internal_log_rate_limit = true, "Reconnected to Websocket.");
292
293 Ok(())
294 }
295
296 async fn connect(
297 &self,
298 out: &mut SourceSender,
299 ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
300 let (mut ws_sink, mut ws_source) = self.try_create_sink_and_stream().await?;
301
302 if self.config.initial_message.is_some() {
303 self.send_initial_message(&mut ws_sink, &mut ws_source, out)
304 .await?;
305 }
306
307 Ok((ws_sink, ws_source))
308 }
309
310 async fn try_create_sink_and_stream(
311 &self,
312 ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
313 let connect_future = self.params.connector.connect_backoff();
314 let timeout = self.config.connect_timeout_secs;
315
316 let ws_stream = match time::timeout(timeout, connect_future).await {
317 Ok(ws) => ws,
318 Err(_) => {
319 emit!(WebSocketConnectionError {
320 error: TungsteniteError::Io(std::io::Error::new(
321 std::io::ErrorKind::TimedOut,
322 "Connection attempt timed out",
323 ))
324 });
325 return Err(WebSocketSourceError::ConnectTimeout);
326 }
327 };
328
329 emit!(WebSocketConnectionEstablished {});
330 let (sink, stream) = ws_stream.split();
331
332 Ok((Box::pin(sink), Box::pin(stream)))
333 }
334
335 async fn send_initial_message(
336 &self,
337 ws_sink: &mut WebSocketSink,
338 ws_source: &mut WebSocketStream,
339 out: &mut SourceSender,
340 ) -> Result<(), WebSocketSourceError> {
341 let initial_message = self.config.initial_message.as_ref().unwrap();
342 ws_sink
343 .send(Message::Text(initial_message.clone()))
344 .await
345 .map_err(|error| {
346 emit!(WebSocketSendError { error: &error });
347 WebSocketSourceError::Tungstenite { source: error }
348 })?;
349
350 debug!("Sent initial message, awaiting response from server.");
351
352 let response =
353 match time::timeout(self.config.initial_message_timeout_secs, ws_source.next()).await {
354 Ok(Some(msg)) => msg,
355 Ok(None) => fail_with_event!(ConnectionClosedPrematurelySnafu),
356 Err(_) => fail_with_event!(InitialMessageTimeoutSnafu),
357 };
358
359 let message = response.map_err(|source| {
360 emit!(WebSocketReceiveError { error: &source });
361 WebSocketSourceError::Tungstenite { source }
362 })?;
363
364 match message {
365 Message::Text(txt) => {
366 self.process_message(&txt, WebSocketKind::Text, out).await;
367 Ok(())
368 }
369 Message::Binary(bin) => {
370 self.process_message(&bin, WebSocketKind::Binary, out).await;
371 Ok(())
372 }
373 Message::Close(frame) => self.handle_close_frame(frame),
374 _ => Ok(()),
375 }
376 }
377
378 fn handle_close_frame(
379 &self,
380 frame: Option<CloseFrame<'_>>,
381 ) -> Result<(), WebSocketSourceError> {
382 let (error_message, specific_error) = match frame {
383 Some(frame) => {
384 let msg = format!(
385 "Connection closed by server with code '{}' and reason: '{}'",
386 frame.code, frame.reason
387 );
388 let err = WebSocketSourceError::RemoteClosed {
389 frame: frame.into_owned(),
390 };
391 (msg, err)
392 }
393 None => (
394 "Connection closed by server without a close frame".to_string(),
395 WebSocketSourceError::RemoteClosedEmpty,
396 ),
397 };
398
399 let error = TungsteniteError::Io(std::io::Error::new(
400 std::io::ErrorKind::ConnectionAborted,
401 error_message,
402 ));
403 emit!(WebSocketReceiveError { error: &error });
404
405 Err(specific_error)
406 }
407
408 fn is_custom_pong(&self, msg_txt: &str) -> bool {
409 match self.config.pong_message.as_ref() {
410 Some(config) => config.matches(msg_txt),
411 None => false,
412 }
413 }
414}
415
416struct PingManager {
417 interval: PingInterval,
418 waiting_for_pong: bool,
419 message: Message,
420}
421
422impl PingManager {
423 fn new(config: &WebSocketConfig) -> Self {
424 let ping_message = if let Some(ping_msg) = &config.ping_message {
425 Message::Text(ping_msg.clone())
426 } else {
427 Message::Ping(vec![])
428 };
429
430 Self {
431 interval: PingInterval::new(config.common.ping_interval.map(u64::from)),
432 waiting_for_pong: false,
433 message: ping_message,
434 }
435 }
436
437 const fn record_pong(&mut self) {
438 self.waiting_for_pong = false;
439 }
440
441 async fn tick(&mut self, ws_sink: &mut WebSocketSink) -> Result<(), WebSocketSourceError> {
442 self.interval.tick().await;
443
444 if self.waiting_for_pong {
445 return Err(WebSocketSourceError::PongTimeout);
446 }
447
448 ws_sink.send(self.message.clone()).await.map_err(|error| {
449 emit!(WebSocketSendError { error: &error });
450 WebSocketSourceError::Tungstenite { source: error }
451 })?;
452
453 self.waiting_for_pong = true;
454 Ok(())
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use std::{borrow::Cow, num::NonZeroU64};
461
462 use futures::{StreamExt, sink::SinkExt};
463 use tokio::{net::TcpListener, time::Duration};
464 use tokio_tungstenite::{
465 accept_async,
466 tungstenite::{
467 Message,
468 protocol::frame::{CloseFrame, coding::CloseCode},
469 },
470 };
471 use url::Url;
472 use vector_lib::codecs::decoding::DeserializerConfig;
473
474 use crate::{
475 common::websocket::WebSocketCommonConfig,
476 sources::websocket::config::{PongMessage, WebSocketConfig},
477 test_util::{
478 components::{
479 SOURCE_TAGS, run_and_assert_source_compliance, run_and_assert_source_error,
480 },
481 next_addr,
482 },
483 };
484
485 fn make_config(uri: &str) -> WebSocketConfig {
486 WebSocketConfig {
487 common: WebSocketCommonConfig {
488 uri: Url::parse(uri).unwrap().to_string(),
489 ..Default::default()
490 },
491 ..Default::default()
492 }
493 }
494
495 async fn start_binary_push_server() -> String {
497 let addr = next_addr();
498 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
499 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
500
501 tokio::spawn(async move {
502 let (stream, _) = listener.accept().await.unwrap();
503 let mut websocket = accept_async(stream).await.expect("Failed to accept");
504
505 let binary_payload = br#"{"message": "binary data"}"#.to_vec();
506 websocket
507 .send(Message::Binary(binary_payload))
508 .await
509 .unwrap();
510 });
511
512 server_addr
513 }
514
515 async fn start_push_server() -> String {
517 let addr = next_addr();
518 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
519 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
520
521 tokio::spawn(async move {
522 let (stream, _) = listener.accept().await.unwrap();
524 let mut websocket = accept_async(stream).await.expect("Failed to accept");
525
526 websocket
528 .send(Message::Text("message from server".to_string()))
529 .await
530 .unwrap();
531 });
532
533 server_addr
534 }
535
536 async fn start_subscribe_server(initial_message: String, response_message: String) -> String {
539 let addr = next_addr();
540 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
541 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
542
543 tokio::spawn(async move {
544 let (stream, _) = listener.accept().await.unwrap();
545 let mut websocket = accept_async(stream).await.expect("Failed to accept");
546
547 if let Some(Ok(Message::Text(msg))) = websocket.next().await
549 && msg == initial_message
550 {
551 websocket
553 .send(Message::Text(response_message))
554 .await
555 .unwrap();
556 }
557 });
558
559 server_addr
560 }
561
562 async fn start_reconnect_server() -> String {
563 let addr = next_addr();
564 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
565 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
566
567 tokio::spawn(async move {
568 let (stream, _) = listener.accept().await.unwrap();
570 let mut websocket = accept_async(stream).await.expect("Failed to accept");
571 websocket
572 .send(Message::Text("first message".to_string()))
573 .await
574 .unwrap();
575 websocket.close(None).await.unwrap();
577
578 let (stream, _) = listener.accept().await.unwrap();
580 let mut websocket = accept_async(stream).await.expect("Failed to accept");
581 websocket
582 .send(Message::Text("second message".to_string()))
583 .await
584 .unwrap();
585 });
586
587 server_addr
588 }
589
590 #[tokio::test(flavor = "multi_thread")]
591 async fn websocket_source_reconnects_after_disconnect() {
592 let server_addr = start_reconnect_server().await;
593 let config = make_config(&server_addr);
594
595 let events =
597 run_and_assert_source_compliance(config, Duration::from_secs(5), &SOURCE_TAGS).await;
598
599 assert_eq!(
600 events.len(),
601 2,
602 "Should have received messages from both connections"
603 );
604
605 let event = events[0].as_log();
606 assert_eq!(event["message"], "first message".into());
607
608 let event = events[1].as_log();
609 assert_eq!(event["message"], "second message".into());
610 }
611
612 #[tokio::test(flavor = "multi_thread")]
613 async fn websocket_source_consume_binary_event() {
614 let server_addr = start_binary_push_server().await;
615 let mut config = make_config(&server_addr);
616 let decoding = DeserializerConfig::Json(Default::default());
617 config.decoding = decoding;
618
619 let events =
620 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
621
622 assert!(!events.is_empty(), "No events received from source");
623 let event = events[0].as_log();
624 assert_eq!(event["message"], "binary data".into());
625 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
626 }
627
628 #[tokio::test(flavor = "multi_thread")]
629 async fn websocket_source_consume_event() {
630 let server_addr = start_push_server().await;
631 let config = make_config(&server_addr);
632
633 let events =
635 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
636
637 assert!(!events.is_empty(), "No events received from source");
639 let event = events[0].as_log();
640 assert_eq!(event["message"], "message from server".into());
641 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
642 }
643
644 #[tokio::test(flavor = "multi_thread")]
645 async fn websocket_source_sends_initial_message() {
646 let initial_msg = "{\"action\":\"subscribe\",\"topic\":\"test\"}".to_string();
647 let response_msg = "{\"status\":\"subscribed\",\"topic\":\"test\"}".to_string();
648 let server_addr = start_subscribe_server(initial_msg.clone(), response_msg.clone()).await;
649
650 let mut config = make_config(&server_addr);
651 config.initial_message = Some(initial_msg);
652
653 let events =
654 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
655
656 assert!(!events.is_empty(), "No events received from source");
657 let event = events[0].as_log();
658 assert_eq!(event["message"], response_msg.into());
659 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
660 }
661
662 async fn start_reject_initial_message_server() -> String {
663 let addr = next_addr();
664 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
665 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
666
667 tokio::spawn(async move {
668 let (stream, _) = listener.accept().await.unwrap();
669 let mut websocket = accept_async(stream).await.expect("Failed to accept");
670
671 if websocket.next().await.is_some() {
672 let close_frame = CloseFrame {
673 code: CloseCode::Error,
674 reason: Cow::from("Simulated Internal Server Error"),
675 };
676 let _ = websocket.close(Some(close_frame)).await;
677 }
678 });
679
680 server_addr
681 }
682
683 #[tokio::test(flavor = "multi_thread")]
684 async fn websocket_source_exits_on_rejected_intial_messsage() {
685 let server_addr = start_reject_initial_message_server().await;
686
687 let mut config = make_config(&server_addr);
688 config.initial_message = Some("hello, server!".to_string());
689 config.initial_message_timeout_secs = Duration::from_secs(1);
690
691 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
692 }
693
694 async fn start_unresponsive_server() -> String {
695 let addr = next_addr();
696 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
697 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
698
699 tokio::spawn(async move {
700 if let Ok((stream, _)) = listener.accept().await {
701 let mut websocket = accept_async(stream).await.expect("Failed to accept");
703 while websocket.next().await.is_some() {
705 }
707 }
708 });
709
710 server_addr
711 }
712
713 #[tokio::test(flavor = "multi_thread")]
714 async fn websocket_source_exits_on_pong_timeout() {
715 let server_addr = start_unresponsive_server().await;
716
717 let mut config = make_config(&server_addr);
718 config.common.ping_interval = NonZeroU64::new(3);
719 config.common.ping_timeout = NonZeroU64::new(1);
720 config.ping_message = Some("ping".to_string());
721 config.pong_message = Some(PongMessage::Simple("pong".to_string()));
722
723 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
725 }
726
727 async fn start_blackhole_server() -> String {
728 let addr = next_addr();
729 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
730 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
731
732 tokio::spawn(async move {
733 let (mut _socket, _) = listener.accept().await.unwrap();
734 tokio::time::sleep(Duration::from_secs(10)).await;
735 });
736
737 server_addr
738 }
739
740 #[tokio::test(flavor = "multi_thread")]
741 async fn websocket_source_exits_on_connection_timeout() {
742 let server_addr = start_blackhole_server().await;
743 let mut config = make_config(&server_addr);
744 config.connect_timeout_secs = Duration::from_secs(1);
745
746 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
747 }
748}