vector/sources/websocket/
source.rs

1use std::pin::Pin;
2
3use chrono::Utc;
4use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt};
5use snafu::Snafu;
6use tokio::time;
7use tokio_tungstenite::tungstenite::{
8    Message, error::Error as TungsteniteError, protocol::CloseFrame,
9};
10use tokio_util::codec::FramedRead;
11use vector_lib::{
12    EstimatedJsonEncodedSizeOf,
13    config::LogNamespace,
14    event::{Event, LogEvent},
15    internal_event::{CountByteSize, EventsReceived, InternalEventHandle as _},
16};
17
18use crate::{
19    SourceSender,
20    codecs::Decoder,
21    common::websocket::{PingInterval, WebSocketConnector, is_closed},
22    config::SourceContext,
23    internal_events::{
24        ConnectionOpen, OpenGauge, PROTOCOL, WebSocketBytesReceived,
25        WebSocketConnectionFailedError, WebSocketConnectionShutdown, WebSocketKind,
26        WebSocketMessageReceived, WebSocketReceiveError, WebSocketSendError,
27    },
28    sources::websocket::config::WebSocketConfig,
29    vector_lib::codecs::StreamDecodingError,
30};
31
32macro_rules! fail_with_event {
33    ($context:expr_2021) => {{
34        emit!(WebSocketConnectionFailedError {
35            error: Box::new($context.build())
36        });
37        return $context.fail();
38    }};
39}
40
41type WebSocketSink = Pin<Box<dyn Sink<Message, Error = TungsteniteError> + Send>>;
42type WebSocketStream = Pin<Box<dyn Stream<Item = Result<Message, TungsteniteError>> + Send>>;
43
44pub(crate) struct WebSocketSourceParams {
45    pub connector: WebSocketConnector,
46    pub decoder: Decoder,
47    pub log_namespace: LogNamespace,
48}
49
50pub(crate) struct WebSocketSource {
51    config: WebSocketConfig,
52    params: WebSocketSourceParams,
53}
54
55#[derive(Debug, Snafu)]
56pub enum WebSocketSourceError {
57    #[snafu(display("Connection attempt timed out"))]
58    ConnectTimeout,
59
60    #[snafu(display("Server did not respond to the initial message in time"))]
61    InitialMessageTimeout,
62
63    #[snafu(display(
64        "The connection was closed after sending the initial message, but before a response."
65    ))]
66    ConnectionClosedPrematurely,
67
68    #[snafu(display("Connection closed by server with code '{}' and reason: '{}'", frame.code, frame.reason))]
69    RemoteClosed { frame: CloseFrame<'static> },
70
71    #[snafu(display("Connection closed by server without a close frame"))]
72    RemoteClosedEmpty,
73
74    #[snafu(display("Connection timed out while waiting for a pong response"))]
75    PongTimeout,
76
77    #[snafu(display("A WebSocket error occurred: {}", source))]
78    Tungstenite { source: TungsteniteError },
79}
80
81impl WebSocketSource {
82    pub const fn new(config: WebSocketConfig, params: WebSocketSourceParams) -> Self {
83        Self { config, params }
84    }
85
86    pub async fn run(self, cx: SourceContext) -> Result<(), WebSocketSourceError> {
87        let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
88        let mut ping_manager = PingManager::new(&self.config);
89
90        let mut out = cx.out;
91
92        let (ws_sink, ws_source) = self.connect(&mut out).await?;
93
94        pin_mut!(ws_sink, ws_source);
95
96        loop {
97            let result = tokio::select! {
98                _ = cx.shutdown.clone() => {
99                    info!("Received shutdown signal.");
100                    break;
101                },
102
103                res = ping_manager.tick(&mut ws_sink) => res,
104
105                Some(msg_result) = ws_source.next() => {
106                    match msg_result {
107                        Ok(msg) => self.handle_message(msg, &mut ping_manager, &mut out).await,
108                        Err(error) => {
109                            emit!(WebSocketReceiveError { error: &error });
110                            Err(WebSocketSourceError::Tungstenite { source: error })
111                        }
112                    }
113                }
114            };
115
116            if let Err(error) = result {
117                match error {
118                    WebSocketSourceError::RemoteClosed { frame } => {
119                        warn!(
120                            message = "Connection closed by server.",
121                            code = %frame.code,
122                            reason = %frame.reason
123                        );
124                        emit!(WebSocketConnectionShutdown);
125                    }
126                    WebSocketSourceError::RemoteClosedEmpty => {
127                        warn!("Connection closed by server without a close frame.");
128                        emit!(WebSocketConnectionShutdown);
129                    }
130                    WebSocketSourceError::PongTimeout => {
131                        error!("Disconnecting due to pong timeout.");
132                        emit!(WebSocketReceiveError {
133                            error: &TungsteniteError::Io(std::io::Error::new(
134                                std::io::ErrorKind::TimedOut,
135                                "Pong timeout"
136                            ))
137                        });
138                        emit!(WebSocketConnectionShutdown);
139                        return Err(error);
140                    }
141                    WebSocketSourceError::Tungstenite { source: ws_err } => {
142                        if is_closed(&ws_err) {
143                            emit!(WebSocketConnectionShutdown);
144                        }
145                        error!(message = "WebSocket connection error.", error = %ws_err);
146                    }
147                    // These errors should only happen during `connect` or `reconnect`,
148                    // not in the main loop's result.
149                    WebSocketSourceError::ConnectTimeout
150                    | WebSocketSourceError::InitialMessageTimeout
151                    | WebSocketSourceError::ConnectionClosedPrematurely => {
152                        unreachable!(
153                            "Encountered a connection-time error during runtime: {:?}",
154                            error
155                        );
156                    }
157                }
158                if self
159                    .reconnect(&mut out, &mut ws_sink, &mut ws_source)
160                    .await
161                    .is_err()
162                {
163                    break;
164                }
165            }
166        }
167        Ok(())
168    }
169
170    async fn handle_message(
171        &self,
172        msg: Message,
173        ping_manager: &mut PingManager,
174        out: &mut SourceSender,
175    ) -> Result<(), WebSocketSourceError> {
176        match msg {
177            Message::Pong(_) => {
178                ping_manager.record_pong();
179                Ok(())
180            }
181            Message::Text(msg_txt) => {
182                if self.is_custom_pong(&msg_txt) {
183                    ping_manager.record_pong();
184                    debug!("Received custom pong response.");
185                } else {
186                    self.process_message(&msg_txt, WebSocketKind::Text, out)
187                        .await;
188                }
189                Ok(())
190            }
191            Message::Binary(msg_bytes) => {
192                self.process_message(&msg_bytes, WebSocketKind::Binary, out)
193                    .await;
194                Ok(())
195            }
196            Message::Ping(_) => Ok(()),
197            Message::Close(frame) => self.handle_close_frame(frame),
198            Message::Frame(_) => {
199                warn!("Unsupported message type received: frame.");
200                Ok(())
201            }
202        }
203    }
204
205    async fn process_message<T>(&self, payload: &T, kind: WebSocketKind, out: &mut SourceSender)
206    where
207        T: AsRef<[u8]> + ?Sized,
208    {
209        let payload_bytes = payload.as_ref();
210
211        emit!(WebSocketBytesReceived {
212            byte_size: payload_bytes.len(),
213            url: &self.config.common.uri,
214            protocol: PROTOCOL,
215            kind,
216        });
217        let mut stream = FramedRead::new(payload_bytes, self.params.decoder.clone());
218
219        while let Some(result) = stream.next().await {
220            match result {
221                Ok((events, _)) => {
222                    if events.is_empty() {
223                        continue;
224                    }
225
226                    let event_count = events.len();
227                    let byte_size = events.estimated_json_encoded_size_of();
228
229                    register!(EventsReceived).emit(CountByteSize(event_count, byte_size));
230                    emit!(WebSocketMessageReceived {
231                        count: event_count,
232                        byte_size,
233                        url: &self.config.common.uri,
234                        protocol: PROTOCOL,
235                        kind,
236                    });
237
238                    let events_with_meta = events.into_iter().map(|mut event| {
239                        if let Event::Log(event) = &mut event {
240                            self.add_metadata(event);
241                        }
242                        event
243                    });
244
245                    if let Err(error) = out.send_batch(events_with_meta).await {
246                        error!(message = "Error sending events.", %error);
247                    }
248                }
249                Err(error) => {
250                    if !error.can_continue() {
251                        break;
252                    }
253                }
254            }
255        }
256    }
257
258    fn add_metadata(&self, event: &mut LogEvent) {
259        self.params
260            .log_namespace
261            .insert_standard_vector_source_metadata(event, WebSocketConfig::NAME, Utc::now());
262    }
263
264    async fn reconnect(
265        &self,
266        out: &mut SourceSender,
267        ws_sink: &mut WebSocketSink,
268        ws_source: &mut WebSocketStream,
269    ) -> Result<(), WebSocketSourceError> {
270        info!("Reconnecting to WebSocket...");
271
272        let (new_sink, new_source) = self.connect(out).await?;
273
274        *ws_sink = new_sink;
275        *ws_source = new_source;
276
277        info!("Reconnected to Websocket.");
278
279        Ok(())
280    }
281
282    async fn connect(
283        &self,
284        out: &mut SourceSender,
285    ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
286        let (mut ws_sink, mut ws_source) = self.try_create_sink_and_stream().await?;
287
288        if self.config.initial_message.is_some() {
289            self.send_initial_message(&mut ws_sink, &mut ws_source, out)
290                .await?;
291        }
292
293        Ok((ws_sink, ws_source))
294    }
295
296    async fn try_create_sink_and_stream(
297        &self,
298    ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
299        let ws_stream = self
300            .params
301            .connector
302            .connect_backoff_with_timeout(self.config.connect_timeout_secs)
303            .await;
304
305        let (sink, stream) = ws_stream.split();
306
307        Ok((Box::pin(sink), Box::pin(stream)))
308    }
309
310    async fn send_initial_message(
311        &self,
312        ws_sink: &mut WebSocketSink,
313        ws_source: &mut WebSocketStream,
314        out: &mut SourceSender,
315    ) -> Result<(), WebSocketSourceError> {
316        let initial_message = self.config.initial_message.as_ref().unwrap();
317        ws_sink
318            .send(Message::Text(initial_message.clone()))
319            .await
320            .map_err(|error| {
321                emit!(WebSocketSendError { error: &error });
322                WebSocketSourceError::Tungstenite { source: error }
323            })?;
324
325        debug!("Sent initial message, awaiting response from server.");
326
327        let response =
328            match time::timeout(self.config.initial_message_timeout_secs, ws_source.next()).await {
329                Ok(Some(msg)) => msg,
330                Ok(None) => fail_with_event!(ConnectionClosedPrematurelySnafu),
331                Err(_) => fail_with_event!(InitialMessageTimeoutSnafu),
332            };
333
334        let message = response.map_err(|source| {
335            emit!(WebSocketReceiveError { error: &source });
336            WebSocketSourceError::Tungstenite { source }
337        })?;
338
339        match message {
340            Message::Text(txt) => {
341                self.process_message(&txt, WebSocketKind::Text, out).await;
342                Ok(())
343            }
344            Message::Binary(bin) => {
345                self.process_message(&bin, WebSocketKind::Binary, out).await;
346                Ok(())
347            }
348            Message::Close(frame) => self.handle_close_frame(frame),
349            _ => Ok(()),
350        }
351    }
352
353    fn handle_close_frame(
354        &self,
355        frame: Option<CloseFrame<'_>>,
356    ) -> Result<(), WebSocketSourceError> {
357        let (error_message, specific_error) = match frame {
358            Some(frame) => {
359                let msg = format!(
360                    "Connection closed by server with code '{}' and reason: '{}'",
361                    frame.code, frame.reason
362                );
363                let err = WebSocketSourceError::RemoteClosed {
364                    frame: frame.into_owned(),
365                };
366                (msg, err)
367            }
368            None => (
369                "Connection closed by server without a close frame".to_string(),
370                WebSocketSourceError::RemoteClosedEmpty,
371            ),
372        };
373
374        let error = TungsteniteError::Io(std::io::Error::new(
375            std::io::ErrorKind::ConnectionAborted,
376            error_message,
377        ));
378        emit!(WebSocketReceiveError { error: &error });
379
380        Err(specific_error)
381    }
382
383    fn is_custom_pong(&self, msg_txt: &str) -> bool {
384        match self.config.pong_message.as_ref() {
385            Some(config) => config.matches(msg_txt),
386            None => false,
387        }
388    }
389}
390
391struct PingManager {
392    interval: PingInterval,
393    waiting_for_pong: bool,
394    message: Message,
395}
396
397impl PingManager {
398    fn new(config: &WebSocketConfig) -> Self {
399        let ping_message = if let Some(ping_msg) = &config.ping_message {
400            Message::Text(ping_msg.clone())
401        } else {
402            Message::Ping(vec![])
403        };
404
405        Self {
406            interval: PingInterval::new(config.common.ping_interval.map(u64::from)),
407            waiting_for_pong: false,
408            message: ping_message,
409        }
410    }
411
412    const fn record_pong(&mut self) {
413        self.waiting_for_pong = false;
414    }
415
416    async fn tick(&mut self, ws_sink: &mut WebSocketSink) -> Result<(), WebSocketSourceError> {
417        self.interval.tick().await;
418
419        if self.waiting_for_pong {
420            return Err(WebSocketSourceError::PongTimeout);
421        }
422
423        ws_sink.send(self.message.clone()).await.map_err(|error| {
424            emit!(WebSocketSendError { error: &error });
425            WebSocketSourceError::Tungstenite { source: error }
426        })?;
427
428        self.waiting_for_pong = true;
429        Ok(())
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use std::{borrow::Cow, num::NonZeroU64};
436
437    use futures::{StreamExt, sink::SinkExt};
438    use tokio::{net::TcpListener, time::Duration};
439    use tokio_tungstenite::{
440        accept_async,
441        tungstenite::{
442            Message,
443            protocol::frame::{CloseFrame, coding::CloseCode},
444        },
445    };
446    use url::Url;
447    use vector_lib::codecs::decoding::DeserializerConfig;
448
449    use crate::{
450        common::websocket::WebSocketCommonConfig,
451        sources::websocket::config::{PongMessage, WebSocketConfig},
452        test_util::{
453            addr::next_addr,
454            components::{
455                SOURCE_TAGS, run_and_assert_source_compliance, run_and_assert_source_error,
456            },
457        },
458    };
459
460    fn make_config(uri: &str) -> WebSocketConfig {
461        WebSocketConfig {
462            common: WebSocketCommonConfig {
463                uri: Url::parse(uri).unwrap().to_string(),
464                ..Default::default()
465            },
466            ..Default::default()
467        }
468    }
469
470    /// Starts a WebSocket server that pushes a binary message to the first client.
471    async fn start_binary_push_server() -> String {
472        let (_guard, addr) = next_addr();
473        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
474        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
475
476        tokio::spawn(async move {
477            let (stream, _) = listener.accept().await.unwrap();
478            let mut websocket = accept_async(stream).await.expect("Failed to accept");
479
480            let binary_payload = br#"{"message": "binary data"}"#.to_vec();
481            websocket
482                .send(Message::Binary(binary_payload))
483                .await
484                .unwrap();
485        });
486
487        server_addr
488    }
489
490    /// Starts a WebSocket server that pushes a message to the first client that connects.
491    async fn start_push_server() -> String {
492        let (_guard, addr) = next_addr();
493        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
494        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
495
496        tokio::spawn(async move {
497            // Accept one connection
498            let (stream, _) = listener.accept().await.unwrap();
499            let mut websocket = accept_async(stream).await.expect("Failed to accept");
500
501            // Immediately send a message to the connected client (which will be our source)
502            websocket
503                .send(Message::Text("message from server".to_string()))
504                .await
505                .unwrap();
506        });
507
508        server_addr
509    }
510
511    /// Starts a WebSocket server that waits for an initial message from the client,
512    /// and upon receiving it, sends a confirmation message back.
513    async fn start_subscribe_server(initial_message: String, response_message: String) -> String {
514        let (_guard, addr) = next_addr();
515        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
516        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
517
518        tokio::spawn(async move {
519            let (stream, _) = listener.accept().await.unwrap();
520            let mut websocket = accept_async(stream).await.expect("Failed to accept");
521
522            // Wait for the initial message from the client
523            if let Some(Ok(Message::Text(msg))) = websocket.next().await
524                && msg == initial_message
525            {
526                // Received correct initial message, send response
527                websocket
528                    .send(Message::Text(response_message))
529                    .await
530                    .unwrap();
531            }
532        });
533
534        server_addr
535    }
536
537    async fn start_reconnect_server() -> String {
538        let (_guard, addr) = next_addr();
539        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
540        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
541
542        tokio::spawn(async move {
543            // First connection
544            let (stream, _) = listener.accept().await.unwrap();
545            let mut websocket = accept_async(stream).await.expect("Failed to accept");
546            websocket
547                .send(Message::Text("first message".to_string()))
548                .await
549                .unwrap();
550            // Close the connection to force a reconnect from the client
551            websocket.close(None).await.unwrap();
552
553            // Second connection
554            let (stream, _) = listener.accept().await.unwrap();
555            let mut websocket = accept_async(stream).await.expect("Failed to accept");
556            websocket
557                .send(Message::Text("second message".to_string()))
558                .await
559                .unwrap();
560        });
561
562        server_addr
563    }
564
565    #[tokio::test(flavor = "multi_thread")]
566    async fn websocket_source_reconnects_after_disconnect() {
567        let server_addr = start_reconnect_server().await;
568        let config = make_config(&server_addr);
569
570        // Run for a longer duration to allow for reconnection
571        let events =
572            run_and_assert_source_compliance(config, Duration::from_secs(5), &SOURCE_TAGS).await;
573
574        assert_eq!(
575            events.len(),
576            2,
577            "Should have received messages from both connections"
578        );
579
580        let event = events[0].as_log();
581        assert_eq!(event["message"], "first message".into());
582
583        let event = events[1].as_log();
584        assert_eq!(event["message"], "second message".into());
585    }
586
587    #[tokio::test(flavor = "multi_thread")]
588    async fn websocket_source_consume_binary_event() {
589        let server_addr = start_binary_push_server().await;
590        let mut config = make_config(&server_addr);
591        let decoding = DeserializerConfig::Json(Default::default());
592        config.decoding = decoding;
593
594        let events =
595            run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
596
597        assert!(!events.is_empty(), "No events received from source");
598        let event = events[0].as_log();
599        assert_eq!(event["message"], "binary data".into());
600        assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
601    }
602
603    #[tokio::test(flavor = "multi_thread")]
604    async fn websocket_source_consume_event() {
605        let server_addr = start_push_server().await;
606        let config = make_config(&server_addr);
607
608        // Run the source, which will connect to the server and receive the pushed message.
609        let events =
610            run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
611
612        // Now assert that the event was received and is correct.
613        assert!(!events.is_empty(), "No events received from source");
614        let event = events[0].as_log();
615        assert_eq!(event["message"], "message from server".into());
616        assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
617    }
618
619    #[tokio::test(flavor = "multi_thread")]
620    async fn websocket_source_sends_initial_message() {
621        let initial_msg = "{\"action\":\"subscribe\",\"topic\":\"test\"}".to_string();
622        let response_msg = "{\"status\":\"subscribed\",\"topic\":\"test\"}".to_string();
623        let server_addr = start_subscribe_server(initial_msg.clone(), response_msg.clone()).await;
624
625        let mut config = make_config(&server_addr);
626        config.initial_message = Some(initial_msg);
627
628        let events =
629            run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
630
631        assert!(!events.is_empty(), "No events received from source");
632        let event = events[0].as_log();
633        assert_eq!(event["message"], response_msg.into());
634        assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
635    }
636
637    async fn start_reject_initial_message_server() -> String {
638        let (_guard, addr) = next_addr();
639        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
640        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
641
642        tokio::spawn(async move {
643            let (stream, _) = listener.accept().await.unwrap();
644            let mut websocket = accept_async(stream).await.expect("Failed to accept");
645
646            if websocket.next().await.is_some() {
647                let close_frame = CloseFrame {
648                    code: CloseCode::Error,
649                    reason: Cow::from("Simulated Internal Server Error"),
650                };
651                let _ = websocket.close(Some(close_frame)).await;
652            }
653        });
654
655        server_addr
656    }
657
658    #[tokio::test(flavor = "multi_thread")]
659    async fn websocket_source_exits_on_rejected_intial_messsage() {
660        let server_addr = start_reject_initial_message_server().await;
661
662        let mut config = make_config(&server_addr);
663        config.initial_message = Some("hello, server!".to_string());
664        config.initial_message_timeout_secs = Duration::from_secs(1);
665
666        run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
667    }
668
669    async fn start_unresponsive_server() -> String {
670        let (_guard, addr) = next_addr();
671        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
672        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
673
674        tokio::spawn(async move {
675            if let Ok((stream, _)) = listener.accept().await {
676                // Accept the connection to establish the WebSocket.
677                let mut websocket = accept_async(stream).await.expect("Failed to accept");
678                // Simply wait forever without responding to pings.
679                while websocket.next().await.is_some() {
680                    // Do nothing
681                }
682            }
683        });
684
685        server_addr
686    }
687
688    #[tokio::test(flavor = "multi_thread")]
689    async fn websocket_source_exits_on_pong_timeout() {
690        let server_addr = start_unresponsive_server().await;
691
692        let mut config = make_config(&server_addr);
693        config.common.ping_interval = NonZeroU64::new(3);
694        config.common.ping_timeout = NonZeroU64::new(1);
695        config.ping_message = Some("ping".to_string());
696        config.pong_message = Some(PongMessage::Simple("pong".to_string()));
697
698        // The source should fail because the server never sends a pong.
699        run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
700    }
701
702    async fn start_blackhole_server() -> String {
703        let (_guard, addr) = next_addr();
704        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
705        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
706
707        tokio::spawn(async move {
708            let (mut _socket, _) = listener.accept().await.unwrap();
709            tokio::time::sleep(Duration::from_secs(10)).await;
710        });
711
712        server_addr
713    }
714
715    #[tokio::test(flavor = "multi_thread")]
716    async fn websocket_source_exits_on_connection_timeout() {
717        let server_addr = start_blackhole_server().await;
718        let mut config = make_config(&server_addr);
719        config.connect_timeout_secs = Duration::from_secs(1);
720
721        run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
722    }
723}