vector/sources/websocket/
source.rs

1use std::pin::Pin;
2
3use chrono::Utc;
4use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt};
5use snafu::Snafu;
6use tokio::time;
7use tokio_tungstenite::tungstenite::{
8    Message, error::Error as TungsteniteError, protocol::CloseFrame,
9};
10use tokio_util::codec::FramedRead;
11use vector_lib::{
12    EstimatedJsonEncodedSizeOf,
13    config::LogNamespace,
14    event::{Event, LogEvent},
15    internal_event::{CountByteSize, EventsReceived, InternalEventHandle as _},
16};
17
18use crate::{
19    SourceSender,
20    codecs::Decoder,
21    common::websocket::{PingInterval, WebSocketConnector, is_closed},
22    config::SourceContext,
23    internal_events::{
24        ConnectionOpen, OpenGauge, PROTOCOL, WebSocketBytesReceived, WebSocketConnectionError,
25        WebSocketConnectionEstablished, WebSocketConnectionFailedError,
26        WebSocketConnectionShutdown, WebSocketKind, WebSocketMessageReceived,
27        WebSocketReceiveError, WebSocketSendError,
28    },
29    sources::websocket::config::WebSocketConfig,
30    vector_lib::codecs::StreamDecodingError,
31};
32
33macro_rules! fail_with_event {
34    ($context:expr_2021) => {{
35        emit!(WebSocketConnectionFailedError {
36            error: Box::new($context.build())
37        });
38        return $context.fail();
39    }};
40}
41
42type WebSocketSink = Pin<Box<dyn Sink<Message, Error = TungsteniteError> + Send>>;
43type WebSocketStream = Pin<Box<dyn Stream<Item = Result<Message, TungsteniteError>> + Send>>;
44
45pub(crate) struct WebSocketSourceParams {
46    pub connector: WebSocketConnector,
47    pub decoder: Decoder,
48    pub log_namespace: LogNamespace,
49}
50
51pub(crate) struct WebSocketSource {
52    config: WebSocketConfig,
53    params: WebSocketSourceParams,
54}
55
56#[derive(Debug, Snafu)]
57pub enum WebSocketSourceError {
58    #[snafu(display("Connection attempt timed out"))]
59    ConnectTimeout,
60
61    #[snafu(display("Server did not respond to the initial message in time"))]
62    InitialMessageTimeout,
63
64    #[snafu(display(
65        "The connection was closed after sending the initial message, but before a response."
66    ))]
67    ConnectionClosedPrematurely,
68
69    #[snafu(display("Connection closed by server with code '{}' and reason: '{}'", frame.code, frame.reason))]
70    RemoteClosed { frame: CloseFrame<'static> },
71
72    #[snafu(display("Connection closed by server without a close frame"))]
73    RemoteClosedEmpty,
74
75    #[snafu(display("Connection timed out while waiting for a pong response"))]
76    PongTimeout,
77
78    #[snafu(display("A WebSocket error occurred: {}", source))]
79    Tungstenite { source: TungsteniteError },
80}
81
82impl WebSocketSource {
83    pub const fn new(config: WebSocketConfig, params: WebSocketSourceParams) -> Self {
84        Self { config, params }
85    }
86
87    pub async fn run(self, cx: SourceContext) -> Result<(), WebSocketSourceError> {
88        let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
89        let mut ping_manager = PingManager::new(&self.config);
90
91        let mut out = cx.out;
92
93        let (ws_sink, ws_source) = self.connect(&mut out).await?;
94
95        pin_mut!(ws_sink, ws_source);
96
97        loop {
98            let result = tokio::select! {
99                _ = cx.shutdown.clone() => {
100                    info!("Received shutdown signal.");
101                    break;
102                },
103
104                res = ping_manager.tick(&mut ws_sink) => res,
105
106                Some(msg_result) = ws_source.next() => {
107                    match msg_result {
108                        Ok(msg) => self.handle_message(msg, &mut ping_manager, &mut out).await,
109                        Err(error) => {
110                            emit!(WebSocketReceiveError { error: &error });
111                            Err(WebSocketSourceError::Tungstenite { source: error })
112                        }
113                    }
114                }
115            };
116
117            if let Err(error) = result {
118                match error {
119                    WebSocketSourceError::RemoteClosed { frame } => {
120                        warn!(
121                            message = "Connection closed by server.",
122                            code = %frame.code,
123                            reason = %frame.reason
124                        );
125                        emit!(WebSocketConnectionShutdown);
126                    }
127                    WebSocketSourceError::RemoteClosedEmpty => {
128                        warn!("Connection closed by server without a close frame.");
129                        emit!(WebSocketConnectionShutdown);
130                    }
131                    WebSocketSourceError::PongTimeout => {
132                        error!("Disconnecting due to pong timeout.");
133                        emit!(WebSocketReceiveError {
134                            error: &TungsteniteError::Io(std::io::Error::new(
135                                std::io::ErrorKind::TimedOut,
136                                "Pong timeout"
137                            ))
138                        });
139                        emit!(WebSocketConnectionShutdown);
140                        return Err(error);
141                    }
142                    WebSocketSourceError::Tungstenite { source: ws_err } => {
143                        if is_closed(&ws_err) {
144                            emit!(WebSocketConnectionShutdown);
145                        }
146                        error!(message = "WebSocket connection error.", error = %ws_err);
147                    }
148                    // These errors should only happen during `connect` or `reconnect`,
149                    // not in the main loop's result.
150                    WebSocketSourceError::ConnectTimeout
151                    | WebSocketSourceError::InitialMessageTimeout
152                    | WebSocketSourceError::ConnectionClosedPrematurely => {
153                        unreachable!(
154                            "Encountered a connection-time error during runtime: {:?}",
155                            error
156                        );
157                    }
158                }
159                if self
160                    .reconnect(&mut out, &mut ws_sink, &mut ws_source)
161                    .await
162                    .is_err()
163                {
164                    break;
165                }
166            }
167        }
168        Ok(())
169    }
170
171    async fn handle_message(
172        &self,
173        msg: Message,
174        ping_manager: &mut PingManager,
175        out: &mut SourceSender,
176    ) -> Result<(), WebSocketSourceError> {
177        match msg {
178            Message::Pong(_) => {
179                ping_manager.record_pong();
180                Ok(())
181            }
182            Message::Text(msg_txt) => {
183                if self.is_custom_pong(&msg_txt) {
184                    ping_manager.record_pong();
185                    debug!("Received custom pong response.");
186                } else {
187                    self.process_message(&msg_txt, WebSocketKind::Text, out)
188                        .await;
189                }
190                Ok(())
191            }
192            Message::Binary(msg_bytes) => {
193                self.process_message(&msg_bytes, WebSocketKind::Binary, out)
194                    .await;
195                Ok(())
196            }
197            Message::Ping(_) => Ok(()),
198            Message::Close(frame) => self.handle_close_frame(frame),
199            Message::Frame(_) => {
200                warn!("Unsupported message type received: frame.");
201                Ok(())
202            }
203        }
204    }
205
206    async fn process_message<T>(&self, payload: &T, kind: WebSocketKind, out: &mut SourceSender)
207    where
208        T: AsRef<[u8]> + ?Sized,
209    {
210        let payload_bytes = payload.as_ref();
211
212        emit!(WebSocketBytesReceived {
213            byte_size: payload_bytes.len(),
214            url: &self.config.common.uri,
215            protocol: PROTOCOL,
216            kind,
217        });
218        let mut stream = FramedRead::new(payload_bytes, self.params.decoder.clone());
219
220        while let Some(result) = stream.next().await {
221            match result {
222                Ok((events, _)) => {
223                    if events.is_empty() {
224                        continue;
225                    }
226
227                    let event_count = events.len();
228                    let byte_size = events.estimated_json_encoded_size_of();
229
230                    register!(EventsReceived).emit(CountByteSize(event_count, byte_size));
231                    emit!(WebSocketMessageReceived {
232                        count: event_count,
233                        byte_size,
234                        url: &self.config.common.uri,
235                        protocol: PROTOCOL,
236                        kind,
237                    });
238
239                    let events_with_meta = events.into_iter().map(|mut event| {
240                        if let Event::Log(event) = &mut event {
241                            self.add_metadata(event);
242                        }
243                        event
244                    });
245
246                    if let Err(error) = out.send_batch(events_with_meta).await {
247                        error!(message = "Error sending events.", %error);
248                    }
249                }
250                Err(error) => {
251                    if !error.can_continue() {
252                        break;
253                    }
254                }
255            }
256        }
257    }
258
259    fn add_metadata(&self, event: &mut LogEvent) {
260        self.params
261            .log_namespace
262            .insert_standard_vector_source_metadata(event, WebSocketConfig::NAME, Utc::now());
263    }
264
265    async fn reconnect(
266        &self,
267        out: &mut SourceSender,
268        ws_sink: &mut WebSocketSink,
269        ws_source: &mut WebSocketStream,
270    ) -> Result<(), WebSocketSourceError> {
271        info!("Reconnecting to WebSocket...");
272
273        let (new_sink, new_source) = self.connect(out).await?;
274
275        *ws_sink = new_sink;
276        *ws_source = new_source;
277
278        info!("Reconnected to Websocket.");
279
280        Ok(())
281    }
282
283    async fn connect(
284        &self,
285        out: &mut SourceSender,
286    ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
287        let (mut ws_sink, mut ws_source) = self.try_create_sink_and_stream().await?;
288
289        if self.config.initial_message.is_some() {
290            self.send_initial_message(&mut ws_sink, &mut ws_source, out)
291                .await?;
292        }
293
294        Ok((ws_sink, ws_source))
295    }
296
297    async fn try_create_sink_and_stream(
298        &self,
299    ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
300        let connect_future = self.params.connector.connect_backoff();
301        let timeout = self.config.connect_timeout_secs;
302
303        let ws_stream = match time::timeout(timeout, connect_future).await {
304            Ok(ws) => ws,
305            Err(_) => {
306                emit!(WebSocketConnectionError {
307                    error: TungsteniteError::Io(std::io::Error::new(
308                        std::io::ErrorKind::TimedOut,
309                        "Connection attempt timed out",
310                    ))
311                });
312                return Err(WebSocketSourceError::ConnectTimeout);
313            }
314        };
315
316        emit!(WebSocketConnectionEstablished {});
317        let (sink, stream) = ws_stream.split();
318
319        Ok((Box::pin(sink), Box::pin(stream)))
320    }
321
322    async fn send_initial_message(
323        &self,
324        ws_sink: &mut WebSocketSink,
325        ws_source: &mut WebSocketStream,
326        out: &mut SourceSender,
327    ) -> Result<(), WebSocketSourceError> {
328        let initial_message = self.config.initial_message.as_ref().unwrap();
329        ws_sink
330            .send(Message::Text(initial_message.clone()))
331            .await
332            .map_err(|error| {
333                emit!(WebSocketSendError { error: &error });
334                WebSocketSourceError::Tungstenite { source: error }
335            })?;
336
337        debug!("Sent initial message, awaiting response from server.");
338
339        let response =
340            match time::timeout(self.config.initial_message_timeout_secs, ws_source.next()).await {
341                Ok(Some(msg)) => msg,
342                Ok(None) => fail_with_event!(ConnectionClosedPrematurelySnafu),
343                Err(_) => fail_with_event!(InitialMessageTimeoutSnafu),
344            };
345
346        let message = response.map_err(|source| {
347            emit!(WebSocketReceiveError { error: &source });
348            WebSocketSourceError::Tungstenite { source }
349        })?;
350
351        match message {
352            Message::Text(txt) => {
353                self.process_message(&txt, WebSocketKind::Text, out).await;
354                Ok(())
355            }
356            Message::Binary(bin) => {
357                self.process_message(&bin, WebSocketKind::Binary, out).await;
358                Ok(())
359            }
360            Message::Close(frame) => self.handle_close_frame(frame),
361            _ => Ok(()),
362        }
363    }
364
365    fn handle_close_frame(
366        &self,
367        frame: Option<CloseFrame<'_>>,
368    ) -> Result<(), WebSocketSourceError> {
369        let (error_message, specific_error) = match frame {
370            Some(frame) => {
371                let msg = format!(
372                    "Connection closed by server with code '{}' and reason: '{}'",
373                    frame.code, frame.reason
374                );
375                let err = WebSocketSourceError::RemoteClosed {
376                    frame: frame.into_owned(),
377                };
378                (msg, err)
379            }
380            None => (
381                "Connection closed by server without a close frame".to_string(),
382                WebSocketSourceError::RemoteClosedEmpty,
383            ),
384        };
385
386        let error = TungsteniteError::Io(std::io::Error::new(
387            std::io::ErrorKind::ConnectionAborted,
388            error_message,
389        ));
390        emit!(WebSocketReceiveError { error: &error });
391
392        Err(specific_error)
393    }
394
395    fn is_custom_pong(&self, msg_txt: &str) -> bool {
396        match self.config.pong_message.as_ref() {
397            Some(config) => config.matches(msg_txt),
398            None => false,
399        }
400    }
401}
402
403struct PingManager {
404    interval: PingInterval,
405    waiting_for_pong: bool,
406    message: Message,
407}
408
409impl PingManager {
410    fn new(config: &WebSocketConfig) -> Self {
411        let ping_message = if let Some(ping_msg) = &config.ping_message {
412            Message::Text(ping_msg.clone())
413        } else {
414            Message::Ping(vec![])
415        };
416
417        Self {
418            interval: PingInterval::new(config.common.ping_interval.map(u64::from)),
419            waiting_for_pong: false,
420            message: ping_message,
421        }
422    }
423
424    const fn record_pong(&mut self) {
425        self.waiting_for_pong = false;
426    }
427
428    async fn tick(&mut self, ws_sink: &mut WebSocketSink) -> Result<(), WebSocketSourceError> {
429        self.interval.tick().await;
430
431        if self.waiting_for_pong {
432            return Err(WebSocketSourceError::PongTimeout);
433        }
434
435        ws_sink.send(self.message.clone()).await.map_err(|error| {
436            emit!(WebSocketSendError { error: &error });
437            WebSocketSourceError::Tungstenite { source: error }
438        })?;
439
440        self.waiting_for_pong = true;
441        Ok(())
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use std::{borrow::Cow, num::NonZeroU64};
448
449    use futures::{StreamExt, sink::SinkExt};
450    use tokio::{net::TcpListener, time::Duration};
451    use tokio_tungstenite::{
452        accept_async,
453        tungstenite::{
454            Message,
455            protocol::frame::{CloseFrame, coding::CloseCode},
456        },
457    };
458    use url::Url;
459    use vector_lib::codecs::decoding::DeserializerConfig;
460
461    use crate::{
462        common::websocket::WebSocketCommonConfig,
463        sources::websocket::config::{PongMessage, WebSocketConfig},
464        test_util::{
465            components::{
466                SOURCE_TAGS, run_and_assert_source_compliance, run_and_assert_source_error,
467            },
468            next_addr,
469        },
470    };
471
472    fn make_config(uri: &str) -> WebSocketConfig {
473        WebSocketConfig {
474            common: WebSocketCommonConfig {
475                uri: Url::parse(uri).unwrap().to_string(),
476                ..Default::default()
477            },
478            ..Default::default()
479        }
480    }
481
482    /// Starts a WebSocket server that pushes a binary message to the first client.
483    async fn start_binary_push_server() -> String {
484        let addr = next_addr();
485        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
486        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
487
488        tokio::spawn(async move {
489            let (stream, _) = listener.accept().await.unwrap();
490            let mut websocket = accept_async(stream).await.expect("Failed to accept");
491
492            let binary_payload = br#"{"message": "binary data"}"#.to_vec();
493            websocket
494                .send(Message::Binary(binary_payload))
495                .await
496                .unwrap();
497        });
498
499        server_addr
500    }
501
502    /// Starts a WebSocket server that pushes a message to the first client that connects.
503    async fn start_push_server() -> String {
504        let addr = next_addr();
505        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
506        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
507
508        tokio::spawn(async move {
509            // Accept one connection
510            let (stream, _) = listener.accept().await.unwrap();
511            let mut websocket = accept_async(stream).await.expect("Failed to accept");
512
513            // Immediately send a message to the connected client (which will be our source)
514            websocket
515                .send(Message::Text("message from server".to_string()))
516                .await
517                .unwrap();
518        });
519
520        server_addr
521    }
522
523    /// Starts a WebSocket server that waits for an initial message from the client,
524    /// and upon receiving it, sends a confirmation message back.
525    async fn start_subscribe_server(initial_message: String, response_message: String) -> String {
526        let addr = next_addr();
527        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
528        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
529
530        tokio::spawn(async move {
531            let (stream, _) = listener.accept().await.unwrap();
532            let mut websocket = accept_async(stream).await.expect("Failed to accept");
533
534            // Wait for the initial message from the client
535            if let Some(Ok(Message::Text(msg))) = websocket.next().await
536                && msg == initial_message
537            {
538                // Received correct initial message, send response
539                websocket
540                    .send(Message::Text(response_message))
541                    .await
542                    .unwrap();
543            }
544        });
545
546        server_addr
547    }
548
549    async fn start_reconnect_server() -> String {
550        let addr = next_addr();
551        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
552        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
553
554        tokio::spawn(async move {
555            // First connection
556            let (stream, _) = listener.accept().await.unwrap();
557            let mut websocket = accept_async(stream).await.expect("Failed to accept");
558            websocket
559                .send(Message::Text("first message".to_string()))
560                .await
561                .unwrap();
562            // Close the connection to force a reconnect from the client
563            websocket.close(None).await.unwrap();
564
565            // Second connection
566            let (stream, _) = listener.accept().await.unwrap();
567            let mut websocket = accept_async(stream).await.expect("Failed to accept");
568            websocket
569                .send(Message::Text("second message".to_string()))
570                .await
571                .unwrap();
572        });
573
574        server_addr
575    }
576
577    #[tokio::test(flavor = "multi_thread")]
578    async fn websocket_source_reconnects_after_disconnect() {
579        let server_addr = start_reconnect_server().await;
580        let config = make_config(&server_addr);
581
582        // Run for a longer duration to allow for reconnection
583        let events =
584            run_and_assert_source_compliance(config, Duration::from_secs(5), &SOURCE_TAGS).await;
585
586        assert_eq!(
587            events.len(),
588            2,
589            "Should have received messages from both connections"
590        );
591
592        let event = events[0].as_log();
593        assert_eq!(event["message"], "first message".into());
594
595        let event = events[1].as_log();
596        assert_eq!(event["message"], "second message".into());
597    }
598
599    #[tokio::test(flavor = "multi_thread")]
600    async fn websocket_source_consume_binary_event() {
601        let server_addr = start_binary_push_server().await;
602        let mut config = make_config(&server_addr);
603        let decoding = DeserializerConfig::Json(Default::default());
604        config.decoding = decoding;
605
606        let events =
607            run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
608
609        assert!(!events.is_empty(), "No events received from source");
610        let event = events[0].as_log();
611        assert_eq!(event["message"], "binary data".into());
612        assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
613    }
614
615    #[tokio::test(flavor = "multi_thread")]
616    async fn websocket_source_consume_event() {
617        let server_addr = start_push_server().await;
618        let config = make_config(&server_addr);
619
620        // Run the source, which will connect to the server and receive the pushed message.
621        let events =
622            run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
623
624        // Now assert that the event was received and is correct.
625        assert!(!events.is_empty(), "No events received from source");
626        let event = events[0].as_log();
627        assert_eq!(event["message"], "message from server".into());
628        assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
629    }
630
631    #[tokio::test(flavor = "multi_thread")]
632    async fn websocket_source_sends_initial_message() {
633        let initial_msg = "{\"action\":\"subscribe\",\"topic\":\"test\"}".to_string();
634        let response_msg = "{\"status\":\"subscribed\",\"topic\":\"test\"}".to_string();
635        let server_addr = start_subscribe_server(initial_msg.clone(), response_msg.clone()).await;
636
637        let mut config = make_config(&server_addr);
638        config.initial_message = Some(initial_msg);
639
640        let events =
641            run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
642
643        assert!(!events.is_empty(), "No events received from source");
644        let event = events[0].as_log();
645        assert_eq!(event["message"], response_msg.into());
646        assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
647    }
648
649    async fn start_reject_initial_message_server() -> String {
650        let addr = next_addr();
651        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
652        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
653
654        tokio::spawn(async move {
655            let (stream, _) = listener.accept().await.unwrap();
656            let mut websocket = accept_async(stream).await.expect("Failed to accept");
657
658            if websocket.next().await.is_some() {
659                let close_frame = CloseFrame {
660                    code: CloseCode::Error,
661                    reason: Cow::from("Simulated Internal Server Error"),
662                };
663                let _ = websocket.close(Some(close_frame)).await;
664            }
665        });
666
667        server_addr
668    }
669
670    #[tokio::test(flavor = "multi_thread")]
671    async fn websocket_source_exits_on_rejected_intial_messsage() {
672        let server_addr = start_reject_initial_message_server().await;
673
674        let mut config = make_config(&server_addr);
675        config.initial_message = Some("hello, server!".to_string());
676        config.initial_message_timeout_secs = Duration::from_secs(1);
677
678        run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
679    }
680
681    async fn start_unresponsive_server() -> String {
682        let addr = next_addr();
683        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
684        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
685
686        tokio::spawn(async move {
687            if let Ok((stream, _)) = listener.accept().await {
688                // Accept the connection to establish the WebSocket.
689                let mut websocket = accept_async(stream).await.expect("Failed to accept");
690                // Simply wait forever without responding to pings.
691                while websocket.next().await.is_some() {
692                    // Do nothing
693                }
694            }
695        });
696
697        server_addr
698    }
699
700    #[tokio::test(flavor = "multi_thread")]
701    async fn websocket_source_exits_on_pong_timeout() {
702        let server_addr = start_unresponsive_server().await;
703
704        let mut config = make_config(&server_addr);
705        config.common.ping_interval = NonZeroU64::new(3);
706        config.common.ping_timeout = NonZeroU64::new(1);
707        config.ping_message = Some("ping".to_string());
708        config.pong_message = Some(PongMessage::Simple("pong".to_string()));
709
710        // The source should fail because the server never sends a pong.
711        run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
712    }
713
714    async fn start_blackhole_server() -> String {
715        let addr = next_addr();
716        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
717        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
718
719        tokio::spawn(async move {
720            let (mut _socket, _) = listener.accept().await.unwrap();
721            tokio::time::sleep(Duration::from_secs(10)).await;
722        });
723
724        server_addr
725    }
726
727    #[tokio::test(flavor = "multi_thread")]
728    async fn websocket_source_exits_on_connection_timeout() {
729        let server_addr = start_blackhole_server().await;
730        let mut config = make_config(&server_addr);
731        config.connect_timeout_secs = Duration::from_secs(1);
732
733        run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
734    }
735}