1use std::pin::Pin;
2
3use chrono::{DateTime, Utc};
4use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt};
5use snafu::Snafu;
6use tokio::time;
7use tokio_tungstenite::tungstenite::{
8 Message, error::Error as TungsteniteError, protocol::CloseFrame,
9};
10use vector_lib::{
11 EstimatedJsonEncodedSizeOf,
12 codecs::DecoderFramedRead,
13 config::LogNamespace,
14 event::{Event, LogEvent},
15 internal_event::{CountByteSize, EventsReceived, InternalEventHandle as _},
16};
17
18use crate::{
19 SourceSender,
20 codecs::Decoder,
21 common::websocket::{PingInterval, WebSocketConnector, is_closed},
22 config::SourceContext,
23 internal_events::{
24 ConnectionOpen, OpenGauge, PROTOCOL, WebSocketBytesReceived,
25 WebSocketConnectionFailedError, WebSocketConnectionShutdown, WebSocketKind,
26 WebSocketMessageReceived, WebSocketReceiveError, WebSocketSendError,
27 },
28 sources::websocket::config::WebSocketConfig,
29 vector_lib::codecs::StreamDecodingError,
30};
31
32macro_rules! fail_with_event {
33 ($context:expr_2021) => {{
34 emit!(WebSocketConnectionFailedError {
35 error: Box::new($context.build())
36 });
37 return $context.fail();
38 }};
39}
40
41type WebSocketSink = Pin<Box<dyn Sink<Message, Error = TungsteniteError> + Send>>;
42type WebSocketStream = Pin<Box<dyn Stream<Item = Result<Message, TungsteniteError>> + Send>>;
43
44pub(crate) struct WebSocketSourceParams {
45 pub connector: WebSocketConnector,
46 pub decoder: Decoder,
47 pub log_namespace: LogNamespace,
48}
49
50pub(crate) struct WebSocketSource {
51 config: WebSocketConfig,
52 params: WebSocketSourceParams,
53}
54
55#[derive(Debug, Snafu)]
56pub enum WebSocketSourceError {
57 #[snafu(display("Connection attempt timed out"))]
58 ConnectTimeout,
59
60 #[snafu(display("Server did not respond to the initial message in time"))]
61 InitialMessageTimeout,
62
63 #[snafu(display(
64 "The connection was closed after sending the initial message, but before a response."
65 ))]
66 ConnectionClosedPrematurely,
67
68 #[snafu(display("Connection closed by server with code '{}' and reason: '{}'", frame.code, frame.reason))]
69 RemoteClosed { frame: CloseFrame<'static> },
70
71 #[snafu(display("Connection closed by server without a close frame"))]
72 RemoteClosedEmpty,
73
74 #[snafu(display("Connection timed out while waiting for a pong response"))]
75 PongTimeout,
76
77 #[snafu(display("A WebSocket error occurred: {}", source))]
78 Tungstenite { source: TungsteniteError },
79}
80
81impl WebSocketSource {
82 pub const fn new(config: WebSocketConfig, params: WebSocketSourceParams) -> Self {
83 Self { config, params }
84 }
85
86 pub async fn run(self, cx: SourceContext) -> Result<(), WebSocketSourceError> {
87 let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
88 let mut ping_manager = PingManager::new(&self.config);
89
90 let mut out = cx.out;
91
92 let (ws_sink, ws_source) = self.connect(&mut out).await?;
93
94 pin_mut!(ws_sink, ws_source);
95
96 loop {
97 let result = tokio::select! {
98 _ = cx.shutdown.clone() => {
99 info!("Received shutdown signal.");
100 break;
101 },
102
103 res = ping_manager.tick(&mut ws_sink) => res,
104
105 Some(msg_result) = ws_source.next() => {
106 match msg_result {
107 Ok(msg) => self.handle_message(msg, &mut ping_manager, &mut out).await,
108 Err(error) => {
109 emit!(WebSocketReceiveError { error: &error });
110 Err(WebSocketSourceError::Tungstenite { source: error })
111 }
112 }
113 }
114 };
115
116 if let Err(error) = result {
117 match error {
118 WebSocketSourceError::RemoteClosed { frame } => {
119 warn!(
120 message = "Connection closed by server.",
121 code = %frame.code,
122 reason = %frame.reason
123 );
124 emit!(WebSocketConnectionShutdown);
125 }
126 WebSocketSourceError::RemoteClosedEmpty => {
127 warn!("Connection closed by server without a close frame.");
128 emit!(WebSocketConnectionShutdown);
129 }
130 WebSocketSourceError::PongTimeout => {
131 error!("Disconnecting due to pong timeout.");
132 emit!(WebSocketReceiveError {
133 error: &TungsteniteError::Io(std::io::Error::new(
134 std::io::ErrorKind::TimedOut,
135 "Pong timeout"
136 ))
137 });
138 emit!(WebSocketConnectionShutdown);
139 return Err(error);
140 }
141 WebSocketSourceError::Tungstenite { source: ws_err } => {
142 if is_closed(&ws_err) {
143 emit!(WebSocketConnectionShutdown);
144 }
145 error!(message = "WebSocket connection error.", error = %ws_err);
146 }
147 WebSocketSourceError::ConnectTimeout
150 | WebSocketSourceError::InitialMessageTimeout
151 | WebSocketSourceError::ConnectionClosedPrematurely => {
152 unreachable!(
153 "Encountered a connection-time error during runtime: {:?}",
154 error
155 );
156 }
157 }
158 if self
159 .reconnect(&mut out, &mut ws_sink, &mut ws_source)
160 .await
161 .is_err()
162 {
163 break;
164 }
165 }
166 }
167 Ok(())
168 }
169
170 async fn handle_message(
171 &self,
172 msg: Message,
173 ping_manager: &mut PingManager,
174 out: &mut SourceSender,
175 ) -> Result<(), WebSocketSourceError> {
176 match msg {
177 Message::Pong(_) => {
178 ping_manager.record_pong();
179 Ok(())
180 }
181 Message::Text(msg_txt) => {
182 if self.is_custom_pong(&msg_txt) {
183 ping_manager.record_pong();
184 debug!("Received custom pong response.");
185 } else {
186 self.process_message(&msg_txt, WebSocketKind::Text, out)
187 .await;
188 }
189 Ok(())
190 }
191 Message::Binary(msg_bytes) => {
192 self.process_message(&msg_bytes, WebSocketKind::Binary, out)
193 .await;
194 Ok(())
195 }
196 Message::Ping(_) => Ok(()),
197 Message::Close(frame) => self.handle_close_frame(frame),
198 Message::Frame(_) => {
199 warn!("Unsupported message type received: frame.");
200 Ok(())
201 }
202 }
203 }
204
205 async fn process_message<T>(&self, payload: &T, kind: WebSocketKind, out: &mut SourceSender)
206 where
207 T: AsRef<[u8]> + ?Sized,
208 {
209 let payload_bytes = payload.as_ref();
210
211 emit!(WebSocketBytesReceived {
212 byte_size: payload_bytes.len(),
213 url: &self.config.common.uri,
214 protocol: PROTOCOL,
215 kind,
216 });
217 let mut stream = DecoderFramedRead::new(payload_bytes, self.params.decoder.clone());
218
219 while let Some(result) = stream.next().await {
220 match result {
221 Ok((events, _)) => {
222 if events.is_empty() {
223 continue;
224 }
225
226 let event_count = events.len();
227 let byte_size = events.estimated_json_encoded_size_of();
228
229 register!(EventsReceived).emit(CountByteSize(event_count, byte_size));
230 emit!(WebSocketMessageReceived {
231 count: event_count,
232 byte_size,
233 url: &self.config.common.uri,
234 protocol: PROTOCOL,
235 kind,
236 });
237
238 let now = Utc::now();
239 let events_with_meta = events.into_iter().map(|mut event| {
240 if let Event::Log(event) = &mut event {
241 self.add_metadata(event, now);
242 }
243 event
244 });
245
246 if let Err(error) = out.send_batch(events_with_meta).await {
247 error!(message = "Error sending events.", %error);
248 }
249 }
250 Err(error) => {
251 if !error.can_continue() {
252 break;
253 }
254 }
255 }
256 }
257 }
258
259 fn add_metadata(&self, event: &mut LogEvent, now: DateTime<Utc>) {
260 self.params
261 .log_namespace
262 .insert_standard_vector_source_metadata(event, WebSocketConfig::NAME, now);
263 }
264
265 async fn reconnect(
266 &self,
267 out: &mut SourceSender,
268 ws_sink: &mut WebSocketSink,
269 ws_source: &mut WebSocketStream,
270 ) -> Result<(), WebSocketSourceError> {
271 info!("Reconnecting to WebSocket...");
272
273 let (new_sink, new_source) = self.connect(out).await?;
274
275 *ws_sink = new_sink;
276 *ws_source = new_source;
277
278 info!("Reconnected to Websocket.");
279
280 Ok(())
281 }
282
283 async fn connect(
284 &self,
285 out: &mut SourceSender,
286 ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
287 let (mut ws_sink, mut ws_source) = self.try_create_sink_and_stream().await?;
288
289 if self.config.initial_message.is_some() {
290 self.send_initial_message(&mut ws_sink, &mut ws_source, out)
291 .await?;
292 }
293
294 Ok((ws_sink, ws_source))
295 }
296
297 async fn try_create_sink_and_stream(
298 &self,
299 ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
300 let ws_stream = self
301 .params
302 .connector
303 .connect_backoff_with_timeout(self.config.connect_timeout_secs)
304 .await;
305
306 let (sink, stream) = ws_stream.split();
307
308 Ok((Box::pin(sink), Box::pin(stream)))
309 }
310
311 async fn send_initial_message(
312 &self,
313 ws_sink: &mut WebSocketSink,
314 ws_source: &mut WebSocketStream,
315 out: &mut SourceSender,
316 ) -> Result<(), WebSocketSourceError> {
317 let initial_message = self.config.initial_message.as_ref().unwrap();
318 ws_sink
319 .send(Message::Text(initial_message.clone()))
320 .await
321 .map_err(|error| {
322 emit!(WebSocketSendError { error: &error });
323 WebSocketSourceError::Tungstenite { source: error }
324 })?;
325
326 debug!("Sent initial message, awaiting response from server.");
327
328 let response =
329 match time::timeout(self.config.initial_message_timeout_secs, ws_source.next()).await {
330 Ok(Some(msg)) => msg,
331 Ok(None) => fail_with_event!(ConnectionClosedPrematurelySnafu),
332 Err(_) => fail_with_event!(InitialMessageTimeoutSnafu),
333 };
334
335 let message = response.map_err(|source| {
336 emit!(WebSocketReceiveError { error: &source });
337 WebSocketSourceError::Tungstenite { source }
338 })?;
339
340 match message {
341 Message::Text(txt) => {
342 self.process_message(&txt, WebSocketKind::Text, out).await;
343 Ok(())
344 }
345 Message::Binary(bin) => {
346 self.process_message(&bin, WebSocketKind::Binary, out).await;
347 Ok(())
348 }
349 Message::Close(frame) => self.handle_close_frame(frame),
350 _ => Ok(()),
351 }
352 }
353
354 fn handle_close_frame(
355 &self,
356 frame: Option<CloseFrame<'_>>,
357 ) -> Result<(), WebSocketSourceError> {
358 let (error_message, specific_error) = match frame {
359 Some(frame) => {
360 let msg = format!(
361 "Connection closed by server with code '{}' and reason: '{}'",
362 frame.code, frame.reason
363 );
364 let err = WebSocketSourceError::RemoteClosed {
365 frame: frame.into_owned(),
366 };
367 (msg, err)
368 }
369 None => (
370 "Connection closed by server without a close frame".to_string(),
371 WebSocketSourceError::RemoteClosedEmpty,
372 ),
373 };
374
375 let error = TungsteniteError::Io(std::io::Error::new(
376 std::io::ErrorKind::ConnectionAborted,
377 error_message,
378 ));
379 emit!(WebSocketReceiveError { error: &error });
380
381 Err(specific_error)
382 }
383
384 fn is_custom_pong(&self, msg_txt: &str) -> bool {
385 match self.config.pong_message.as_ref() {
386 Some(config) => config.matches(msg_txt),
387 None => false,
388 }
389 }
390}
391
392struct PingManager {
393 interval: PingInterval,
394 waiting_for_pong: bool,
395 message: Message,
396}
397
398impl PingManager {
399 fn new(config: &WebSocketConfig) -> Self {
400 let ping_message = if let Some(ping_msg) = &config.ping_message {
401 Message::Text(ping_msg.clone())
402 } else {
403 Message::Ping(vec![])
404 };
405
406 Self {
407 interval: PingInterval::new(config.common.ping_interval.map(u64::from)),
408 waiting_for_pong: false,
409 message: ping_message,
410 }
411 }
412
413 const fn record_pong(&mut self) {
414 self.waiting_for_pong = false;
415 }
416
417 async fn tick(&mut self, ws_sink: &mut WebSocketSink) -> Result<(), WebSocketSourceError> {
418 self.interval.tick().await;
419
420 if self.waiting_for_pong {
421 return Err(WebSocketSourceError::PongTimeout);
422 }
423
424 ws_sink.send(self.message.clone()).await.map_err(|error| {
425 emit!(WebSocketSendError { error: &error });
426 WebSocketSourceError::Tungstenite { source: error }
427 })?;
428
429 self.waiting_for_pong = true;
430 Ok(())
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use std::{borrow::Cow, num::NonZeroU64};
437
438 use futures::{StreamExt, sink::SinkExt};
439 use tokio::{net::TcpListener, time::Duration};
440 use tokio_tungstenite::{
441 accept_async,
442 tungstenite::{
443 Message,
444 protocol::frame::{CloseFrame, coding::CloseCode},
445 },
446 };
447 use url::Url;
448 use vector_lib::codecs::decoding::DeserializerConfig;
449
450 use crate::{
451 common::websocket::WebSocketCommonConfig,
452 sources::websocket::config::{PongMessage, WebSocketConfig},
453 test_util::{
454 addr::next_addr,
455 components::{
456 SOURCE_TAGS, run_and_assert_source_compliance, run_and_assert_source_error,
457 },
458 },
459 };
460
461 fn make_config(uri: &str) -> WebSocketConfig {
462 WebSocketConfig {
463 common: WebSocketCommonConfig {
464 uri: Url::parse(uri).unwrap().to_string(),
465 ..Default::default()
466 },
467 ..Default::default()
468 }
469 }
470
471 async fn start_binary_push_server() -> String {
473 let (_guard, addr) = next_addr();
474 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
475 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
476
477 tokio::spawn(async move {
478 let (stream, _) = listener.accept().await.unwrap();
479 let mut websocket = accept_async(stream).await.expect("Failed to accept");
480
481 let binary_payload = br#"{"message": "binary data"}"#.to_vec();
482 websocket
483 .send(Message::Binary(binary_payload))
484 .await
485 .unwrap();
486 });
487
488 server_addr
489 }
490
491 async fn start_push_server() -> String {
493 let (_guard, addr) = next_addr();
494 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
495 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
496
497 tokio::spawn(async move {
498 let (stream, _) = listener.accept().await.unwrap();
500 let mut websocket = accept_async(stream).await.expect("Failed to accept");
501
502 websocket
504 .send(Message::Text("message from server".to_string()))
505 .await
506 .unwrap();
507 });
508
509 server_addr
510 }
511
512 async fn start_subscribe_server(initial_message: String, response_message: String) -> String {
515 let (_guard, addr) = next_addr();
516 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
517 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
518
519 tokio::spawn(async move {
520 let (stream, _) = listener.accept().await.unwrap();
521 let mut websocket = accept_async(stream).await.expect("Failed to accept");
522
523 if let Some(Ok(Message::Text(msg))) = websocket.next().await
525 && msg == initial_message
526 {
527 websocket
529 .send(Message::Text(response_message))
530 .await
531 .unwrap();
532 }
533 });
534
535 server_addr
536 }
537
538 async fn start_reconnect_server() -> String {
539 let (_guard, addr) = next_addr();
540 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
541 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
542
543 tokio::spawn(async move {
544 let (stream, _) = listener.accept().await.unwrap();
546 let mut websocket = accept_async(stream).await.expect("Failed to accept");
547 websocket
548 .send(Message::Text("first message".to_string()))
549 .await
550 .unwrap();
551 websocket.close(None).await.unwrap();
553
554 let (stream, _) = listener.accept().await.unwrap();
556 let mut websocket = accept_async(stream).await.expect("Failed to accept");
557 websocket
558 .send(Message::Text("second message".to_string()))
559 .await
560 .unwrap();
561 });
562
563 server_addr
564 }
565
566 #[tokio::test(flavor = "multi_thread")]
567 async fn websocket_source_reconnects_after_disconnect() {
568 let server_addr = start_reconnect_server().await;
569 let config = make_config(&server_addr);
570
571 let events =
573 run_and_assert_source_compliance(config, Duration::from_secs(5), &SOURCE_TAGS).await;
574
575 assert_eq!(
576 events.len(),
577 2,
578 "Should have received messages from both connections"
579 );
580
581 let event = events[0].as_log();
582 assert_eq!(event["message"], "first message".into());
583
584 let event = events[1].as_log();
585 assert_eq!(event["message"], "second message".into());
586 }
587
588 #[tokio::test(flavor = "multi_thread")]
589 async fn websocket_source_consume_binary_event() {
590 let server_addr = start_binary_push_server().await;
591 let mut config = make_config(&server_addr);
592 let decoding = DeserializerConfig::Json(Default::default());
593 config.decoding = decoding;
594
595 let events =
596 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
597
598 assert!(!events.is_empty(), "No events received from source");
599 let event = events[0].as_log();
600 assert_eq!(event["message"], "binary data".into());
601 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
602 }
603
604 #[tokio::test(flavor = "multi_thread")]
605 async fn websocket_source_consume_event() {
606 let server_addr = start_push_server().await;
607 let config = make_config(&server_addr);
608
609 let events =
611 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
612
613 assert!(!events.is_empty(), "No events received from source");
615 let event = events[0].as_log();
616 assert_eq!(event["message"], "message from server".into());
617 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
618 }
619
620 #[tokio::test(flavor = "multi_thread")]
621 async fn websocket_source_sends_initial_message() {
622 let initial_msg = "{\"action\":\"subscribe\",\"topic\":\"test\"}".to_string();
623 let response_msg = "{\"status\":\"subscribed\",\"topic\":\"test\"}".to_string();
624 let server_addr = start_subscribe_server(initial_msg.clone(), response_msg.clone()).await;
625
626 let mut config = make_config(&server_addr);
627 config.initial_message = Some(initial_msg);
628
629 let events =
630 run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
631
632 assert!(!events.is_empty(), "No events received from source");
633 let event = events[0].as_log();
634 assert_eq!(event["message"], response_msg.into());
635 assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
636 }
637
638 async fn start_reject_initial_message_server() -> String {
639 let (_guard, addr) = next_addr();
640 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
641 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
642
643 tokio::spawn(async move {
644 let (stream, _) = listener.accept().await.unwrap();
645 let mut websocket = accept_async(stream).await.expect("Failed to accept");
646
647 if websocket.next().await.is_some() {
648 let close_frame = CloseFrame {
649 code: CloseCode::Error,
650 reason: Cow::from("Simulated Internal Server Error"),
651 };
652 let _ = websocket.close(Some(close_frame)).await;
653 }
654 });
655
656 server_addr
657 }
658
659 #[tokio::test(flavor = "multi_thread")]
660 async fn websocket_source_exits_on_rejected_intial_messsage() {
661 let server_addr = start_reject_initial_message_server().await;
662
663 let mut config = make_config(&server_addr);
664 config.initial_message = Some("hello, server!".to_string());
665 config.initial_message_timeout_secs = Duration::from_secs(1);
666
667 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
668 }
669
670 async fn start_unresponsive_server() -> String {
671 let (_guard, addr) = next_addr();
672 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
673 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
674
675 tokio::spawn(async move {
676 if let Ok((stream, _)) = listener.accept().await {
677 let mut websocket = accept_async(stream).await.expect("Failed to accept");
679 while websocket.next().await.is_some() {
681 }
683 }
684 });
685
686 server_addr
687 }
688
689 #[tokio::test(flavor = "multi_thread")]
690 async fn websocket_source_exits_on_pong_timeout() {
691 let server_addr = start_unresponsive_server().await;
692
693 let mut config = make_config(&server_addr);
694 config.common.ping_interval = NonZeroU64::new(3);
695 config.common.ping_timeout = NonZeroU64::new(1);
696 config.ping_message = Some("ping".to_string());
697 config.pong_message = Some(PongMessage::Simple("pong".to_string()));
698
699 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
701 }
702
703 async fn start_blackhole_server() -> String {
704 let (_guard, addr) = next_addr();
705 let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
706 let server_addr = format!("ws://{}", listener.local_addr().unwrap());
707
708 tokio::spawn(async move {
709 let (mut _socket, _) = listener.accept().await.unwrap();
710 tokio::time::sleep(Duration::from_secs(10)).await;
711 });
712
713 server_addr
714 }
715
716 #[tokio::test(flavor = "multi_thread")]
717 async fn websocket_source_exits_on_connection_timeout() {
718 let server_addr = start_blackhole_server().await;
719 let mut config = make_config(&server_addr);
720 config.connect_timeout_secs = Duration::from_secs(1);
721
722 run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
723 }
724}