1use crate::vector_lib::codecs::StreamDecodingError;
2use crate::{
3 codecs::Decoder,
4 common::websocket::{is_closed, PingInterval, WebSocketConnector},
5 config::SourceContext,
6 internal_events::{
7 ConnectionOpen, OpenGauge, WebSocketBytesReceived, WebSocketConnectionError,
8 WebSocketConnectionEstablished, WebSocketConnectionFailedError,
9 WebSocketConnectionShutdown, WebSocketKind, WebSocketMessageReceived,
10 WebSocketReceiveError, WebSocketSendError, PROTOCOL,
11 },
12 sources::websocket::config::WebSocketConfig,
13 SourceSender,
14};
15use chrono::Utc;
16use futures::{pin_mut, sink::SinkExt, Sink, Stream, StreamExt};
17use snafu::Snafu;
18use std::pin::Pin;
19use tokio::time;
20use tokio_tungstenite::tungstenite::protocol::CloseFrame;
21use tokio_tungstenite::tungstenite::{error::Error as TungsteniteError, Message};
22use tokio_util::codec::FramedRead;
23use vector_lib::internal_event::{CountByteSize, EventsReceived, InternalEventHandle as _};
24use vector_lib::{
25 config::LogNamespace,
26 event::{Event, LogEvent},
27 EstimatedJsonEncodedSizeOf,
28};
29
30macro_rules! fail_with_event {
31 ($context:expr) => {{
32 emit!(WebSocketConnectionFailedError {
33 error: Box::new($context.build())
34 });
35 return $context.fail();
36 }};
37}
38
39type WebSocketSink = Pin<Box<dyn Sink<Message, Error = TungsteniteError> + Send>>;
40type WebSocketStream = Pin<Box<dyn Stream<Item = Result<Message, TungsteniteError>> + Send>>;
41
42pub(crate) struct WebSocketSourceParams {
43 pub connector: WebSocketConnector,
44 pub decoder: Decoder,
45 pub log_namespace: LogNamespace,
46}
47
48pub(crate) struct WebSocketSource {
49 config: WebSocketConfig,
50 params: WebSocketSourceParams,
51}
52
53#[derive(Debug, Snafu)]
54pub enum WebSocketSourceError {
55 #[snafu(display("Connection attempt timed out"))]
56 ConnectTimeout,
57
58 #[snafu(display("Server did not respond to the initial message in time"))]
59 InitialMessageTimeout,
60
61 #[snafu(display(
62 "The connection was closed after sending the initial message, but before a response."
63 ))]
64 ConnectionClosedPrematurely,
65
66 #[snafu(display("Connection closed by server with code '{}' and reason: '{}'", frame.code, frame.reason))]
67 RemoteClosed { frame: CloseFrame<'static> },
68
69 #[snafu(display("Connection closed by server without a close frame"))]
70 RemoteClosedEmpty,
71
72 #[snafu(display("Connection timed out while waiting for a pong response"))]
73 PongTimeout,
74
75 #[snafu(display("A WebSocket error occurred: {}", source))]
76 Tungstenite { source: TungsteniteError },
77}
78
79impl WebSocketSource {
80 pub const fn new(config: WebSocketConfig, params: WebSocketSourceParams) -> Self {
81 Self { config, params }
82 }
83
84 pub async fn run(self, cx: SourceContext) -> Result<(), WebSocketSourceError> {
85 let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
86 let mut ping_manager = PingManager::new(&self.config);
87
88 let mut out = cx.out;
89
90 let (ws_sink, ws_source) = self.connect(&mut out).await?;
91
92 pin_mut!(ws_sink, ws_source);
93
94 loop {
95 let result = tokio::select! {
96 _ = cx.shutdown.clone() => {
97 info!(internal_log_rate_limit = true, "Received shutdown signal.");
98 break;
99 },
100
101 res = ping_manager.tick(&mut ws_sink) => res,
102
103 Some(msg_result) = ws_source.next() => {
104 match msg_result {
105 Ok(msg) => self.handle_message(msg, &mut ping_manager, &mut out).await,
106 Err(error) => {
107 emit!(WebSocketReceiveError { error: &error });
108 Err(WebSocketSourceError::Tungstenite { source: error })
109 }
110 }
111 }
112 };
113
114 if let Err(error) = result {
115 match error {
116 WebSocketSourceError::RemoteClosed { frame } => {
117 warn!(
118 message = "Connection closed by server.",
119 code = %frame.code,
120 reason = %frame.reason,
121 internal_log_rate_limit = true
122 );
123 emit!(WebSocketConnectionShutdown);
124 }
125 WebSocketSourceError::RemoteClosedEmpty => {
126 warn!(
127 internal_log_rate_limit = true,
128 "Connection closed by server without a close frame."
129 );
130 emit!(WebSocketConnectionShutdown);
131 }
132 WebSocketSourceError::PongTimeout => {
133 error!(
134 internal_log_rate_limit = true,
135 "Disconnecting due to pong timeout."
136 );
137 emit!(WebSocketReceiveError {
138 error: &TungsteniteError::Io(std::io::Error::new(
139 std::io::ErrorKind::TimedOut,
140 "Pong timeout"
141 ))
142 });
143 emit!(WebSocketConnectionShutdown);
144 return Err(error);
145 }
146 WebSocketSourceError::Tungstenite { source: ws_err } => {
147 if is_closed(&ws_err) {
148 emit!(WebSocketConnectionShutdown);
149 }
150 error!(message = "WebSocket connection error.", error = %ws_err, internal_log_rate_limit = true);
151 }
152 WebSocketSourceError::ConnectTimeout
155 | WebSocketSourceError::InitialMessageTimeout
156 | WebSocketSourceError::ConnectionClosedPrematurely => {
157 unreachable!(
158 "Encountered a connection-time error during runtime: {:?}",
159 error
160 );
161 }
162 }
163 if self
164 .reconnect(&mut out, &mut ws_sink, &mut ws_source)
165 .await
166 .is_err()
167 {
168 break;
169 }
170 }
171 }
172 Ok(())
173 }
174
175 async fn handle_message(
176 &self,
177 msg: Message,
178 ping_manager: &mut PingManager,
179 out: &mut SourceSender,
180 ) -> Result<(), WebSocketSourceError> {
181 match msg {
182 Message::Pong(_) => {
183 ping_manager.record_pong();
184 Ok(())
185 }
186 Message::Text(msg_txt) => {
187 if self.is_custom_pong(&msg_txt) {
188 ping_manager.record_pong();
189 debug!("Received custom pong response.");
190 } else {
191 self.process_message(&msg_txt, WebSocketKind::Text, out)
192 .await;
193 }
194 Ok(())
195 }
196 Message::Binary(msg_bytes) => {
197 self.process_message(&msg_bytes, WebSocketKind::Binary, out)
198 .await;
199 Ok(())
200 }
201 Message::Ping(_) => Ok(()),
202 Message::Close(frame) => self.handle_close_frame(frame),
203 Message::Frame(_) => {
204 warn!(
205 internal_log_rate_limit = true,
206 "Unsupported message type received: frame."
207 );
208 Ok(())
209 }
210 }
211 }
212
213 async fn process_message<T>(&self, payload: &T, kind: WebSocketKind, out: &mut SourceSender)
214 where
215 T: AsRef<[u8]> + ?Sized,
216 {
217 let payload_bytes = payload.as_ref();
218
219 emit!(WebSocketBytesReceived {
220 byte_size: payload_bytes.len(),
221 url: &self.config.common.uri,
222 protocol: PROTOCOL,
223 kind,
224 });
225 let mut stream = FramedRead::new(payload_bytes, self.params.decoder.clone());
226
227 while let Some(result) = stream.next().await {
228 match result {
229 Ok((events, _)) => {
230 if events.is_empty() {
231 continue;
232 }
233
234 let event_count = events.len();
235 let byte_size = events.estimated_json_encoded_size_of();
236
237 register!(EventsReceived).emit(CountByteSize(event_count, byte_size));
238 emit!(WebSocketMessageReceived {
239 count: event_count,
240 byte_size,
241 url: &self.config.common.uri,
242 protocol: PROTOCOL,
243 kind,
244 });
245
246 let events_with_meta = events.into_iter().map(|mut event| {
247 if let Event::Log(event) = &mut event {
248 self.add_metadata(event);
249 }
250 event
251 });
252
253 if let Err(error) = out.send_batch(events_with_meta).await {
254 error!(message = "Error sending events.", %error, internal_log_rate_limit = true);
255 }
256 }
257 Err(error) => {
258 if !error.can_continue() {
259 break;
260 }
261 }
262 }
263 }
264 }
265
266 fn add_metadata(&self, event: &mut LogEvent) {
267 self.params
268 .log_namespace
269 .insert_standard_vector_source_metadata(event, WebSocketConfig::NAME, Utc::now());
270 }
271
272 async fn reconnect(
273 &self,
274 out: &mut SourceSender,
275 ws_sink: &mut WebSocketSink,
276 ws_source: &mut WebSocketStream,
277 ) -> Result<(), WebSocketSourceError> {
278 info!(
279 internal_log_rate_limit = true,
280 "Reconnecting to WebSocket..."
281 );
282
283 let (new_sink, new_source) = self.connect(out).await?;
284
285 *ws_sink = new_sink;
286 *ws_source = new_source;
287
288 info!(internal_log_rate_limit = true, "Reconnected to Websocket.");
289
290 Ok(())
291 }
292
293 async fn connect(
294 &self,
295 out: &mut SourceSender,
296 ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
297 let (mut ws_sink, mut ws_source) = self.try_create_sink_and_stream().await?;
298
299 if self.config.initial_message.is_some() {
300 self.send_initial_message(&mut ws_sink, &mut ws_source, out)
301 .await?;
302 }
303
304 Ok((ws_sink, ws_source))
305 }
306
307 async fn try_create_sink_and_stream(
308 &self,
309 ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
310 let connect_future = self.params.connector.connect_backoff();
311 let timeout = self.config.connect_timeout_secs;
312
313 let ws_stream = match time::timeout(timeout, connect_future).await {
314 Ok(ws) => ws,
315 Err(_) => {
316 emit!(WebSocketConnectionError {
317 error: TungsteniteError::Io(std::io::Error::new(
318 std::io::ErrorKind::TimedOut,
319 "Connection attempt timed out",
320 ))
321 });
322 return Err(WebSocketSourceError::ConnectTimeout);
323 }
324 };
325
326 emit!(WebSocketConnectionEstablished {});
327 let (sink, stream) = ws_stream.split();
328
329 Ok((Box::pin(sink), Box::pin(stream)))
330 }
331
332 async fn send_initial_message(
333 &self,
334 ws_sink: &mut WebSocketSink,
335 ws_source: &mut WebSocketStream,
336 out: &mut SourceSender,
337 ) -> Result<(), WebSocketSourceError> {
338 let initial_message = self.config.initial_message.as_ref().unwrap();
339 ws_sink
340 .send(Message::Text(initial_message.clone()))
341 .await
342 .map_err(|error| {
343 emit!(WebSocketSendError { error: &error });
344 WebSocketSourceError::Tungstenite { source: error }
345 })?;
346
347 debug!("Sent initial message, awaiting response from server.");
348
349 let response =
350 match time::timeout(self.config.initial_message_timeout_secs, ws_source.next()).await {
351 Ok(Some(msg)) => msg,
352 Ok(None) => fail_with_event!(ConnectionClosedPrematurelySnafu),
353 Err(_) => fail_with_event!(InitialMessageTimeoutSnafu),
354 };
355
356 let message = response.map_err(|source| {
357 emit!(WebSocketReceiveError { error: &source });
358 WebSocketSourceError::Tungstenite { source }
359 })?;
360
361 match message {
362 Message::Text(txt) => {
363 self.process_message(&txt, WebSocketKind::Text, out).await;
364 Ok(())
365 }
366 Message::Binary(bin) => {
367 self.process_message(&bin, WebSocketKind::Binary, out).await;
368 Ok(())
369 }
370 Message::Close(frame) => self.handle_close_frame(frame),
371 _ => Ok(()),
372 }
373 }
374
375 fn handle_close_frame(
376 &self,
377 frame: Option<CloseFrame<'_>>,
378 ) -> Result<(), WebSocketSourceError> {
379 let (error_message, specific_error) = match frame {
380 Some(frame) => {
381 let msg = format!(
382 "Connection closed by server with code '{}' and reason: '{}'",
383 frame.code, frame.reason
384 );
385 let err = WebSocketSourceError::RemoteClosed {
386 frame: frame.into_owned(),
387 };
388 (msg, err)
389 }
390 None => (
391 "Connection closed by server without a close frame".to_string(),
392 WebSocketSourceError::RemoteClosedEmpty,
393 ),
394 };
395
396 let error = TungsteniteError::Io(std::io::Error::new(
397 std::io::ErrorKind::ConnectionAborted,
398 error_message,
399 ));
400 emit!(WebSocketReceiveError { error: &error });
401
402 Err(specific_error)
403 }
404
405 fn is_custom_pong(&self, msg_txt: &str) -> bool {
406 match self.config.pong_message.as_ref() {
407 Some(config) => config.matches(msg_txt),
408 None => false,
409 }
410 }
411}
412
413struct PingManager {
414 interval: PingInterval,
415 waiting_for_pong: bool,
416 message: Message,
417}
418
419impl PingManager {
420 fn new(config: &WebSocketConfig) -> Self {
421 let ping_message = if let Some(ping_msg) = &config.ping_message {
422 Message::Text(ping_msg.clone())
423 } else {
424 Message::Ping(vec![])
425 };
426
427 Self {
428 interval: PingInterval::new(config.common.ping_interval.map(u64::from)),
429 waiting_for_pong: false,
430 message: ping_message,
431 }
432 }
433
434 const fn record_pong(&mut self) {
435 self.waiting_for_pong = false;
436 }
437
438 async fn tick(&mut self, ws_sink: &mut WebSocketSink) -> Result<(), WebSocketSourceError> {
439 self.interval.tick().await;
440
441 if self.waiting_for_pong {
442 return Err(WebSocketSourceError::PongTimeout);
443 }
444
445 ws_sink.send(self.message.clone()).await.map_err(|error| {
446 emit!(WebSocketSendError { error: &error });
447 WebSocketSourceError::Tungstenite { source: error }
448 })?;
449
450 self.waiting_for_pong = true;
451 Ok(())
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use crate::test_util::components::run_and_assert_source_error;
458 use crate::{
459 common::websocket::WebSocketCommonConfig,
460 sources::websocket::config::PongMessage,
461 sources::websocket::config::WebSocketConfig,
462 test_util::{
463 components::{run_and_assert_source_compliance, SOURCE_TAGS},
464 next_addr,
465 },
466 };
467 use futures::{sink::SinkExt, StreamExt};
468 use std::borrow::Cow;
469 use std::num::NonZeroU64;
470 use tokio::{net::TcpListener, time::Duration};
471 use tokio_tungstenite::tungstenite::{
472 protocol::frame::coding::CloseCode, protocol::frame::CloseFrame,
473 };
474 use tokio_tungstenite::{accept_async, tungstenite::Message};
475 use url::Url;
476 use vector_lib::codecs::decoding::DeserializerConfig;
477
478 fn make_config(uri: &str) -> WebSocketConfig {
479 WebSocketConfig {
480 common: WebSocketCommonConfig {
481 uri: Url::parse(uri).unwrap().to_string(),
482 ..Default::default()
483 },
484 ..Default::default()
485 }
486 }
487
488 async fn start_binary_push_server() -> String {
490 let addr = next_addr();
491 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
492 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
493
494 tokio::spawn(async move {
495 let (stream, _) = listener.accept().await.unwrap();
496 let mut websocket = accept_async(stream).await.expect("Failed to accept");
497
498 let binary_payload = br#"{"message": "binary data"}"#.to_vec();
499 websocket
500 .send(Message::Binary(binary_payload))
501 .await
502 .unwrap();
503 });
504
505 server_addr
506 }
507
508 async fn start_push_server() -> String {
510 let addr = next_addr();
511 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
512 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
513
514 tokio::spawn(async move {
515 let (stream, _) = listener.accept().await.unwrap();
517 let mut websocket = accept_async(stream).await.expect("Failed to accept");
518
519 websocket
521 .send(Message::Text("message from server".to_string()))
522 .await
523 .unwrap();
524 });
525
526 server_addr
527 }
528
529 async fn start_subscribe_server(initial_message: String, response_message: String) -> String {
532 let addr = next_addr();
533 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
534 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
535
536 tokio::spawn(async move {
537 let (stream, _) = listener.accept().await.unwrap();
538 let mut websocket = accept_async(stream).await.expect("Failed to accept");
539
540 if let Some(Ok(Message::Text(msg))) = websocket.next().await {
542 if msg == initial_message {
543 websocket
545 .send(Message::Text(response_message))
546 .await
547 .unwrap();
548 }
549 }
550 });
551
552 server_addr
553 }
554
555 async fn start_reconnect_server() -> String {
556 let addr = next_addr();
557 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
558 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
559
560 tokio::spawn(async move {
561 let (stream, _) = listener.accept().await.unwrap();
563 let mut websocket = accept_async(stream).await.expect("Failed to accept");
564 websocket
565 .send(Message::Text("first message".to_string()))
566 .await
567 .unwrap();
568 websocket.close(None).await.unwrap();
570
571 let (stream, _) = listener.accept().await.unwrap();
573 let mut websocket = accept_async(stream).await.expect("Failed to accept");
574 websocket
575 .send(Message::Text("second message".to_string()))
576 .await
577 .unwrap();
578 });
579
580 server_addr
581 }
582
583 #[tokio::test(flavor = "multi_thread")]
584 async fn websocket_source_reconnects_after_disconnect() {
585 let server_addr = start_reconnect_server().await;
586 let config = make_config(&server_addr);
587
588 let events =
590 run_and_assert_source_compliance(config, Duration::from_secs(5), &SOURCE_TAGS).await;
591
592 assert_eq!(
593 events.len(),
594 2,
595 "Should have received messages from both connections"
596 );
597
598 let event = events[0].as_log();
599 assert_eq!(event["message"], "first message".into());
600
601 let event = events[1].as_log();
602 assert_eq!(event["message"], "second message".into());
603 }
604
605 #[tokio::test(flavor = "multi_thread")]
606 async fn websocket_source_consume_binary_event() {
607 let server_addr = start_binary_push_server().await;
608 let mut config = make_config(&server_addr);
609 let decoding = DeserializerConfig::Json(Default::default());
610 config.decoding = decoding;
611
612 let events =
613 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
614
615 assert!(!events.is_empty(), "No events received from source");
616 let event = events[0].as_log();
617 assert_eq!(event["message"], "binary data".into());
618 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
619 }
620
621 #[tokio::test(flavor = "multi_thread")]
622 async fn websocket_source_consume_event() {
623 let server_addr = start_push_server().await;
624 let config = make_config(&server_addr);
625
626 let events =
628 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
629
630 assert!(!events.is_empty(), "No events received from source");
632 let event = events[0].as_log();
633 assert_eq!(event["message"], "message from server".into());
634 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
635 }
636
637 #[tokio::test(flavor = "multi_thread")]
638 async fn websocket_source_sends_initial_message() {
639 let initial_msg = "{\"action\":\"subscribe\",\"topic\":\"test\"}".to_string();
640 let response_msg = "{\"status\":\"subscribed\",\"topic\":\"test\"}".to_string();
641 let server_addr = start_subscribe_server(initial_msg.clone(), response_msg.clone()).await;
642
643 let mut config = make_config(&server_addr);
644 config.initial_message = Some(initial_msg);
645
646 let events =
647 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
648
649 assert!(!events.is_empty(), "No events received from source");
650 let event = events[0].as_log();
651 assert_eq!(event["message"], response_msg.into());
652 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
653 }
654
655 async fn start_reject_initial_message_server() -> String {
656 let addr = next_addr();
657 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
658 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
659
660 tokio::spawn(async move {
661 let (stream, _) = listener.accept().await.unwrap();
662 let mut websocket = accept_async(stream).await.expect("Failed to accept");
663
664 if websocket.next().await.is_some() {
665 let close_frame = CloseFrame {
666 code: CloseCode::Error,
667 reason: Cow::from("Simulated Internal Server Error"),
668 };
669 let _ = websocket.close(Some(close_frame)).await;
670 }
671 });
672
673 server_addr
674 }
675
676 #[tokio::test(flavor = "multi_thread")]
677 async fn websocket_source_exits_on_rejected_intial_messsage() {
678 let server_addr = start_reject_initial_message_server().await;
679
680 let mut config = make_config(&server_addr);
681 config.initial_message = Some("hello, server!".to_string());
682 config.initial_message_timeout_secs = Duration::from_secs(1);
683
684 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
685 }
686
687 async fn start_unresponsive_server() -> String {
688 let addr = next_addr();
689 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
690 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
691
692 tokio::spawn(async move {
693 if let Ok((stream, _)) = listener.accept().await {
694 let mut websocket = accept_async(stream).await.expect("Failed to accept");
696 while websocket.next().await.is_some() {
698 }
700 }
701 });
702
703 server_addr
704 }
705
706 #[tokio::test(flavor = "multi_thread")]
707 async fn websocket_source_exits_on_pong_timeout() {
708 let server_addr = start_unresponsive_server().await;
709
710 let mut config = make_config(&server_addr);
711 config.common.ping_interval = NonZeroU64::new(3);
712 config.common.ping_timeout = NonZeroU64::new(1);
713 config.ping_message = Some("ping".to_string());
714 config.pong_message = Some(PongMessage::Simple("pong".to_string()));
715
716 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
718 }
719
720 async fn start_blackhole_server() -> String {
721 let addr = next_addr();
722 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
723 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
724
725 tokio::spawn(async move {
726 let (mut _socket, _) = listener.accept().await.unwrap();
727 tokio::time::sleep(Duration::from_secs(10)).await;
728 });
729
730 server_addr
731 }
732
733 #[tokio::test(flavor = "multi_thread")]
734 async fn websocket_source_exits_on_connection_timeout() {
735 let server_addr = start_blackhole_server().await;
736 let mut config = make_config(&server_addr);
737 config.connect_timeout_secs = Duration::from_secs(1);
738
739 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
740 }
741}