vector/sources/websocket/
source.rs

1use std::pin::Pin;
2
3use chrono::{DateTime, Utc};
4use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt};
5use snafu::Snafu;
6use tokio::time;
7use tokio_tungstenite::tungstenite::{
8    Message, error::Error as TungsteniteError, protocol::CloseFrame,
9};
10use vector_lib::{
11    EstimatedJsonEncodedSizeOf,
12    codecs::DecoderFramedRead,
13    config::LogNamespace,
14    event::{Event, LogEvent},
15    internal_event::{CountByteSize, EventsReceived, InternalEventHandle as _},
16};
17
18use crate::{
19    SourceSender,
20    codecs::Decoder,
21    common::websocket::{PingInterval, WebSocketConnector, is_closed},
22    config::SourceContext,
23    internal_events::{
24        ConnectionOpen, OpenGauge, PROTOCOL, WebSocketBytesReceived,
25        WebSocketConnectionFailedError, WebSocketConnectionShutdown, WebSocketKind,
26        WebSocketMessageReceived, WebSocketReceiveError, WebSocketSendError,
27    },
28    sources::websocket::config::WebSocketConfig,
29    vector_lib::codecs::StreamDecodingError,
30};
31
32macro_rules! fail_with_event {
33    ($context:expr_2021) => {{
34        emit!(WebSocketConnectionFailedError {
35            error: Box::new($context.build())
36        });
37        return $context.fail();
38    }};
39}
40
41type WebSocketSink = Pin<Box<dyn Sink<Message, Error = TungsteniteError> + Send>>;
42type WebSocketStream = Pin<Box<dyn Stream<Item = Result<Message, TungsteniteError>> + Send>>;
43
44pub(crate) struct WebSocketSourceParams {
45    pub connector: WebSocketConnector,
46    pub decoder: Decoder,
47    pub log_namespace: LogNamespace,
48}
49
50pub(crate) struct WebSocketSource {
51    config: WebSocketConfig,
52    params: WebSocketSourceParams,
53}
54
55#[derive(Debug, Snafu)]
56pub enum WebSocketSourceError {
57    #[snafu(display("Connection attempt timed out"))]
58    ConnectTimeout,
59
60    #[snafu(display("Server did not respond to the initial message in time"))]
61    InitialMessageTimeout,
62
63    #[snafu(display(
64        "The connection was closed after sending the initial message, but before a response."
65    ))]
66    ConnectionClosedPrematurely,
67
68    #[snafu(display("Connection closed by server with code '{}' and reason: '{}'", frame.code, frame.reason))]
69    RemoteClosed { frame: CloseFrame<'static> },
70
71    #[snafu(display("Connection closed by server without a close frame"))]
72    RemoteClosedEmpty,
73
74    #[snafu(display("Connection timed out while waiting for a pong response"))]
75    PongTimeout,
76
77    #[snafu(display("A WebSocket error occurred: {}", source))]
78    Tungstenite { source: TungsteniteError },
79}
80
81impl WebSocketSource {
82    pub const fn new(config: WebSocketConfig, params: WebSocketSourceParams) -> Self {
83        Self { config, params }
84    }
85
86    pub async fn run(self, cx: SourceContext) -> Result<(), WebSocketSourceError> {
87        let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
88        let mut ping_manager = PingManager::new(&self.config);
89
90        let mut out = cx.out;
91
92        let (ws_sink, ws_source) = self.connect(&mut out).await?;
93
94        pin_mut!(ws_sink, ws_source);
95
96        loop {
97            let result = tokio::select! {
98                _ = cx.shutdown.clone() => {
99                    info!("Received shutdown signal.");
100                    break;
101                },
102
103                res = ping_manager.tick(&mut ws_sink) => res,
104
105                Some(msg_result) = ws_source.next() => {
106                    match msg_result {
107                        Ok(msg) => self.handle_message(msg, &mut ping_manager, &mut out).await,
108                        Err(error) => {
109                            emit!(WebSocketReceiveError { error: &error });
110                            Err(WebSocketSourceError::Tungstenite { source: error })
111                        }
112                    }
113                }
114            };
115
116            if let Err(error) = result {
117                match error {
118                    WebSocketSourceError::RemoteClosed { frame } => {
119                        warn!(
120                            message = "Connection closed by server.",
121                            code = %frame.code,
122                            reason = %frame.reason
123                        );
124                        emit!(WebSocketConnectionShutdown);
125                    }
126                    WebSocketSourceError::RemoteClosedEmpty => {
127                        warn!("Connection closed by server without a close frame.");
128                        emit!(WebSocketConnectionShutdown);
129                    }
130                    WebSocketSourceError::PongTimeout => {
131                        error!("Disconnecting due to pong timeout.");
132                        emit!(WebSocketReceiveError {
133                            error: &TungsteniteError::Io(std::io::Error::new(
134                                std::io::ErrorKind::TimedOut,
135                                "Pong timeout"
136                            ))
137                        });
138                        emit!(WebSocketConnectionShutdown);
139                        return Err(error);
140                    }
141                    WebSocketSourceError::Tungstenite { source: ws_err } => {
142                        if is_closed(&ws_err) {
143                            emit!(WebSocketConnectionShutdown);
144                        }
145                        error!(message = "WebSocket connection error.", error = %ws_err);
146                    }
147                    // These errors should only happen during `connect` or `reconnect`,
148                    // not in the main loop's result.
149                    WebSocketSourceError::ConnectTimeout
150                    | WebSocketSourceError::InitialMessageTimeout
151                    | WebSocketSourceError::ConnectionClosedPrematurely => {
152                        unreachable!(
153                            "Encountered a connection-time error during runtime: {:?}",
154                            error
155                        );
156                    }
157                }
158                if self
159                    .reconnect(&mut out, &mut ws_sink, &mut ws_source)
160                    .await
161                    .is_err()
162                {
163                    break;
164                }
165            }
166        }
167        Ok(())
168    }
169
170    async fn handle_message(
171        &self,
172        msg: Message,
173        ping_manager: &mut PingManager,
174        out: &mut SourceSender,
175    ) -> Result<(), WebSocketSourceError> {
176        match msg {
177            Message::Pong(_) => {
178                ping_manager.record_pong();
179                Ok(())
180            }
181            Message::Text(msg_txt) => {
182                if self.is_custom_pong(&msg_txt) {
183                    ping_manager.record_pong();
184                    debug!("Received custom pong response.");
185                } else {
186                    self.process_message(&msg_txt, WebSocketKind::Text, out)
187                        .await;
188                }
189                Ok(())
190            }
191            Message::Binary(msg_bytes) => {
192                self.process_message(&msg_bytes, WebSocketKind::Binary, out)
193                    .await;
194                Ok(())
195            }
196            Message::Ping(_) => Ok(()),
197            Message::Close(frame) => self.handle_close_frame(frame),
198            Message::Frame(_) => {
199                warn!("Unsupported message type received: frame.");
200                Ok(())
201            }
202        }
203    }
204
205    async fn process_message<T>(&self, payload: &T, kind: WebSocketKind, out: &mut SourceSender)
206    where
207        T: AsRef<[u8]> + ?Sized,
208    {
209        let payload_bytes = payload.as_ref();
210
211        emit!(WebSocketBytesReceived {
212            byte_size: payload_bytes.len(),
213            url: &self.config.common.uri,
214            protocol: PROTOCOL,
215            kind,
216        });
217        let mut stream = DecoderFramedRead::new(payload_bytes, self.params.decoder.clone());
218
219        while let Some(result) = stream.next().await {
220            match result {
221                Ok((events, _)) => {
222                    if events.is_empty() {
223                        continue;
224                    }
225
226                    let event_count = events.len();
227                    let byte_size = events.estimated_json_encoded_size_of();
228
229                    register!(EventsReceived).emit(CountByteSize(event_count, byte_size));
230                    emit!(WebSocketMessageReceived {
231                        count: event_count,
232                        byte_size,
233                        url: &self.config.common.uri,
234                        protocol: PROTOCOL,
235                        kind,
236                    });
237
238                    let now = Utc::now();
239                    let events_with_meta = events.into_iter().map(|mut event| {
240                        if let Event::Log(event) = &mut event {
241                            self.add_metadata(event, now);
242                        }
243                        event
244                    });
245
246                    if let Err(error) = out.send_batch(events_with_meta).await {
247                        error!(message = "Error sending events.", %error);
248                    }
249                }
250                Err(error) => {
251                    if !error.can_continue() {
252                        break;
253                    }
254                }
255            }
256        }
257    }
258
259    fn add_metadata(&self, event: &mut LogEvent, now: DateTime<Utc>) {
260        self.params
261            .log_namespace
262            .insert_standard_vector_source_metadata(event, WebSocketConfig::NAME, now);
263    }
264
265    async fn reconnect(
266        &self,
267        out: &mut SourceSender,
268        ws_sink: &mut WebSocketSink,
269        ws_source: &mut WebSocketStream,
270    ) -> Result<(), WebSocketSourceError> {
271        info!("Reconnecting to WebSocket...");
272
273        let (new_sink, new_source) = self.connect(out).await?;
274
275        *ws_sink = new_sink;
276        *ws_source = new_source;
277
278        info!("Reconnected to Websocket.");
279
280        Ok(())
281    }
282
283    async fn connect(
284        &self,
285        out: &mut SourceSender,
286    ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
287        let (mut ws_sink, mut ws_source) = self.try_create_sink_and_stream().await?;
288
289        if self.config.initial_message.is_some() {
290            self.send_initial_message(&mut ws_sink, &mut ws_source, out)
291                .await?;
292        }
293
294        Ok((ws_sink, ws_source))
295    }
296
297    async fn try_create_sink_and_stream(
298        &self,
299    ) -> Result<(WebSocketSink, WebSocketStream), WebSocketSourceError> {
300        let ws_stream = self
301            .params
302            .connector
303            .connect_backoff_with_timeout(self.config.connect_timeout_secs)
304            .await;
305
306        let (sink, stream) = ws_stream.split();
307
308        Ok((Box::pin(sink), Box::pin(stream)))
309    }
310
311    async fn send_initial_message(
312        &self,
313        ws_sink: &mut WebSocketSink,
314        ws_source: &mut WebSocketStream,
315        out: &mut SourceSender,
316    ) -> Result<(), WebSocketSourceError> {
317        let initial_message = self.config.initial_message.as_ref().unwrap();
318        ws_sink
319            .send(Message::Text(initial_message.clone()))
320            .await
321            .map_err(|error| {
322                emit!(WebSocketSendError { error: &error });
323                WebSocketSourceError::Tungstenite { source: error }
324            })?;
325
326        debug!("Sent initial message, awaiting response from server.");
327
328        let response =
329            match time::timeout(self.config.initial_message_timeout_secs, ws_source.next()).await {
330                Ok(Some(msg)) => msg,
331                Ok(None) => fail_with_event!(ConnectionClosedPrematurelySnafu),
332                Err(_) => fail_with_event!(InitialMessageTimeoutSnafu),
333            };
334
335        let message = response.map_err(|source| {
336            emit!(WebSocketReceiveError { error: &source });
337            WebSocketSourceError::Tungstenite { source }
338        })?;
339
340        match message {
341            Message::Text(txt) => {
342                self.process_message(&txt, WebSocketKind::Text, out).await;
343                Ok(())
344            }
345            Message::Binary(bin) => {
346                self.process_message(&bin, WebSocketKind::Binary, out).await;
347                Ok(())
348            }
349            Message::Close(frame) => self.handle_close_frame(frame),
350            _ => Ok(()),
351        }
352    }
353
354    fn handle_close_frame(
355        &self,
356        frame: Option<CloseFrame<'_>>,
357    ) -> Result<(), WebSocketSourceError> {
358        let (error_message, specific_error) = match frame {
359            Some(frame) => {
360                let msg = format!(
361                    "Connection closed by server with code '{}' and reason: '{}'",
362                    frame.code, frame.reason
363                );
364                let err = WebSocketSourceError::RemoteClosed {
365                    frame: frame.into_owned(),
366                };
367                (msg, err)
368            }
369            None => (
370                "Connection closed by server without a close frame".to_string(),
371                WebSocketSourceError::RemoteClosedEmpty,
372            ),
373        };
374
375        let error = TungsteniteError::Io(std::io::Error::new(
376            std::io::ErrorKind::ConnectionAborted,
377            error_message,
378        ));
379        emit!(WebSocketReceiveError { error: &error });
380
381        Err(specific_error)
382    }
383
384    fn is_custom_pong(&self, msg_txt: &str) -> bool {
385        match self.config.pong_message.as_ref() {
386            Some(config) => config.matches(msg_txt),
387            None => false,
388        }
389    }
390}
391
392struct PingManager {
393    interval: PingInterval,
394    waiting_for_pong: bool,
395    message: Message,
396}
397
398impl PingManager {
399    fn new(config: &WebSocketConfig) -> Self {
400        let ping_message = if let Some(ping_msg) = &config.ping_message {
401            Message::Text(ping_msg.clone())
402        } else {
403            Message::Ping(vec![])
404        };
405
406        Self {
407            interval: PingInterval::new(config.common.ping_interval.map(u64::from)),
408            waiting_for_pong: false,
409            message: ping_message,
410        }
411    }
412
413    const fn record_pong(&mut self) {
414        self.waiting_for_pong = false;
415    }
416
417    async fn tick(&mut self, ws_sink: &mut WebSocketSink) -> Result<(), WebSocketSourceError> {
418        self.interval.tick().await;
419
420        if self.waiting_for_pong {
421            return Err(WebSocketSourceError::PongTimeout);
422        }
423
424        ws_sink.send(self.message.clone()).await.map_err(|error| {
425            emit!(WebSocketSendError { error: &error });
426            WebSocketSourceError::Tungstenite { source: error }
427        })?;
428
429        self.waiting_for_pong = true;
430        Ok(())
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use std::{borrow::Cow, num::NonZeroU64};
437
438    use futures::{StreamExt, sink::SinkExt};
439    use tokio::{net::TcpListener, time::Duration};
440    use tokio_tungstenite::{
441        accept_async,
442        tungstenite::{
443            Message,
444            protocol::frame::{CloseFrame, coding::CloseCode},
445        },
446    };
447    use url::Url;
448    use vector_lib::codecs::decoding::DeserializerConfig;
449
450    use crate::{
451        common::websocket::WebSocketCommonConfig,
452        sources::websocket::config::{PongMessage, WebSocketConfig},
453        test_util::{
454            addr::next_addr,
455            components::{
456                SOURCE_TAGS, run_and_assert_source_compliance, run_and_assert_source_error,
457            },
458        },
459    };
460
461    fn make_config(uri: &str) -> WebSocketConfig {
462        WebSocketConfig {
463            common: WebSocketCommonConfig {
464                uri: Url::parse(uri).unwrap().to_string(),
465                ..Default::default()
466            },
467            ..Default::default()
468        }
469    }
470
471    /// Starts a WebSocket server that pushes a binary message to the first client.
472    async fn start_binary_push_server() -> String {
473        let (_guard, addr) = next_addr();
474        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
475        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
476
477        tokio::spawn(async move {
478            let (stream, _) = listener.accept().await.unwrap();
479            let mut websocket = accept_async(stream).await.expect("Failed to accept");
480
481            let binary_payload = br#"{"message": "binary data"}"#.to_vec();
482            websocket
483                .send(Message::Binary(binary_payload))
484                .await
485                .unwrap();
486        });
487
488        server_addr
489    }
490
491    /// Starts a WebSocket server that pushes a message to the first client that connects.
492    async fn start_push_server() -> String {
493        let (_guard, addr) = next_addr();
494        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
495        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
496
497        tokio::spawn(async move {
498            // Accept one connection
499            let (stream, _) = listener.accept().await.unwrap();
500            let mut websocket = accept_async(stream).await.expect("Failed to accept");
501
502            // Immediately send a message to the connected client (which will be our source)
503            websocket
504                .send(Message::Text("message from server".to_string()))
505                .await
506                .unwrap();
507        });
508
509        server_addr
510    }
511
512    /// Starts a WebSocket server that waits for an initial message from the client,
513    /// and upon receiving it, sends a confirmation message back.
514    async fn start_subscribe_server(initial_message: String, response_message: String) -> String {
515        let (_guard, addr) = next_addr();
516        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
517        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
518
519        tokio::spawn(async move {
520            let (stream, _) = listener.accept().await.unwrap();
521            let mut websocket = accept_async(stream).await.expect("Failed to accept");
522
523            // Wait for the initial message from the client
524            if let Some(Ok(Message::Text(msg))) = websocket.next().await
525                && msg == initial_message
526            {
527                // Received correct initial message, send response
528                websocket
529                    .send(Message::Text(response_message))
530                    .await
531                    .unwrap();
532            }
533        });
534
535        server_addr
536    }
537
538    async fn start_reconnect_server() -> String {
539        let (_guard, addr) = next_addr();
540        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
541        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
542
543        tokio::spawn(async move {
544            // First connection
545            let (stream, _) = listener.accept().await.unwrap();
546            let mut websocket = accept_async(stream).await.expect("Failed to accept");
547            websocket
548                .send(Message::Text("first message".to_string()))
549                .await
550                .unwrap();
551            // Close the connection to force a reconnect from the client
552            websocket.close(None).await.unwrap();
553
554            // Second connection
555            let (stream, _) = listener.accept().await.unwrap();
556            let mut websocket = accept_async(stream).await.expect("Failed to accept");
557            websocket
558                .send(Message::Text("second message".to_string()))
559                .await
560                .unwrap();
561        });
562
563        server_addr
564    }
565
566    #[tokio::test(flavor = "multi_thread")]
567    async fn websocket_source_reconnects_after_disconnect() {
568        let server_addr = start_reconnect_server().await;
569        let config = make_config(&server_addr);
570
571        // Run for a longer duration to allow for reconnection
572        let events =
573            run_and_assert_source_compliance(config, Duration::from_secs(5), &SOURCE_TAGS).await;
574
575        assert_eq!(
576            events.len(),
577            2,
578            "Should have received messages from both connections"
579        );
580
581        let event = events[0].as_log();
582        assert_eq!(event["message"], "first message".into());
583
584        let event = events[1].as_log();
585        assert_eq!(event["message"], "second message".into());
586    }
587
588    #[tokio::test(flavor = "multi_thread")]
589    async fn websocket_source_consume_binary_event() {
590        let server_addr = start_binary_push_server().await;
591        let mut config = make_config(&server_addr);
592        let decoding = DeserializerConfig::Json(Default::default());
593        config.decoding = decoding;
594
595        let events =
596            run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
597
598        assert!(!events.is_empty(), "No events received from source");
599        let event = events[0].as_log();
600        assert_eq!(event["message"], "binary data".into());
601        assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
602    }
603
604    #[tokio::test(flavor = "multi_thread")]
605    async fn websocket_source_consume_event() {
606        let server_addr = start_push_server().await;
607        let config = make_config(&server_addr);
608
609        // Run the source, which will connect to the server and receive the pushed message.
610        let events =
611            run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
612
613        // Now assert that the event was received and is correct.
614        assert!(!events.is_empty(), "No events received from source");
615        let event = events[0].as_log();
616        assert_eq!(event["message"], "message from server".into());
617        assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
618    }
619
620    #[tokio::test(flavor = "multi_thread")]
621    async fn websocket_source_sends_initial_message() {
622        let initial_msg = "{\"action\":\"subscribe\",\"topic\":\"test\"}".to_string();
623        let response_msg = "{\"status\":\"subscribed\",\"topic\":\"test\"}".to_string();
624        let server_addr = start_subscribe_server(initial_msg.clone(), response_msg.clone()).await;
625
626        let mut config = make_config(&server_addr);
627        config.initial_message = Some(initial_msg);
628
629        let events =
630            run_and_assert_source_compliance(config, Duration::from_secs(2), &SOURCE_TAGS).await;
631
632        assert!(!events.is_empty(), "No events received from source");
633        let event = events[0].as_log();
634        assert_eq!(event["message"], response_msg.into());
635        assert_eq!(*event.get_source_type().unwrap(), "websocket".into());
636    }
637
638    async fn start_reject_initial_message_server() -> String {
639        let (_guard, addr) = next_addr();
640        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
641        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
642
643        tokio::spawn(async move {
644            let (stream, _) = listener.accept().await.unwrap();
645            let mut websocket = accept_async(stream).await.expect("Failed to accept");
646
647            if websocket.next().await.is_some() {
648                let close_frame = CloseFrame {
649                    code: CloseCode::Error,
650                    reason: Cow::from("Simulated Internal Server Error"),
651                };
652                let _ = websocket.close(Some(close_frame)).await;
653            }
654        });
655
656        server_addr
657    }
658
659    #[tokio::test(flavor = "multi_thread")]
660    async fn websocket_source_exits_on_rejected_intial_messsage() {
661        let server_addr = start_reject_initial_message_server().await;
662
663        let mut config = make_config(&server_addr);
664        config.initial_message = Some("hello, server!".to_string());
665        config.initial_message_timeout_secs = Duration::from_secs(1);
666
667        run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
668    }
669
670    async fn start_unresponsive_server() -> String {
671        let (_guard, addr) = next_addr();
672        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
673        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
674
675        tokio::spawn(async move {
676            if let Ok((stream, _)) = listener.accept().await {
677                // Accept the connection to establish the WebSocket.
678                let mut websocket = accept_async(stream).await.expect("Failed to accept");
679                // Simply wait forever without responding to pings.
680                while websocket.next().await.is_some() {
681                    // Do nothing
682                }
683            }
684        });
685
686        server_addr
687    }
688
689    #[tokio::test(flavor = "multi_thread")]
690    async fn websocket_source_exits_on_pong_timeout() {
691        let server_addr = start_unresponsive_server().await;
692
693        let mut config = make_config(&server_addr);
694        config.common.ping_interval = NonZeroU64::new(3);
695        config.common.ping_timeout = NonZeroU64::new(1);
696        config.ping_message = Some("ping".to_string());
697        config.pong_message = Some(PongMessage::Simple("pong".to_string()));
698
699        // The source should fail because the server never sends a pong.
700        run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
701    }
702
703    async fn start_blackhole_server() -> String {
704        let (_guard, addr) = next_addr();
705        let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
706        let server_addr = format!("ws://{}", listener.local_addr().unwrap());
707
708        tokio::spawn(async move {
709            let (mut _socket, _) = listener.accept().await.unwrap();
710            tokio::time::sleep(Duration::from_secs(10)).await;
711        });
712
713        server_addr
714    }
715
716    #[tokio::test(flavor = "multi_thread")]
717    async fn websocket_source_exits_on_connection_timeout() {
718        let server_addr = start_blackhole_server().await;
719        let mut config = make_config(&server_addr);
720        config.connect_timeout_secs = Duration::from_secs(1);
721
722        run_and_assert_source_error(config, Duration::from_secs(5), &SOURCE_TAGS).await;
723    }
724}