vector/sinks/websocket/
sink.rs

1use std::{
2    io,
3    num::NonZeroU64,
4    time::{Duration, Instant},
5};
6
7use crate::{
8    codecs::{Encoder, Transformer},
9    common::websocket::{PingInterval, WebSocketConnector, is_closed},
10    event::{Event, EventStatus, Finalizable},
11    internal_events::{
12        ConnectionOpen, OpenGauge, WebSocketConnectionError, WebSocketConnectionShutdown,
13    },
14    sinks::{util::StreamSink, websocket::config::WebSocketSinkConfig},
15};
16use async_trait::async_trait;
17use bytes::BytesMut;
18use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt, stream::BoxStream};
19use tokio_tungstenite::tungstenite::{error::Error as TungsteniteError, protocol::Message};
20use tokio_util::codec::Encoder as _;
21use vector_lib::{
22    EstimatedJsonEncodedSizeOf, emit,
23    internal_event::{
24        ByteSize, BytesSent, CountByteSize, EventsSent, InternalEventHandle as _, Output, Protocol,
25    },
26};
27
28pub struct WebSocketSink {
29    transformer: Transformer,
30    encoder: Encoder<()>,
31    connector: WebSocketConnector,
32    ping_interval: Option<NonZeroU64>,
33    ping_timeout: Option<NonZeroU64>,
34}
35
36impl WebSocketSink {
37    pub(crate) fn new(
38        config: &WebSocketSinkConfig,
39        connector: WebSocketConnector,
40    ) -> crate::Result<Self> {
41        let transformer = config.encoding.transformer();
42        let serializer = config.encoding.build()?;
43        let encoder = Encoder::<()>::new(serializer);
44
45        Ok(Self {
46            transformer,
47            encoder,
48            connector,
49            ping_interval: config.common.ping_interval,
50            ping_timeout: config.common.ping_timeout,
51        })
52    }
53
54    async fn create_sink_and_stream(
55        &self,
56    ) -> (
57        impl Sink<Message, Error = TungsteniteError> + use<>,
58        impl Stream<Item = Result<Message, TungsteniteError>> + use<>,
59    ) {
60        let ws_stream = self.connector.connect_backoff().await;
61        ws_stream.split()
62    }
63
64    fn check_received_pong_time(&self, last_pong: Instant) -> Result<(), TungsteniteError> {
65        if let Some(ping_timeout) = self.ping_timeout
66            && last_pong.elapsed() > Duration::from_secs(ping_timeout.into())
67        {
68            return Err(TungsteniteError::Io(io::Error::new(
69                io::ErrorKind::TimedOut,
70                "Pong not received in time",
71            )));
72        }
73
74        Ok(())
75    }
76
77    async fn handle_events<I, WS, O>(
78        &mut self,
79        input: &mut I,
80        ws_stream: &mut WS,
81        ws_sink: &mut O,
82    ) -> Result<(), ()>
83    where
84        I: Stream<Item = Event> + Unpin,
85        WS: Stream<Item = Result<Message, TungsteniteError>> + Unpin,
86        O: Sink<Message, Error = TungsteniteError> + Unpin,
87    {
88        const PING: &[u8] = b"PING";
89
90        // tokio::time::Interval panics if the period arg is zero. Since the struct members are
91        // using NonZeroU64 that is not something we need to account for.
92        let mut ping_interval = PingInterval::new(self.ping_interval.map(u64::from));
93
94        if let Err(error) = ws_sink.send(Message::Ping(PING.to_vec())).await {
95            emit!(WebSocketConnectionError { error });
96            return Err(());
97        }
98        let mut last_pong = Instant::now();
99
100        let bytes_sent = register!(BytesSent::from(Protocol("websocket".into())));
101        let events_sent = register!(EventsSent::from(Output(None)));
102        let encode_as_binary = self.encoder.serializer().is_binary();
103
104        loop {
105            let result = tokio::select! {
106                _ = ping_interval.tick() => {
107                    match self.check_received_pong_time(last_pong) {
108                        Ok(()) => ws_sink.send(Message::Ping(PING.to_vec())).await.map(|_| ()),
109                        Err(e) => Err(e)
110                    }
111                },
112
113                Some(msg) = ws_stream.next() => {
114                    // Pongs are sent automatically by tungstenite during reading from the stream.
115                    match msg {
116                        Ok(Message::Pong(_)) => {
117                            last_pong = Instant::now();
118                            Ok(())
119                        },
120                        Ok(_) => Ok(()),
121                        Err(e) => Err(e)
122                    }
123                },
124
125                event = input.next() => {
126                    let mut event = if let Some(event) = event {
127                        event
128                    } else {
129                        break;
130                    };
131
132                    let finalizers = event.take_finalizers();
133
134                    self.transformer.transform(&mut event);
135
136                    let event_byte_size = event.estimated_json_encoded_size_of();
137
138                    let mut bytes = BytesMut::new();
139                    match self.encoder.encode(event, &mut bytes) {
140                        Ok(()) => {
141                            finalizers.update_status(EventStatus::Delivered);
142
143                            let message = if encode_as_binary {
144                                Message::binary(bytes)
145                            }
146                            else {
147                                Message::text(String::from_utf8_lossy(&bytes))
148                            };
149                            let message_len = message.len();
150
151                            ws_sink.send(message).await.map(|_| {
152                                events_sent.emit(CountByteSize(1, event_byte_size));
153                                bytes_sent.emit(ByteSize(message_len));
154                            })
155                        },
156                        Err(_) => {
157                            // Error is handled by `Encoder`.
158                            finalizers.update_status(EventStatus::Errored);
159                            Ok(())
160                        }
161                    }
162                },
163                else => break,
164            };
165
166            if let Err(error) = result {
167                if is_closed(&error) {
168                    emit!(WebSocketConnectionShutdown);
169                } else {
170                    emit!(WebSocketConnectionError { error });
171                }
172                return Err(());
173            }
174        }
175
176        Ok(())
177    }
178}
179
180#[async_trait]
181impl StreamSink<Event> for WebSocketSink {
182    async fn run(mut self: Box<Self>, input: BoxStream<'_, Event>) -> Result<(), ()> {
183        let input = input.fuse().peekable();
184        pin_mut!(input);
185
186        while input.as_mut().peek().await.is_some() {
187            let (ws_sink, ws_stream) = self.create_sink_and_stream().await;
188            pin_mut!(ws_sink);
189            pin_mut!(ws_stream);
190
191            let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
192
193            if self
194                .handle_events(&mut input, &mut ws_stream, &mut ws_sink)
195                .await
196                .is_ok()
197            {
198                _ = ws_sink.close().await;
199            }
200        }
201
202        Ok(())
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use std::net::SocketAddr;
209
210    use futures::{FutureExt, StreamExt, future};
211    use serde_json::Value as JsonValue;
212    use tokio::{time, time::timeout};
213    use tokio_tungstenite::{
214        accept_async, accept_hdr_async,
215        tungstenite::{
216            error::ProtocolError,
217            handshake::server::{Request, Response},
218        },
219    };
220    use vector_lib::codecs::JsonSerializerConfig;
221
222    use super::*;
223    use crate::{
224        common::websocket::WebSocketCommonConfig,
225        config::{SinkConfig, SinkContext},
226        http::Auth,
227        test_util::{
228            CountReceiver,
229            components::{SINK_TAGS, run_and_assert_sink_compliance},
230            next_addr, random_lines_with_stream, trace_init,
231        },
232        tls::{self, MaybeTlsSettings, TlsConfig, TlsEnableableConfig},
233    };
234
235    #[tokio::test(flavor = "multi_thread")]
236    async fn test_websocket() {
237        trace_init();
238
239        let addr = next_addr();
240        let config = WebSocketSinkConfig {
241            common: WebSocketCommonConfig {
242                uri: format!("ws://{addr}"),
243                tls: None,
244                ping_interval: None,
245                ping_timeout: None,
246                auth: None,
247            },
248            encoding: JsonSerializerConfig::default().into(),
249            acknowledgements: Default::default(),
250        };
251        let tls = MaybeTlsSettings::Raw(());
252
253        send_events_and_assert(addr, config, tls, None).await;
254    }
255
256    #[tokio::test(flavor = "multi_thread")]
257    async fn test_auth_websocket() {
258        trace_init();
259
260        let auth = Some(Auth::Bearer {
261            token: "OiJIUzI1NiIsInR5cCI6IkpXVCJ".to_string().into(),
262        });
263        let auth_clone = auth.clone();
264        let addr = next_addr();
265        let config = WebSocketSinkConfig {
266            common: WebSocketCommonConfig {
267                uri: format!("ws://{addr}"),
268                tls: None,
269                ping_interval: None,
270                ping_timeout: None,
271                auth: None,
272            },
273            encoding: JsonSerializerConfig::default().into(),
274            acknowledgements: Default::default(),
275        };
276        let tls = MaybeTlsSettings::Raw(());
277
278        send_events_and_assert(addr, config, tls, auth_clone).await;
279    }
280
281    #[tokio::test(flavor = "multi_thread")]
282    async fn test_tls_websocket() {
283        trace_init();
284
285        let addr = next_addr();
286        let tls_config = Some(TlsEnableableConfig::test_config());
287        let tls = MaybeTlsSettings::from_config(tls_config.as_ref(), true).unwrap();
288
289        let config = WebSocketSinkConfig {
290            common: WebSocketCommonConfig {
291                uri: format!("wss://{addr}"),
292                tls: Some(TlsEnableableConfig {
293                    enabled: Some(true),
294                    options: TlsConfig {
295                        verify_certificate: Some(false),
296                        verify_hostname: Some(true),
297                        ca_file: Some(tls::TEST_PEM_CRT_PATH.into()),
298                        ..Default::default()
299                    },
300                }),
301                ping_timeout: None,
302                ping_interval: None,
303                auth: None,
304            },
305            encoding: JsonSerializerConfig::default().into(),
306            acknowledgements: Default::default(),
307        };
308
309        send_events_and_assert(addr, config, tls, None).await;
310    }
311
312    #[tokio::test]
313    async fn test_websocket_reconnect() {
314        trace_init();
315
316        let addr = next_addr();
317        let config = WebSocketSinkConfig {
318            common: WebSocketCommonConfig {
319                uri: format!("ws://{addr}"),
320                tls: None,
321                ping_interval: None,
322                ping_timeout: None,
323                auth: None,
324            },
325            encoding: JsonSerializerConfig::default().into(),
326            acknowledgements: Default::default(),
327        };
328        let tls = MaybeTlsSettings::Raw(());
329
330        let mut receiver = create_count_receiver(addr, tls.clone(), true, None);
331
332        let context = SinkContext::default();
333        let (sink, _healthcheck) = config.build(context).await.unwrap();
334
335        let (_lines, events) = random_lines_with_stream(10, 100, None);
336        let events = events.then(|event| async move {
337            time::sleep(Duration::from_millis(10)).await;
338            event
339        });
340        drop(tokio::spawn(sink.run(events)));
341
342        receiver.connected().await;
343        time::sleep(Duration::from_millis(500)).await;
344        assert!(!receiver.await.is_empty());
345
346        let mut receiver = create_count_receiver(addr, tls, false, None);
347        assert!(
348            timeout(Duration::from_secs(10), receiver.connected())
349                .await
350                .is_ok()
351        );
352    }
353
354    async fn send_events_and_assert(
355        addr: SocketAddr,
356        config: WebSocketSinkConfig,
357        tls: MaybeTlsSettings,
358        auth: Option<Auth>,
359    ) {
360        let mut receiver = create_count_receiver(addr, tls, false, auth);
361
362        let context = SinkContext::default();
363        let (sink, _healthcheck) = config.build(context).await.unwrap();
364
365        let (lines, events) = random_lines_with_stream(10, 100, None);
366        run_and_assert_sink_compliance(sink, events, &SINK_TAGS).await;
367
368        receiver.connected().await;
369
370        let output = receiver.await;
371        assert_eq!(lines.len(), output.len());
372        let message_key = crate::config::log_schema()
373            .message_key()
374            .expect("global log_schema.message_key to be valid path")
375            .to_string();
376        for (source, received) in lines.iter().zip(output) {
377            let json = serde_json::from_str::<JsonValue>(&received).expect("Invalid JSON");
378            let received = json.get(message_key.as_str()).unwrap().as_str().unwrap();
379            assert_eq!(source, received);
380        }
381    }
382
383    fn create_count_receiver(
384        addr: SocketAddr,
385        tls: MaybeTlsSettings,
386        interrupt_stream: bool,
387        auth: Option<Auth>,
388    ) -> CountReceiver<String> {
389        CountReceiver::receive_items_stream(move |tripwire, connected| async move {
390            let listener = tls.bind(&addr).await.unwrap();
391            let stream = listener.accept_stream();
392
393            let tripwire = tripwire.map(|_| ()).shared();
394            let stream_tripwire = tripwire.clone();
395            let mut connected = Some(connected);
396
397            let stream = stream
398                .take_until(tripwire)
399                .filter_map(move |maybe_tls_stream| {
400                    let au = auth.clone();
401                    async move {
402                        let maybe_tls_stream = maybe_tls_stream.unwrap();
403                        let ws_stream = match au {
404                            Some(a) => {
405                                let auth_callback = |req: &Request, res: Response| {
406                                    let hdr = req.headers().get("Authorization");
407                                    if let Some(h) = hdr {
408                                        match a {
409                                            Auth::Bearer { token } => {
410                                                if format!("Bearer {}", token.inner())
411                                                    != h.to_str().unwrap()
412                                                {
413                                                    return Err(
414                                                        http::Response::<Option<String>>::new(None),
415                                                    );
416                                                }
417                                            }
418                                            Auth::Basic {
419                                                user: _user,
420                                                password: _password,
421                                            } => { /* Not needed for tests at the moment */ }
422                                            #[cfg(feature = "aws-core")]
423                                            _ => {}
424                                        }
425                                    }
426                                    Ok(res)
427                                };
428                                accept_hdr_async(maybe_tls_stream, auth_callback)
429                                    .await
430                                    .unwrap()
431                            }
432                            None => accept_async(maybe_tls_stream).await.unwrap(),
433                        };
434
435                        Some(
436                            ws_stream
437                                .filter_map(|msg| {
438                                    future::ready(match msg {
439                                        Ok(msg) if msg.is_text() => {
440                                            Some(Ok(msg.into_text().unwrap()))
441                                        }
442                                        Err(TungsteniteError::Protocol(
443                                            ProtocolError::ResetWithoutClosingHandshake,
444                                        )) => None,
445                                        Err(e) => Some(Err(e)),
446                                        _ => None,
447                                    })
448                                })
449                                .take_while(|msg| future::ready(msg.is_ok()))
450                                .filter_map(|msg| future::ready(msg.ok())),
451                        )
452                    }
453                })
454                .map(move |ws_stream| {
455                    connected.take().map(|trigger| trigger.send(()));
456                    ws_stream
457                })
458                .flatten();
459
460            match interrupt_stream {
461                false => stream.boxed(),
462                true => stream.take_until(stream_tripwire).boxed(),
463            }
464        })
465    }
466}