1use std::pin::Pin;
2
3use chrono::Utc;
4use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt};
5use snafu::Snafu;
6use tokio::time;
7use tokio_tungstenite::tungstenite::{
8 Message, error::Error as TungsteniteError, protocol::CloseFrame,
9};
10use tokio_util::codec::FramedRead;
11use vector_lib::{
12 EstimatedJsonEncodedSizeOf,
13 config::LogNamespace,
14 event::{Event, LogEvent},
15 internal_event::{CountByteSize, EventsReceived, InternalEventHandle as _},
16};
17
18use crate::{
19 SourceSender,
20 codecs::Decoder,
21 common::websocket::{PingInterval, WebSocketConnector, is_closed},
22 config::SourceContext,
23 internal_events::{
24 ConnectionOpen, OpenGauge, PROTOCOL, WebSocketBytesReceived,
25 WebSocketConnectionFailedError, WebSocketConnectionShutdown, WebSocketKind,
26 WebSocketMessageReceived, WebSocketReceiveError, WebSocketSendError,
27 },
28 sources::websocket::config::WebSocketConfig,
29 vector_lib::codecs::StreamDecodingError,
30};
31
32macro_rules! fail_with_event {
33 ($context:expr_2021) => {{
34 emit!(WebSocketConnectionFailedError {
35 error: Box::new($context.build())
36 });
37 return $context.fail();
38 }};
39}
40
41type WebSocketSink = Pin<Box<dyn Sink<Message, Error = TungsteniteError> + Send>>;
42type WebSocketStream = Pin<Box<dyn Stream<Item = Result<Message, TungsteniteError>> + Send>>;
43
44pub(crate) struct WebSocketSourceParams {
45 pub connector: WebSocketConnector,
46 pub decoder: Decoder,
47 pub log_namespace: LogNamespace,
48}
49
50pub(crate) struct WebSocketSource {
51 config: WebSocketConfig,
52 params: WebSocketSourceParams,
53}
54
55#[derive(Debug, Snafu)]
56pub enum WebSocketSourceError {
57 #[snafu(display("Connection attempt timed out"))]
58 ConnectTimeout,
59
60 #[snafu(display("Server did not respond to the initial message in time"))]
61 InitialMessageTimeout,
62
63 #[snafu(display(
64 "The connection was closed after sending the initial message, but before a response."
65 ))]
66 ConnectionClosedPrematurely,
67
68 #[snafu(display("Connection closed by server with code '{}' and reason: '{}'", frame.code, frame.reason))]
69 RemoteClosed { frame: CloseFrame<'static> },
70
71 #[snafu(display("Connection closed by server without a close frame"))]
72 RemoteClosedEmpty,
73
74 #[snafu(display("Connection timed out while waiting for a pong response"))]
75 PongTimeout,
76
77 #[snafu(display("A WebSocket error occurred: {}", source))]
78 Tungstenite { source: TungsteniteError },
79}
80
81impl WebSocketSource {
82 pub const fn new(config: WebSocketConfig, params: WebSocketSourceParams) -> Self {
83 Self { config, params }
84 }
85
86 pub async fn run(self, cx: SourceContext) -> Result<(), WebSocketSourceError> {
87 let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
88 let mut ping_manager = PingManager::new(&self.config);
89
90 let mut out = cx.out;
91
92 let (ws_sink, ws_source) = self.connect(&mut out).await?;
93
94 pin_mut!(ws_sink, ws_source);
95
96 loop {
97 let result = tokio::select! {
98 _ = cx.shutdown.clone() => {
99 info!("Received shutdown signal.");
100 break;
101 },
102
103 res = ping_manager.tick(&mut ws_sink) => res,
104
105 Some(msg_result) = ws_source.next() => {
106 match msg_result {
107 Ok(msg) => self.handle_message(msg, &mut ping_manager, &mut out).await,
108 Err(error) => {
109 emit!(WebSocketReceiveError { error: &error });
110 Err(WebSocketSourceError::Tungstenite { source: error })
111 }
112 }
113 }
114 };
115
116 if let Err(error) = result {
117 match error {
118 WebSocketSourceError::RemoteClosed { frame } => {
119 warn!(
120 message = "Connection closed by server.",
121 code = %frame.code,
122 reason = %frame.reason
123 );
124 emit!(WebSocketConnectionShutdown);
125 }
126 WebSocketSourceError::RemoteClosedEmpty => {
127 warn!("Connection closed by server without a close frame.");
128 emit!(WebSocketConnectionShutdown);
129 }
130 WebSocketSourceError::PongTimeout => {
131 error!("Disconnecting due to pong timeout.");
132 emit!(WebSocketReceiveError {
133 error: &TungsteniteError::Io(std::io::Error::new(
134 std::io::ErrorKind::TimedOut,
135 "Pong timeout"
136 ))
137 });
138 emit!(WebSocketConnectionShutdown);
139 return Err(error);
140 }
141 WebSocketSourceError::Tungstenite { source: ws_err } => {
142 if is_closed(&ws_err) {
143 emit!(WebSocketConnectionShutdown);
144 }
145 error!(message = "WebSocket connection error.", error = %ws_err);
146 }
147 WebSocketSourceError::ConnectTimeout
150 | WebSocketSourceError::InitialMessageTimeout
151 | WebSocketSourceError::ConnectionClosedPrematurely => {
152 unreachable!(
153 "Encountered a connection-time error during runtime: {:?}",
154 error
155 );
156 }
157 }
158 if self
159 .reconnect(&mut out, &mut ws_sink, &mut ws_source)
160 .await
161 .is_err()
162 {
163 break;
164 }
165 }
166 }
167 Ok(())
168 }
169
170 async fn handle_message(
171 &self,
172 msg: Message,
173 ping_manager: &mut PingManager,
174 out: &mut SourceSender,
175 ) -> Result<(), WebSocketSourceError> {
176 match msg {
177 Message::Pong(_) => {
178 ping_manager.record_pong();
179 Ok(())
180 }
181 Message::Text(msg_txt) => {
182 if self.is_custom_pong(&msg_txt) {
183 ping_manager.record_pong();
184 debug!("Received custom pong response.");
185 } else {
186 self.process_message(&msg_txt, WebSocketKind::Text, out)
187 .await;
188 }
189 Ok(())
190 }
191 Message::Binary(msg_bytes) => {
192 self.process_message(&msg_bytes, WebSocketKind::Binary, out)
193 .await;
194 Ok(())
195 }
196 Message::Ping(_) => Ok(()),
197 Message::Close(frame) => self.handle_close_frame(frame),
198 Message::Frame(_) => {
199 warn!("Unsupported message type received: frame.");
200 Ok(())
201 }
202 }
203 }
204
205 async fn process_message<T>(&self, payload: &T, kind: WebSocketKind, out: &mut SourceSender)
206 where
207 T: AsRef<[u8]> + ?Sized,
208 {
209 let payload_bytes = payload.as_ref();
210
211 emit!(WebSocketBytesReceived {
212 byte_size: payload_bytes.len(),
213 url: &self.config.common.uri,
214 protocol: PROTOCOL,
215 kind,
216 });
217 let mut stream = FramedRead::new(payload_bytes, self.params.decoder.clone());
218
219 while let Some(result) = stream.next().await {
220 match result {
221 Ok((events, _)) => {
222 if events.is_empty() {
223 continue;
224 }
225
226 let event_count = events.len();
227 let byte_size = events.estimated_json_encoded_size_of();
228
229 register!(EventsReceived).emit(CountByteSize(event_count, byte_size));
230 emit!(WebSocketMessageReceived {
231 count: event_count,
232 byte_size,
233 url: &self.config.common.uri,
234 protocol: PROTOCOL,
235 kind,
236 });
237
238 let events_with_meta = events.into_iter().map(|mut event| {
239 if let Event::Log(event) = &mut event {
240 self.add_metadata(event);
241 }
242 event
243 });
244
245 if let Err(error) = out.send_batch(events_with_meta).await {
246 error!(message = "Error sending events.", %error);
247 }
248 }
249 Err(error) => {
250 if !error.can_continue() {
251 break;
252 }
253 }
254 }
255 }
256 }
257
258 fn add_metadata(&self, event: &mut LogEvent) {
259 self.params
260 .log_namespace
261 .insert_standard_vector_source_metadata(event, WebSocketConfig::NAME, Utc::now());
262 }
263
264 async fn reconnect(
265 &self,
266 out: &mut SourceSender,
267 ws_sink: &mut WebSocketSink,
268 ws_source: &mut WebSocketStream,
269 ) -> Result<(), WebSocketSourceError> {
270 info!("Reconnecting to WebSocket...");
271
272 let (new_sink, new_source) = self.connect(out).await?;
273
274 *ws_sink = new_sink;
275 *ws_source = new_source;
276
277 info!("Reconnected to Websocket.");
278
279 Ok(())
280 }
281
282 async fn connect(
283 &self,
284 out: &mut SourceSender,
285 ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
286 let (mut ws_sink, mut ws_source) = self.try_create_sink_and_stream().await?;
287
288 if self.config.initial_message.is_some() {
289 self.send_initial_message(&mut ws_sink, &mut ws_source, out)
290 .await?;
291 }
292
293 Ok((ws_sink, ws_source))
294 }
295
296 async fn try_create_sink_and_stream(
297 &self,
298 ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
299 let ws_stream = self
300 .params
301 .connector
302 .connect_backoff_with_timeout(self.config.connect_timeout_secs)
303 .await;
304
305 let (sink, stream) = ws_stream.split();
306
307 Ok((Box::pin(sink), Box::pin(stream)))
308 }
309
310 async fn send_initial_message(
311 &self,
312 ws_sink: &mut WebSocketSink,
313 ws_source: &mut WebSocketStream,
314 out: &mut SourceSender,
315 ) -> Result<(), WebSocketSourceError> {
316 let initial_message = self.config.initial_message.as_ref().unwrap();
317 ws_sink
318 .send(Message::Text(initial_message.clone()))
319 .await
320 .map_err(|error| {
321 emit!(WebSocketSendError { error: &error });
322 WebSocketSourceError::Tungstenite { source: error }
323 })?;
324
325 debug!("Sent initial message, awaiting response from server.");
326
327 let response =
328 match time::timeout(self.config.initial_message_timeout_secs, ws_source.next()).await {
329 Ok(Some(msg)) => msg,
330 Ok(None) => fail_with_event!(ConnectionClosedPrematurelySnafu),
331 Err(_) => fail_with_event!(InitialMessageTimeoutSnafu),
332 };
333
334 let message = response.map_err(|source| {
335 emit!(WebSocketReceiveError { error: &source });
336 WebSocketSourceError::Tungstenite { source }
337 })?;
338
339 match message {
340 Message::Text(txt) => {
341 self.process_message(&txt, WebSocketKind::Text, out).await;
342 Ok(())
343 }
344 Message::Binary(bin) => {
345 self.process_message(&bin, WebSocketKind::Binary, out).await;
346 Ok(())
347 }
348 Message::Close(frame) => self.handle_close_frame(frame),
349 _ => Ok(()),
350 }
351 }
352
353 fn handle_close_frame(
354 &self,
355 frame: Option<CloseFrame<'_>>,
356 ) -> Result<(), WebSocketSourceError> {
357 let (error_message, specific_error) = match frame {
358 Some(frame) => {
359 let msg = format!(
360 "Connection closed by server with code '{}' and reason: '{}'",
361 frame.code, frame.reason
362 );
363 let err = WebSocketSourceError::RemoteClosed {
364 frame: frame.into_owned(),
365 };
366 (msg, err)
367 }
368 None => (
369 "Connection closed by server without a close frame".to_string(),
370 WebSocketSourceError::RemoteClosedEmpty,
371 ),
372 };
373
374 let error = TungsteniteError::Io(std::io::Error::new(
375 std::io::ErrorKind::ConnectionAborted,
376 error_message,
377 ));
378 emit!(WebSocketReceiveError { error: &error });
379
380 Err(specific_error)
381 }
382
383 fn is_custom_pong(&self, msg_txt: &str) -> bool {
384 match self.config.pong_message.as_ref() {
385 Some(config) => config.matches(msg_txt),
386 None => false,
387 }
388 }
389}
390
391struct PingManager {
392 interval: PingInterval,
393 waiting_for_pong: bool,
394 message: Message,
395}
396
397impl PingManager {
398 fn new(config: &WebSocketConfig) -> Self {
399 let ping_message = if let Some(ping_msg) = &config.ping_message {
400 Message::Text(ping_msg.clone())
401 } else {
402 Message::Ping(vec![])
403 };
404
405 Self {
406 interval: PingInterval::new(config.common.ping_interval.map(u64::from)),
407 waiting_for_pong: false,
408 message: ping_message,
409 }
410 }
411
412 const fn record_pong(&mut self) {
413 self.waiting_for_pong = false;
414 }
415
416 async fn tick(&mut self, ws_sink: &mut WebSocketSink) -> Result<(), WebSocketSourceError> {
417 self.interval.tick().await;
418
419 if self.waiting_for_pong {
420 return Err(WebSocketSourceError::PongTimeout);
421 }
422
423 ws_sink.send(self.message.clone()).await.map_err(|error| {
424 emit!(WebSocketSendError { error: &error });
425 WebSocketSourceError::Tungstenite { source: error }
426 })?;
427
428 self.waiting_for_pong = true;
429 Ok(())
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use std::{borrow::Cow, num::NonZeroU64};
436
437 use futures::{StreamExt, sink::SinkExt};
438 use tokio::{net::TcpListener, time::Duration};
439 use tokio_tungstenite::{
440 accept_async,
441 tungstenite::{
442 Message,
443 protocol::frame::{CloseFrame, coding::CloseCode},
444 },
445 };
446 use url::Url;
447 use vector_lib::codecs::decoding::DeserializerConfig;
448
449 use crate::{
450 common::websocket::WebSocketCommonConfig,
451 sources::websocket::config::{PongMessage, WebSocketConfig},
452 test_util::{
453 addr::next_addr,
454 components::{
455 SOURCE_TAGS, run_and_assert_source_compliance, run_and_assert_source_error,
456 },
457 },
458 };
459
460 fn make_config(uri: &str) -> WebSocketConfig {
461 WebSocketConfig {
462 common: WebSocketCommonConfig {
463 uri: Url::parse(uri).unwrap().to_string(),
464 ..Default::default()
465 },
466 ..Default::default()
467 }
468 }
469
470 async fn start_binary_push_server() -> String {
472 let (_guard, addr) = next_addr();
473 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
474 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
475
476 tokio::spawn(async move {
477 let (stream, _) = listener.accept().await.unwrap();
478 let mut websocket = accept_async(stream).await.expect("Failed to accept");
479
480 let binary_payload = br#"{"message": "binary data"}"#.to_vec();
481 websocket
482 .send(Message::Binary(binary_payload))
483 .await
484 .unwrap();
485 });
486
487 server_addr
488 }
489
490 async fn start_push_server() -> String {
492 let (_guard, addr) = next_addr();
493 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
494 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
495
496 tokio::spawn(async move {
497 let (stream, _) = listener.accept().await.unwrap();
499 let mut websocket = accept_async(stream).await.expect("Failed to accept");
500
501 websocket
503 .send(Message::Text("message from server".to_string()))
504 .await
505 .unwrap();
506 });
507
508 server_addr
509 }
510
511 async fn start_subscribe_server(initial_message: String, response_message: String) -> String {
514 let (_guard, addr) = next_addr();
515 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
516 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
517
518 tokio::spawn(async move {
519 let (stream, _) = listener.accept().await.unwrap();
520 let mut websocket = accept_async(stream).await.expect("Failed to accept");
521
522 if let Some(Ok(Message::Text(msg))) = websocket.next().await
524 && msg == initial_message
525 {
526 websocket
528 .send(Message::Text(response_message))
529 .await
530 .unwrap();
531 }
532 });
533
534 server_addr
535 }
536
537 async fn start_reconnect_server() -> String {
538 let (_guard, addr) = next_addr();
539 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
540 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
541
542 tokio::spawn(async move {
543 let (stream, _) = listener.accept().await.unwrap();
545 let mut websocket = accept_async(stream).await.expect("Failed to accept");
546 websocket
547 .send(Message::Text("first message".to_string()))
548 .await
549 .unwrap();
550 websocket.close(None).await.unwrap();
552
553 let (stream, _) = listener.accept().await.unwrap();
555 let mut websocket = accept_async(stream).await.expect("Failed to accept");
556 websocket
557 .send(Message::Text("second message".to_string()))
558 .await
559 .unwrap();
560 });
561
562 server_addr
563 }
564
565 #[tokio::test(flavor = "multi_thread")]
566 async fn websocket_source_reconnects_after_disconnect() {
567 let server_addr = start_reconnect_server().await;
568 let config = make_config(&server_addr);
569
570 let events =
572 run_and_assert_source_compliance(config, Duration::from_secs(5), &SOURCE_TAGS).await;
573
574 assert_eq!(
575 events.len(),
576 2,
577 "Should have received messages from both connections"
578 );
579
580 let event = events[0].as_log();
581 assert_eq!(event["message"], "first message".into());
582
583 let event = events[1].as_log();
584 assert_eq!(event["message"], "second message".into());
585 }
586
587 #[tokio::test(flavor = "multi_thread")]
588 async fn websocket_source_consume_binary_event() {
589 let server_addr = start_binary_push_server().await;
590 let mut config = make_config(&server_addr);
591 let decoding = DeserializerConfig::Json(Default::default());
592 config.decoding = decoding;
593
594 let events =
595 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
596
597 assert!(!events.is_empty(), "No events received from source");
598 let event = events[0].as_log();
599 assert_eq!(event["message"], "binary data".into());
600 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
601 }
602
603 #[tokio::test(flavor = "multi_thread")]
604 async fn websocket_source_consume_event() {
605 let server_addr = start_push_server().await;
606 let config = make_config(&server_addr);
607
608 let events =
610 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
611
612 assert!(!events.is_empty(), "No events received from source");
614 let event = events[0].as_log();
615 assert_eq!(event["message"], "message from server".into());
616 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
617 }
618
619 #[tokio::test(flavor = "multi_thread")]
620 async fn websocket_source_sends_initial_message() {
621 let initial_msg = "{\"action\":\"subscribe\",\"topic\":\"test\"}".to_string();
622 let response_msg = "{\"status\":\"subscribed\",\"topic\":\"test\"}".to_string();
623 let server_addr = start_subscribe_server(initial_msg.clone(), response_msg.clone()).await;
624
625 let mut config = make_config(&server_addr);
626 config.initial_message = Some(initial_msg);
627
628 let events =
629 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
630
631 assert!(!events.is_empty(), "No events received from source");
632 let event = events[0].as_log();
633 assert_eq!(event["message"], response_msg.into());
634 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
635 }
636
637 async fn start_reject_initial_message_server() -> String {
638 let (_guard, addr) = next_addr();
639 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
640 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
641
642 tokio::spawn(async move {
643 let (stream, _) = listener.accept().await.unwrap();
644 let mut websocket = accept_async(stream).await.expect("Failed to accept");
645
646 if websocket.next().await.is_some() {
647 let close_frame = CloseFrame {
648 code: CloseCode::Error,
649 reason: Cow::from("Simulated Internal Server Error"),
650 };
651 let _ = websocket.close(Some(close_frame)).await;
652 }
653 });
654
655 server_addr
656 }
657
658 #[tokio::test(flavor = "multi_thread")]
659 async fn websocket_source_exits_on_rejected_intial_messsage() {
660 let server_addr = start_reject_initial_message_server().await;
661
662 let mut config = make_config(&server_addr);
663 config.initial_message = Some("hello, server!".to_string());
664 config.initial_message_timeout_secs = Duration::from_secs(1);
665
666 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
667 }
668
669 async fn start_unresponsive_server() -> String {
670 let (_guard, addr) = next_addr();
671 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
672 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
673
674 tokio::spawn(async move {
675 if let Ok((stream, _)) = listener.accept().await {
676 let mut websocket = accept_async(stream).await.expect("Failed to accept");
678 while websocket.next().await.is_some() {
680 }
682 }
683 });
684
685 server_addr
686 }
687
688 #[tokio::test(flavor = "multi_thread")]
689 async fn websocket_source_exits_on_pong_timeout() {
690 let server_addr = start_unresponsive_server().await;
691
692 let mut config = make_config(&server_addr);
693 config.common.ping_interval = NonZeroU64::new(3);
694 config.common.ping_timeout = NonZeroU64::new(1);
695 config.ping_message = Some("ping".to_string());
696 config.pong_message = Some(PongMessage::Simple("pong".to_string()));
697
698 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
700 }
701
702 async fn start_blackhole_server() -> String {
703 let (_guard, addr) = next_addr();
704 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
705 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
706
707 tokio::spawn(async move {
708 let (mut _socket, _) = listener.accept().await.unwrap();
709 tokio::time::sleep(Duration::from_secs(10)).await;
710 });
711
712 server_addr
713 }
714
715 #[tokio::test(flavor = "multi_thread")]
716 async fn websocket_source_exits_on_connection_timeout() {
717 let server_addr = start_blackhole_server().await;
718 let mut config = make_config(&server_addr);
719 config.connect_timeout_secs = Duration::from_secs(1);
720
721 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
722 }
723}