1use std::pin::Pin;
2
3use chrono::Utc;
4use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt};
5use snafu::Snafu;
6use tokio::time;
7use tokio_tungstenite::tungstenite::{
8 Message, error::Error as TungsteniteError, protocol::CloseFrame,
9};
10use tokio_util::codec::FramedRead;
11use vector_lib::{
12 EstimatedJsonEncodedSizeOf,
13 config::LogNamespace,
14 event::{Event, LogEvent},
15 internal_event::{CountByteSize, EventsReceived, InternalEventHandle as _},
16};
17
18use crate::{
19 SourceSender,
20 codecs::Decoder,
21 common::websocket::{PingInterval, WebSocketConnector, is_closed},
22 config::SourceContext,
23 internal_events::{
24 ConnectionOpen, OpenGauge, PROTOCOL, WebSocketBytesReceived, WebSocketConnectionError,
25 WebSocketConnectionEstablished, WebSocketConnectionFailedError,
26 WebSocketConnectionShutdown, WebSocketKind, WebSocketMessageReceived,
27 WebSocketReceiveError, WebSocketSendError,
28 },
29 sources::websocket::config::WebSocketConfig,
30 vector_lib::codecs::StreamDecodingError,
31};
32
33macro_rules! fail_with_event {
34 ($context:expr_2021) => {{
35 emit!(WebSocketConnectionFailedError {
36 error: Box::new($context.build())
37 });
38 return $context.fail();
39 }};
40}
41
42type WebSocketSink = Pin<Box<dyn Sink<Message, Error = TungsteniteError> + Send>>;
43type WebSocketStream = Pin<Box<dyn Stream<Item = Result<Message, TungsteniteError>> + Send>>;
44
45pub(crate) struct WebSocketSourceParams {
46 pub connector: WebSocketConnector,
47 pub decoder: Decoder,
48 pub log_namespace: LogNamespace,
49}
50
51pub(crate) struct WebSocketSource {
52 config: WebSocketConfig,
53 params: WebSocketSourceParams,
54}
55
56#[derive(Debug, Snafu)]
57pub enum WebSocketSourceError {
58 #[snafu(display("Connection attempt timed out"))]
59 ConnectTimeout,
60
61 #[snafu(display("Server did not respond to the initial message in time"))]
62 InitialMessageTimeout,
63
64 #[snafu(display(
65 "The connection was closed after sending the initial message, but before a response."
66 ))]
67 ConnectionClosedPrematurely,
68
69 #[snafu(display("Connection closed by server with code '{}' and reason: '{}'", frame.code, frame.reason))]
70 RemoteClosed { frame: CloseFrame<'static> },
71
72 #[snafu(display("Connection closed by server without a close frame"))]
73 RemoteClosedEmpty,
74
75 #[snafu(display("Connection timed out while waiting for a pong response"))]
76 PongTimeout,
77
78 #[snafu(display("A WebSocket error occurred: {}", source))]
79 Tungstenite { source: TungsteniteError },
80}
81
82impl WebSocketSource {
83 pub const fn new(config: WebSocketConfig, params: WebSocketSourceParams) -> Self {
84 Self { config, params }
85 }
86
87 pub async fn run(self, cx: SourceContext) -> Result<(), WebSocketSourceError> {
88 let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
89 let mut ping_manager = PingManager::new(&self.config);
90
91 let mut out = cx.out;
92
93 let (ws_sink, ws_source) = self.connect(&mut out).await?;
94
95 pin_mut!(ws_sink, ws_source);
96
97 loop {
98 let result = tokio::select! {
99 _ = cx.shutdown.clone() => {
100 info!("Received shutdown signal.");
101 break;
102 },
103
104 res = ping_manager.tick(&mut ws_sink) => res,
105
106 Some(msg_result) = ws_source.next() => {
107 match msg_result {
108 Ok(msg) => self.handle_message(msg, &mut ping_manager, &mut out).await,
109 Err(error) => {
110 emit!(WebSocketReceiveError { error: &error });
111 Err(WebSocketSourceError::Tungstenite { source: error })
112 }
113 }
114 }
115 };
116
117 if let Err(error) = result {
118 match error {
119 WebSocketSourceError::RemoteClosed { frame } => {
120 warn!(
121 message = "Connection closed by server.",
122 code = %frame.code,
123 reason = %frame.reason
124 );
125 emit!(WebSocketConnectionShutdown);
126 }
127 WebSocketSourceError::RemoteClosedEmpty => {
128 warn!("Connection closed by server without a close frame.");
129 emit!(WebSocketConnectionShutdown);
130 }
131 WebSocketSourceError::PongTimeout => {
132 error!("Disconnecting due to pong timeout.");
133 emit!(WebSocketReceiveError {
134 error: &TungsteniteError::Io(std::io::Error::new(
135 std::io::ErrorKind::TimedOut,
136 "Pong timeout"
137 ))
138 });
139 emit!(WebSocketConnectionShutdown);
140 return Err(error);
141 }
142 WebSocketSourceError::Tungstenite { source: ws_err } => {
143 if is_closed(&ws_err) {
144 emit!(WebSocketConnectionShutdown);
145 }
146 error!(message = "WebSocket connection error.", error = %ws_err);
147 }
148 WebSocketSourceError::ConnectTimeout
151 | WebSocketSourceError::InitialMessageTimeout
152 | WebSocketSourceError::ConnectionClosedPrematurely => {
153 unreachable!(
154 "Encountered a connection-time error during runtime: {:?}",
155 error
156 );
157 }
158 }
159 if self
160 .reconnect(&mut out, &mut ws_sink, &mut ws_source)
161 .await
162 .is_err()
163 {
164 break;
165 }
166 }
167 }
168 Ok(())
169 }
170
171 async fn handle_message(
172 &self,
173 msg: Message,
174 ping_manager: &mut PingManager,
175 out: &mut SourceSender,
176 ) -> Result<(), WebSocketSourceError> {
177 match msg {
178 Message::Pong(_) => {
179 ping_manager.record_pong();
180 Ok(())
181 }
182 Message::Text(msg_txt) => {
183 if self.is_custom_pong(&msg_txt) {
184 ping_manager.record_pong();
185 debug!("Received custom pong response.");
186 } else {
187 self.process_message(&msg_txt, WebSocketKind::Text, out)
188 .await;
189 }
190 Ok(())
191 }
192 Message::Binary(msg_bytes) => {
193 self.process_message(&msg_bytes, WebSocketKind::Binary, out)
194 .await;
195 Ok(())
196 }
197 Message::Ping(_) => Ok(()),
198 Message::Close(frame) => self.handle_close_frame(frame),
199 Message::Frame(_) => {
200 warn!("Unsupported message type received: frame.");
201 Ok(())
202 }
203 }
204 }
205
206 async fn process_message<T>(&self, payload: &T, kind: WebSocketKind, out: &mut SourceSender)
207 where
208 T: AsRef<[u8]> + ?Sized,
209 {
210 let payload_bytes = payload.as_ref();
211
212 emit!(WebSocketBytesReceived {
213 byte_size: payload_bytes.len(),
214 url: &self.config.common.uri,
215 protocol: PROTOCOL,
216 kind,
217 });
218 let mut stream = FramedRead::new(payload_bytes, self.params.decoder.clone());
219
220 while let Some(result) = stream.next().await {
221 match result {
222 Ok((events, _)) => {
223 if events.is_empty() {
224 continue;
225 }
226
227 let event_count = events.len();
228 let byte_size = events.estimated_json_encoded_size_of();
229
230 register!(EventsReceived).emit(CountByteSize(event_count, byte_size));
231 emit!(WebSocketMessageReceived {
232 count: event_count,
233 byte_size,
234 url: &self.config.common.uri,
235 protocol: PROTOCOL,
236 kind,
237 });
238
239 let events_with_meta = events.into_iter().map(|mut event| {
240 if let Event::Log(event) = &mut event {
241 self.add_metadata(event);
242 }
243 event
244 });
245
246 if let Err(error) = out.send_batch(events_with_meta).await {
247 error!(message = "Error sending events.", %error);
248 }
249 }
250 Err(error) => {
251 if !error.can_continue() {
252 break;
253 }
254 }
255 }
256 }
257 }
258
259 fn add_metadata(&self, event: &mut LogEvent) {
260 self.params
261 .log_namespace
262 .insert_standard_vector_source_metadata(event, WebSocketConfig::NAME, Utc::now());
263 }
264
265 async fn reconnect(
266 &self,
267 out: &mut SourceSender,
268 ws_sink: &mut WebSocketSink,
269 ws_source: &mut WebSocketStream,
270 ) -> Result<(), WebSocketSourceError> {
271 info!("Reconnecting to WebSocket...");
272
273 let (new_sink, new_source) = self.connect(out).await?;
274
275 *ws_sink = new_sink;
276 *ws_source = new_source;
277
278 info!("Reconnected to Websocket.");
279
280 Ok(())
281 }
282
283 async fn connect(
284 &self,
285 out: &mut SourceSender,
286 ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
287 let (mut ws_sink, mut ws_source) = self.try_create_sink_and_stream().await?;
288
289 if self.config.initial_message.is_some() {
290 self.send_initial_message(&mut ws_sink, &mut ws_source, out)
291 .await?;
292 }
293
294 Ok((ws_sink, ws_source))
295 }
296
297 async fn try_create_sink_and_stream(
298 &self,
299 ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
300 let connect_future = self.params.connector.connect_backoff();
301 let timeout = self.config.connect_timeout_secs;
302
303 let ws_stream = match time::timeout(timeout, connect_future).await {
304 Ok(ws) => ws,
305 Err(_) => {
306 emit!(WebSocketConnectionError {
307 error: TungsteniteError::Io(std::io::Error::new(
308 std::io::ErrorKind::TimedOut,
309 "Connection attempt timed out",
310 ))
311 });
312 return Err(WebSocketSourceError::ConnectTimeout);
313 }
314 };
315
316 emit!(WebSocketConnectionEstablished {});
317 let (sink, stream) = ws_stream.split();
318
319 Ok((Box::pin(sink), Box::pin(stream)))
320 }
321
322 async fn send_initial_message(
323 &self,
324 ws_sink: &mut WebSocketSink,
325 ws_source: &mut WebSocketStream,
326 out: &mut SourceSender,
327 ) -> Result<(), WebSocketSourceError> {
328 let initial_message = self.config.initial_message.as_ref().unwrap();
329 ws_sink
330 .send(Message::Text(initial_message.clone()))
331 .await
332 .map_err(|error| {
333 emit!(WebSocketSendError { error: &error });
334 WebSocketSourceError::Tungstenite { source: error }
335 })?;
336
337 debug!("Sent initial message, awaiting response from server.");
338
339 let response =
340 match time::timeout(self.config.initial_message_timeout_secs, ws_source.next()).await {
341 Ok(Some(msg)) => msg,
342 Ok(None) => fail_with_event!(ConnectionClosedPrematurelySnafu),
343 Err(_) => fail_with_event!(InitialMessageTimeoutSnafu),
344 };
345
346 let message = response.map_err(|source| {
347 emit!(WebSocketReceiveError { error: &source });
348 WebSocketSourceError::Tungstenite { source }
349 })?;
350
351 match message {
352 Message::Text(txt) => {
353 self.process_message(&txt, WebSocketKind::Text, out).await;
354 Ok(())
355 }
356 Message::Binary(bin) => {
357 self.process_message(&bin, WebSocketKind::Binary, out).await;
358 Ok(())
359 }
360 Message::Close(frame) => self.handle_close_frame(frame),
361 _ => Ok(()),
362 }
363 }
364
365 fn handle_close_frame(
366 &self,
367 frame: Option<CloseFrame<'_>>,
368 ) -> Result<(), WebSocketSourceError> {
369 let (error_message, specific_error) = match frame {
370 Some(frame) => {
371 let msg = format!(
372 "Connection closed by server with code '{}' and reason: '{}'",
373 frame.code, frame.reason
374 );
375 let err = WebSocketSourceError::RemoteClosed {
376 frame: frame.into_owned(),
377 };
378 (msg, err)
379 }
380 None => (
381 "Connection closed by server without a close frame".to_string(),
382 WebSocketSourceError::RemoteClosedEmpty,
383 ),
384 };
385
386 let error = TungsteniteError::Io(std::io::Error::new(
387 std::io::ErrorKind::ConnectionAborted,
388 error_message,
389 ));
390 emit!(WebSocketReceiveError { error: &error });
391
392 Err(specific_error)
393 }
394
395 fn is_custom_pong(&self, msg_txt: &str) -> bool {
396 match self.config.pong_message.as_ref() {
397 Some(config) => config.matches(msg_txt),
398 None => false,
399 }
400 }
401}
402
403struct PingManager {
404 interval: PingInterval,
405 waiting_for_pong: bool,
406 message: Message,
407}
408
409impl PingManager {
410 fn new(config: &WebSocketConfig) -> Self {
411 let ping_message = if let Some(ping_msg) = &config.ping_message {
412 Message::Text(ping_msg.clone())
413 } else {
414 Message::Ping(vec![])
415 };
416
417 Self {
418 interval: PingInterval::new(config.common.ping_interval.map(u64::from)),
419 waiting_for_pong: false,
420 message: ping_message,
421 }
422 }
423
424 const fn record_pong(&mut self) {
425 self.waiting_for_pong = false;
426 }
427
428 async fn tick(&mut self, ws_sink: &mut WebSocketSink) -> Result<(), WebSocketSourceError> {
429 self.interval.tick().await;
430
431 if self.waiting_for_pong {
432 return Err(WebSocketSourceError::PongTimeout);
433 }
434
435 ws_sink.send(self.message.clone()).await.map_err(|error| {
436 emit!(WebSocketSendError { error: &error });
437 WebSocketSourceError::Tungstenite { source: error }
438 })?;
439
440 self.waiting_for_pong = true;
441 Ok(())
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use std::{borrow::Cow, num::NonZeroU64};
448
449 use futures::{StreamExt, sink::SinkExt};
450 use tokio::{net::TcpListener, time::Duration};
451 use tokio_tungstenite::{
452 accept_async,
453 tungstenite::{
454 Message,
455 protocol::frame::{CloseFrame, coding::CloseCode},
456 },
457 };
458 use url::Url;
459 use vector_lib::codecs::decoding::DeserializerConfig;
460
461 use crate::{
462 common::websocket::WebSocketCommonConfig,
463 sources::websocket::config::{PongMessage, WebSocketConfig},
464 test_util::{
465 components::{
466 SOURCE_TAGS, run_and_assert_source_compliance, run_and_assert_source_error,
467 },
468 next_addr,
469 },
470 };
471
472 fn make_config(uri: &str) -> WebSocketConfig {
473 WebSocketConfig {
474 common: WebSocketCommonConfig {
475 uri: Url::parse(uri).unwrap().to_string(),
476 ..Default::default()
477 },
478 ..Default::default()
479 }
480 }
481
482 async fn start_binary_push_server() -> String {
484 let addr = next_addr();
485 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
486 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
487
488 tokio::spawn(async move {
489 let (stream, _) = listener.accept().await.unwrap();
490 let mut websocket = accept_async(stream).await.expect("Failed to accept");
491
492 let binary_payload = br#"{"message": "binary data"}"#.to_vec();
493 websocket
494 .send(Message::Binary(binary_payload))
495 .await
496 .unwrap();
497 });
498
499 server_addr
500 }
501
502 async fn start_push_server() -> String {
504 let addr = next_addr();
505 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
506 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
507
508 tokio::spawn(async move {
509 let (stream, _) = listener.accept().await.unwrap();
511 let mut websocket = accept_async(stream).await.expect("Failed to accept");
512
513 websocket
515 .send(Message::Text("message from server".to_string()))
516 .await
517 .unwrap();
518 });
519
520 server_addr
521 }
522
523 async fn start_subscribe_server(initial_message: String, response_message: String) -> String {
526 let addr = next_addr();
527 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
528 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
529
530 tokio::spawn(async move {
531 let (stream, _) = listener.accept().await.unwrap();
532 let mut websocket = accept_async(stream).await.expect("Failed to accept");
533
534 if let Some(Ok(Message::Text(msg))) = websocket.next().await
536 && msg == initial_message
537 {
538 websocket
540 .send(Message::Text(response_message))
541 .await
542 .unwrap();
543 }
544 });
545
546 server_addr
547 }
548
549 async fn start_reconnect_server() -> String {
550 let addr = next_addr();
551 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
552 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
553
554 tokio::spawn(async move {
555 let (stream, _) = listener.accept().await.unwrap();
557 let mut websocket = accept_async(stream).await.expect("Failed to accept");
558 websocket
559 .send(Message::Text("first message".to_string()))
560 .await
561 .unwrap();
562 websocket.close(None).await.unwrap();
564
565 let (stream, _) = listener.accept().await.unwrap();
567 let mut websocket = accept_async(stream).await.expect("Failed to accept");
568 websocket
569 .send(Message::Text("second message".to_string()))
570 .await
571 .unwrap();
572 });
573
574 server_addr
575 }
576
577 #[tokio::test(flavor = "multi_thread")]
578 async fn websocket_source_reconnects_after_disconnect() {
579 let server_addr = start_reconnect_server().await;
580 let config = make_config(&server_addr);
581
582 let events =
584 run_and_assert_source_compliance(config, Duration::from_secs(5), &SOURCE_TAGS).await;
585
586 assert_eq!(
587 events.len(),
588 2,
589 "Should have received messages from both connections"
590 );
591
592 let event = events[0].as_log();
593 assert_eq!(event["message"], "first message".into());
594
595 let event = events[1].as_log();
596 assert_eq!(event["message"], "second message".into());
597 }
598
599 #[tokio::test(flavor = "multi_thread")]
600 async fn websocket_source_consume_binary_event() {
601 let server_addr = start_binary_push_server().await;
602 let mut config = make_config(&server_addr);
603 let decoding = DeserializerConfig::Json(Default::default());
604 config.decoding = decoding;
605
606 let events =
607 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
608
609 assert!(!events.is_empty(), "No events received from source");
610 let event = events[0].as_log();
611 assert_eq!(event["message"], "binary data".into());
612 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
613 }
614
615 #[tokio::test(flavor = "multi_thread")]
616 async fn websocket_source_consume_event() {
617 let server_addr = start_push_server().await;
618 let config = make_config(&server_addr);
619
620 let events =
622 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
623
624 assert!(!events.is_empty(), "No events received from source");
626 let event = events[0].as_log();
627 assert_eq!(event["message"], "message from server".into());
628 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
629 }
630
631 #[tokio::test(flavor = "multi_thread")]
632 async fn websocket_source_sends_initial_message() {
633 let initial_msg = "{\"action\":\"subscribe\",\"topic\":\"test\"}".to_string();
634 let response_msg = "{\"status\":\"subscribed\",\"topic\":\"test\"}".to_string();
635 let server_addr = start_subscribe_server(initial_msg.clone(), response_msg.clone()).await;
636
637 let mut config = make_config(&server_addr);
638 config.initial_message = Some(initial_msg);
639
640 let events =
641 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
642
643 assert!(!events.is_empty(), "No events received from source");
644 let event = events[0].as_log();
645 assert_eq!(event["message"], response_msg.into());
646 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
647 }
648
649 async fn start_reject_initial_message_server() -> String {
650 let addr = next_addr();
651 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
652 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
653
654 tokio::spawn(async move {
655 let (stream, _) = listener.accept().await.unwrap();
656 let mut websocket = accept_async(stream).await.expect("Failed to accept");
657
658 if websocket.next().await.is_some() {
659 let close_frame = CloseFrame {
660 code: CloseCode::Error,
661 reason: Cow::from("Simulated Internal Server Error"),
662 };
663 let _ = websocket.close(Some(close_frame)).await;
664 }
665 });
666
667 server_addr
668 }
669
670 #[tokio::test(flavor = "multi_thread")]
671 async fn websocket_source_exits_on_rejected_intial_messsage() {
672 let server_addr = start_reject_initial_message_server().await;
673
674 let mut config = make_config(&server_addr);
675 config.initial_message = Some("hello, server!".to_string());
676 config.initial_message_timeout_secs = Duration::from_secs(1);
677
678 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
679 }
680
681 async fn start_unresponsive_server() -> String {
682 let addr = next_addr();
683 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
684 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
685
686 tokio::spawn(async move {
687 if let Ok((stream, _)) = listener.accept().await {
688 let mut websocket = accept_async(stream).await.expect("Failed to accept");
690 while websocket.next().await.is_some() {
692 }
694 }
695 });
696
697 server_addr
698 }
699
700 #[tokio::test(flavor = "multi_thread")]
701 async fn websocket_source_exits_on_pong_timeout() {
702 let server_addr = start_unresponsive_server().await;
703
704 let mut config = make_config(&server_addr);
705 config.common.ping_interval = NonZeroU64::new(3);
706 config.common.ping_timeout = NonZeroU64::new(1);
707 config.ping_message = Some("ping".to_string());
708 config.pong_message = Some(PongMessage::Simple("pong".to_string()));
709
710 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
712 }
713
714 async fn start_blackhole_server() -> String {
715 let addr = next_addr();
716 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
717 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
718
719 tokio::spawn(async move {
720 let (mut _socket, _) = listener.accept().await.unwrap();
721 tokio::time::sleep(Duration::from_secs(10)).await;
722 });
723
724 server_addr
725 }
726
727 #[tokio::test(flavor = "multi_thread")]
728 async fn websocket_source_exits_on_connection_timeout() {
729 let server_addr = start_blackhole_server().await;
730 let mut config = make_config(&server_addr);
731 config.connect_timeout_secs = Duration::from_secs(1);
732
733 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
734 }
735}