vector/sinks/websocket/
sink.rs

1use std::{
2    io,
3    num::NonZeroU64,
4    time::{Duration, Instant},
5};
6
7use async_trait::async_trait;
8use bytes::BytesMut;
9use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt, stream::BoxStream};
10use tokio_tungstenite::tungstenite::{error::Error as TungsteniteError, protocol::Message};
11use tokio_util::codec::Encoder as _;
12use vector_lib::{
13    EstimatedJsonEncodedSizeOf, emit,
14    internal_event::{
15        ByteSize, BytesSent, CountByteSize, EventsSent, InternalEventHandle as _, Output, Protocol,
16    },
17};
18
19use crate::{
20    codecs::{Encoder, Transformer},
21    common::websocket::{PingInterval, WebSocketConnector, is_closed},
22    event::{Event, EventStatus, Finalizable},
23    internal_events::{
24        ConnectionOpen, OpenGauge, WebSocketConnectionError, WebSocketConnectionShutdown,
25    },
26    sinks::{util::StreamSink, websocket::config::WebSocketSinkConfig},
27};
28
29pub struct WebSocketSink {
30    transformer: Transformer,
31    encoder: Encoder<()>,
32    connector: WebSocketConnector,
33    ping_interval: Option<NonZeroU64>,
34    ping_timeout: Option<NonZeroU64>,
35}
36
37impl WebSocketSink {
38    pub(crate) fn new(
39        config: &WebSocketSinkConfig,
40        connector: WebSocketConnector,
41    ) -> crate::Result<Self> {
42        let transformer = config.encoding.transformer();
43        let serializer = config.encoding.build()?;
44        let encoder = Encoder::<()>::new(serializer);
45
46        Ok(Self {
47            transformer,
48            encoder,
49            connector,
50            ping_interval: config.common.ping_interval,
51            ping_timeout: config.common.ping_timeout,
52        })
53    }
54
55    async fn create_sink_and_stream(
56        &self,
57    ) -> (
58        impl Sink<Message, Error = TungsteniteError> + use<>,
59        impl Stream<Item = Result<Message, TungsteniteError>> + use<>,
60    ) {
61        let ws_stream = self.connector.connect_backoff().await;
62        ws_stream.split()
63    }
64
65    fn check_received_pong_time(&self, last_pong: Instant) -> Result<(), TungsteniteError> {
66        if let Some(ping_timeout) = self.ping_timeout
67            && last_pong.elapsed() > Duration::from_secs(ping_timeout.into())
68        {
69            return Err(TungsteniteError::Io(io::Error::new(
70                io::ErrorKind::TimedOut,
71                "Pong not received in time",
72            )));
73        }
74
75        Ok(())
76    }
77
78    const fn should_encode_as_binary(&self) -> bool {
79        use vector_lib::codecs::encoding::Serializer::{
80            Avro, Cef, Csv, Gelf, Json, Logfmt, Native, NativeJson, Protobuf, RawMessage, Text,
81        };
82
83        match self.encoder.serializer() {
84            RawMessage(_) | Avro(_) | Native(_) | Protobuf(_) => true,
85            Cef(_) | Csv(_) | Logfmt(_) | Gelf(_) | Json(_) | Text(_) | NativeJson(_) => false,
86        }
87    }
88
89    async fn handle_events<I, WS, O>(
90        &mut self,
91        input: &mut I,
92        ws_stream: &mut WS,
93        ws_sink: &mut O,
94    ) -> Result<(), ()>
95    where
96        I: Stream<Item = Event> + Unpin,
97        WS: Stream<Item = Result<Message, TungsteniteError>> + Unpin,
98        O: Sink<Message, Error = TungsteniteError> + Unpin,
99    {
100        const PING: &[u8] = b"PING";
101
102        // tokio::time::Interval panics if the period arg is zero. Since the struct members are
103        // using NonZeroU64 that is not something we need to account for.
104        let mut ping_interval = PingInterval::new(self.ping_interval.map(u64::from));
105
106        if let Err(error) = ws_sink.send(Message::Ping(PING.to_vec())).await {
107            emit!(WebSocketConnectionError { error });
108            return Err(());
109        }
110        let mut last_pong = Instant::now();
111
112        let bytes_sent = register!(BytesSent::from(Protocol("websocket".into())));
113        let events_sent = register!(EventsSent::from(Output(None)));
114        let encode_as_binary = self.should_encode_as_binary();
115
116        loop {
117            let result = tokio::select! {
118                _ = ping_interval.tick() => {
119                    match self.check_received_pong_time(last_pong) {
120                        Ok(()) => ws_sink.send(Message::Ping(PING.to_vec())).await.map(|_| ()),
121                        Err(e) => Err(e)
122                    }
123                },
124
125                Some(msg) = ws_stream.next() => {
126                    // Pongs are sent automatically by tungstenite during reading from the stream.
127                    match msg {
128                        Ok(Message::Pong(_)) => {
129                            last_pong = Instant::now();
130                            Ok(())
131                        },
132                        Ok(_) => Ok(()),
133                        Err(e) => Err(e)
134                    }
135                },
136
137                event = input.next() => {
138                    let mut event = if let Some(event) = event {
139                        event
140                    } else {
141                        break;
142                    };
143
144                    let finalizers = event.take_finalizers();
145
146                    self.transformer.transform(&mut event);
147
148                    let event_byte_size = event.estimated_json_encoded_size_of();
149
150                    let mut bytes = BytesMut::new();
151                    match self.encoder.encode(event, &mut bytes) {
152                        Ok(()) => {
153                            finalizers.update_status(EventStatus::Delivered);
154
155                            let message = if encode_as_binary {
156                                Message::binary(bytes)
157                            }
158                            else {
159                                Message::text(String::from_utf8_lossy(&bytes))
160                            };
161                            let message_len = message.len();
162
163                            ws_sink.send(message).await.map(|_| {
164                                events_sent.emit(CountByteSize(1, event_byte_size));
165                                bytes_sent.emit(ByteSize(message_len));
166                            })
167                        },
168                        Err(_) => {
169                            // Error is handled by `Encoder`.
170                            finalizers.update_status(EventStatus::Errored);
171                            Ok(())
172                        }
173                    }
174                },
175                else => break,
176            };
177
178            if let Err(error) = result {
179                if is_closed(&error) {
180                    emit!(WebSocketConnectionShutdown);
181                } else {
182                    emit!(WebSocketConnectionError { error });
183                }
184                return Err(());
185            }
186        }
187
188        Ok(())
189    }
190}
191
192#[async_trait]
193impl StreamSink<Event> for WebSocketSink {
194    async fn run(mut self: Box<Self>, input: BoxStream<'_, Event>) -> Result<(), ()> {
195        let input = input.fuse().peekable();
196        pin_mut!(input);
197
198        while input.as_mut().peek().await.is_some() {
199            let (ws_sink, ws_stream) = self.create_sink_and_stream().await;
200            pin_mut!(ws_sink);
201            pin_mut!(ws_stream);
202
203            let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
204
205            if self
206                .handle_events(&mut input, &mut ws_stream, &mut ws_sink)
207                .await
208                .is_ok()
209            {
210                _ = ws_sink.close().await;
211            }
212        }
213
214        Ok(())
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use std::net::SocketAddr;
221
222    use futures::{FutureExt, StreamExt, future};
223    use serde_json::Value as JsonValue;
224    use tokio::{time, time::timeout};
225    use tokio_tungstenite::{
226        accept_async, accept_hdr_async,
227        tungstenite::{
228            error::ProtocolError,
229            handshake::server::{Request, Response},
230        },
231    };
232    use vector_lib::codecs::JsonSerializerConfig;
233
234    use super::*;
235    use crate::{
236        common::websocket::WebSocketCommonConfig,
237        config::{SinkConfig, SinkContext},
238        http::Auth,
239        test_util::{
240            CountReceiver,
241            components::{SINK_TAGS, run_and_assert_sink_compliance},
242            next_addr, random_lines_with_stream, trace_init,
243        },
244        tls::{self, MaybeTlsSettings, TlsConfig, TlsEnableableConfig},
245    };
246
247    #[tokio::test(flavor = "multi_thread")]
248    async fn test_websocket() {
249        trace_init();
250
251        let addr = next_addr();
252        let config = WebSocketSinkConfig {
253            common: WebSocketCommonConfig {
254                uri: format!("ws://{addr}"),
255                tls: None,
256                ping_interval: None,
257                ping_timeout: None,
258                auth: None,
259            },
260            encoding: JsonSerializerConfig::default().into(),
261            acknowledgements: Default::default(),
262        };
263        let tls = MaybeTlsSettings::Raw(());
264
265        send_events_and_assert(addr, config, tls, None).await;
266    }
267
268    #[tokio::test(flavor = "multi_thread")]
269    async fn test_auth_websocket() {
270        trace_init();
271
272        let auth = Some(Auth::Bearer {
273            token: "OiJIUzI1NiIsInR5cCI6IkpXVCJ".to_string().into(),
274        });
275        let auth_clone = auth.clone();
276        let addr = next_addr();
277        let config = WebSocketSinkConfig {
278            common: WebSocketCommonConfig {
279                uri: format!("ws://{addr}"),
280                tls: None,
281                ping_interval: None,
282                ping_timeout: None,
283                auth: None,
284            },
285            encoding: JsonSerializerConfig::default().into(),
286            acknowledgements: Default::default(),
287        };
288        let tls = MaybeTlsSettings::Raw(());
289
290        send_events_and_assert(addr, config, tls, auth_clone).await;
291    }
292
293    #[tokio::test(flavor = "multi_thread")]
294    async fn test_tls_websocket() {
295        trace_init();
296
297        let addr = next_addr();
298        let tls_config = Some(TlsEnableableConfig::test_config());
299        let tls = MaybeTlsSettings::from_config(tls_config.as_ref(), true).unwrap();
300
301        let config = WebSocketSinkConfig {
302            common: WebSocketCommonConfig {
303                uri: format!("wss://{addr}"),
304                tls: Some(TlsEnableableConfig {
305                    enabled: Some(true),
306                    options: TlsConfig {
307                        verify_certificate: Some(false),
308                        verify_hostname: Some(true),
309                        ca_file: Some(tls::TEST_PEM_CRT_PATH.into()),
310                        ..Default::default()
311                    },
312                }),
313                ping_timeout: None,
314                ping_interval: None,
315                auth: None,
316            },
317            encoding: JsonSerializerConfig::default().into(),
318            acknowledgements: Default::default(),
319        };
320
321        send_events_and_assert(addr, config, tls, None).await;
322    }
323
324    #[tokio::test]
325    async fn test_websocket_reconnect() {
326        trace_init();
327
328        let addr = next_addr();
329        let config = WebSocketSinkConfig {
330            common: WebSocketCommonConfig {
331                uri: format!("ws://{addr}"),
332                tls: None,
333                ping_interval: None,
334                ping_timeout: None,
335                auth: None,
336            },
337            encoding: JsonSerializerConfig::default().into(),
338            acknowledgements: Default::default(),
339        };
340        let tls = MaybeTlsSettings::Raw(());
341
342        let mut receiver = create_count_receiver(addr, tls.clone(), true, None);
343
344        let context = SinkContext::default();
345        let (sink, _healthcheck) = config.build(context).await.unwrap();
346
347        let (_lines, events) = random_lines_with_stream(10, 100, None);
348        let events = events.then(|event| async move {
349            time::sleep(Duration::from_millis(10)).await;
350            event
351        });
352        drop(tokio::spawn(sink.run(events)));
353
354        receiver.connected().await;
355        time::sleep(Duration::from_millis(500)).await;
356        assert!(!receiver.await.is_empty());
357
358        let mut receiver = create_count_receiver(addr, tls, false, None);
359        assert!(
360            timeout(Duration::from_secs(10), receiver.connected())
361                .await
362                .is_ok()
363        );
364    }
365
366    async fn send_events_and_assert(
367        addr: SocketAddr,
368        config: WebSocketSinkConfig,
369        tls: MaybeTlsSettings,
370        auth: Option<Auth>,
371    ) {
372        let mut receiver = create_count_receiver(addr, tls, false, auth);
373
374        let context = SinkContext::default();
375        let (sink, _healthcheck) = config.build(context).await.unwrap();
376
377        let (lines, events) = random_lines_with_stream(10, 100, None);
378        run_and_assert_sink_compliance(sink, events, &SINK_TAGS).await;
379
380        receiver.connected().await;
381
382        let output = receiver.await;
383        assert_eq!(lines.len(), output.len());
384        let message_key = crate::config::log_schema()
385            .message_key()
386            .expect("global log_schema.message_key to be valid path")
387            .to_string();
388        for (source, received) in lines.iter().zip(output) {
389            let json = serde_json::from_str::<JsonValue>(&received).expect("Invalid JSON");
390            let received = json.get(message_key.as_str()).unwrap().as_str().unwrap();
391            assert_eq!(source, received);
392        }
393    }
394
395    fn create_count_receiver(
396        addr: SocketAddr,
397        tls: MaybeTlsSettings,
398        interrupt_stream: bool,
399        auth: Option<Auth>,
400    ) -> CountReceiver<String> {
401        CountReceiver::receive_items_stream(move |tripwire, connected| async move {
402            let listener = tls.bind(&addr).await.unwrap();
403            let stream = listener.accept_stream();
404
405            let tripwire = tripwire.map(|_| ()).shared();
406            let stream_tripwire = tripwire.clone();
407            let mut connected = Some(connected);
408
409            let stream = stream
410                .take_until(tripwire)
411                .filter_map(move |maybe_tls_stream| {
412                    let au = auth.clone();
413                    async move {
414                        let maybe_tls_stream = maybe_tls_stream.unwrap();
415                        let ws_stream = match au {
416                            Some(a) => {
417                                let auth_callback = |req: &Request, res: Response| {
418                                    let hdr = req.headers().get("Authorization");
419                                    if let Some(h) = hdr {
420                                        match a {
421                                            Auth::Bearer { token } => {
422                                                if format!("Bearer {}", token.inner())
423                                                    != h.to_str().unwrap()
424                                                {
425                                                    return Err(
426                                                        http::Response::<Option<String>>::new(None),
427                                                    );
428                                                }
429                                            }
430                                            Auth::Basic {
431                                                user: _user,
432                                                password: _password,
433                                            } => { /* Not needed for tests at the moment */ }
434                                            #[cfg(feature = "aws-core")]
435                                            _ => {}
436                                        }
437                                    }
438                                    Ok(res)
439                                };
440                                accept_hdr_async(maybe_tls_stream, auth_callback)
441                                    .await
442                                    .unwrap()
443                            }
444                            None => accept_async(maybe_tls_stream).await.unwrap(),
445                        };
446
447                        Some(
448                            ws_stream
449                                .filter_map(|msg| {
450                                    future::ready(match msg {
451                                        Ok(msg) if msg.is_text() => {
452                                            Some(Ok(msg.into_text().unwrap()))
453                                        }
454                                        Err(TungsteniteError::Protocol(
455                                            ProtocolError::ResetWithoutClosingHandshake,
456                                        )) => None,
457                                        Err(e) => Some(Err(e)),
458                                        _ => None,
459                                    })
460                                })
461                                .take_while(|msg| future::ready(msg.is_ok()))
462                                .filter_map(|msg| future::ready(msg.ok())),
463                        )
464                    }
465                })
466                .map(move |ws_stream| {
467                    connected.take().map(|trigger| trigger.send(()));
468                    ws_stream
469                })
470                .flatten();
471
472            match interrupt_stream {
473                false => stream.boxed(),
474                true => stream.take_until(stream_tripwire).boxed(),
475            }
476        })
477    }
478}