vector/sinks/websocket/
sink.rs

1use std::{
2    io,
3    num::NonZeroU64,
4    time::{Duration, Instant},
5};
6
7use async_trait::async_trait;
8use bytes::BytesMut;
9use futures::{pin_mut, sink::SinkExt, stream::BoxStream, Sink, Stream, StreamExt};
10use tokio_tungstenite::tungstenite::{error::Error as TungsteniteError, protocol::Message};
11use tokio_util::codec::Encoder as _;
12use vector_lib::{
13    emit,
14    internal_event::{
15        ByteSize, BytesSent, CountByteSize, EventsSent, InternalEventHandle as _, Output, Protocol,
16    },
17    EstimatedJsonEncodedSizeOf,
18};
19
20use crate::{
21    codecs::{Encoder, Transformer},
22    common::websocket::{is_closed, PingInterval, WebSocketConnector},
23    event::{Event, EventStatus, Finalizable},
24    internal_events::{
25        ConnectionOpen, OpenGauge, WebSocketConnectionError, WebSocketConnectionShutdown,
26    },
27    sinks::util::StreamSink,
28    sinks::websocket::config::WebSocketSinkConfig,
29};
30
31pub struct WebSocketSink {
32    transformer: Transformer,
33    encoder: Encoder<()>,
34    connector: WebSocketConnector,
35    ping_interval: Option<NonZeroU64>,
36    ping_timeout: Option<NonZeroU64>,
37}
38
39impl WebSocketSink {
40    pub(crate) fn new(
41        config: &WebSocketSinkConfig,
42        connector: WebSocketConnector,
43    ) -> crate::Result<Self> {
44        let transformer = config.encoding.transformer();
45        let serializer = config.encoding.build()?;
46        let encoder = Encoder::<()>::new(serializer);
47
48        Ok(Self {
49            transformer,
50            encoder,
51            connector,
52            ping_interval: config.common.ping_interval,
53            ping_timeout: config.common.ping_timeout,
54        })
55    }
56
57    async fn create_sink_and_stream(
58        &self,
59    ) -> (
60        impl Sink<Message, Error = TungsteniteError>,
61        impl Stream<Item = Result<Message, TungsteniteError>>,
62    ) {
63        let ws_stream = self.connector.connect_backoff().await;
64        ws_stream.split()
65    }
66
67    fn check_received_pong_time(&self, last_pong: Instant) -> Result<(), TungsteniteError> {
68        if let Some(ping_timeout) = self.ping_timeout {
69            if last_pong.elapsed() > Duration::from_secs(ping_timeout.into()) {
70                return Err(TungsteniteError::Io(io::Error::new(
71                    io::ErrorKind::TimedOut,
72                    "Pong not received in time",
73                )));
74            }
75        }
76
77        Ok(())
78    }
79
80    const fn should_encode_as_binary(&self) -> bool {
81        use vector_lib::codecs::encoding::Serializer::{
82            Avro, Cef, Csv, Gelf, Json, Logfmt, Native, NativeJson, Protobuf, RawMessage, Text,
83        };
84
85        match self.encoder.serializer() {
86            RawMessage(_) | Avro(_) | Native(_) | Protobuf(_) => true,
87            Cef(_) | Csv(_) | Logfmt(_) | Gelf(_) | Json(_) | Text(_) | NativeJson(_) => false,
88        }
89    }
90
91    async fn handle_events<I, WS, O>(
92        &mut self,
93        input: &mut I,
94        ws_stream: &mut WS,
95        ws_sink: &mut O,
96    ) -> Result<(), ()>
97    where
98        I: Stream<Item = Event> + Unpin,
99        WS: Stream<Item = Result<Message, TungsteniteError>> + Unpin,
100        O: Sink<Message, Error = TungsteniteError> + Unpin,
101    {
102        const PING: &[u8] = b"PING";
103
104        // tokio::time::Interval panics if the period arg is zero. Since the struct members are
105        // using NonZeroU64 that is not something we need to account for.
106        let mut ping_interval = PingInterval::new(self.ping_interval.map(u64::from));
107
108        if let Err(error) = ws_sink.send(Message::Ping(PING.to_vec())).await {
109            emit!(WebSocketConnectionError { error });
110            return Err(());
111        }
112        let mut last_pong = Instant::now();
113
114        let bytes_sent = register!(BytesSent::from(Protocol("websocket".into())));
115        let events_sent = register!(EventsSent::from(Output(None)));
116        let encode_as_binary = self.should_encode_as_binary();
117
118        loop {
119            let result = tokio::select! {
120                _ = ping_interval.tick() => {
121                    match self.check_received_pong_time(last_pong) {
122                        Ok(()) => ws_sink.send(Message::Ping(PING.to_vec())).await.map(|_| ()),
123                        Err(e) => Err(e)
124                    }
125                },
126
127                Some(msg) = ws_stream.next() => {
128                    // Pongs are sent automatically by tungstenite during reading from the stream.
129                    match msg {
130                        Ok(Message::Pong(_)) => {
131                            last_pong = Instant::now();
132                            Ok(())
133                        },
134                        Ok(_) => Ok(()),
135                        Err(e) => Err(e)
136                    }
137                },
138
139                event = input.next() => {
140                    let mut event = if let Some(event) = event {
141                        event
142                    } else {
143                        break;
144                    };
145
146                    let finalizers = event.take_finalizers();
147
148                    self.transformer.transform(&mut event);
149
150                    let event_byte_size = event.estimated_json_encoded_size_of();
151
152                    let mut bytes = BytesMut::new();
153                    let res = match self.encoder.encode(event, &mut bytes) {
154                        Ok(()) => {
155                            finalizers.update_status(EventStatus::Delivered);
156
157                            let message = if encode_as_binary {
158                                Message::binary(bytes)
159                            }
160                            else {
161                                Message::text(String::from_utf8_lossy(&bytes))
162                            };
163                            let message_len = message.len();
164
165                            ws_sink.send(message).await.map(|_| {
166                                events_sent.emit(CountByteSize(1, event_byte_size));
167                                bytes_sent.emit(ByteSize(message_len));
168                            })
169                        },
170                        Err(_) => {
171                            // Error is handled by `Encoder`.
172                            finalizers.update_status(EventStatus::Errored);
173                            Ok(())
174                        }
175                    };
176
177                    res
178                },
179                else => break,
180            };
181
182            if let Err(error) = result {
183                if is_closed(&error) {
184                    emit!(WebSocketConnectionShutdown);
185                } else {
186                    emit!(WebSocketConnectionError { error });
187                }
188                return Err(());
189            }
190        }
191
192        Ok(())
193    }
194}
195
196#[async_trait]
197impl StreamSink<Event> for WebSocketSink {
198    async fn run(mut self: Box<Self>, input: BoxStream<'_, Event>) -> Result<(), ()> {
199        let input = input.fuse().peekable();
200        pin_mut!(input);
201
202        while input.as_mut().peek().await.is_some() {
203            let (ws_sink, ws_stream) = self.create_sink_and_stream().await;
204            pin_mut!(ws_sink);
205            pin_mut!(ws_stream);
206
207            let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
208
209            if self
210                .handle_events(&mut input, &mut ws_stream, &mut ws_sink)
211                .await
212                .is_ok()
213            {
214                _ = ws_sink.close().await;
215            }
216        }
217
218        Ok(())
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use std::net::SocketAddr;
225
226    use futures::{future, FutureExt, StreamExt};
227    use serde_json::Value as JsonValue;
228    use tokio::{time, time::timeout};
229    use tokio_tungstenite::{
230        accept_async, accept_hdr_async,
231        tungstenite::{
232            error::ProtocolError,
233            handshake::server::{Request, Response},
234        },
235    };
236    use vector_lib::codecs::JsonSerializerConfig;
237
238    use super::*;
239    use crate::{
240        common::websocket::WebSocketCommonConfig,
241        config::{SinkConfig, SinkContext},
242        http::Auth,
243        test_util::{
244            components::{run_and_assert_sink_compliance, SINK_TAGS},
245            next_addr, random_lines_with_stream, trace_init, CountReceiver,
246        },
247        tls::{self, MaybeTlsSettings, TlsConfig, TlsEnableableConfig},
248    };
249
250    #[tokio::test(flavor = "multi_thread")]
251    async fn test_websocket() {
252        trace_init();
253
254        let addr = next_addr();
255        let config = WebSocketSinkConfig {
256            common: WebSocketCommonConfig {
257                uri: format!("ws://{addr}"),
258                tls: None,
259                ping_interval: None,
260                ping_timeout: None,
261                auth: None,
262            },
263            encoding: JsonSerializerConfig::default().into(),
264            acknowledgements: Default::default(),
265        };
266        let tls = MaybeTlsSettings::Raw(());
267
268        send_events_and_assert(addr, config, tls, None).await;
269    }
270
271    #[tokio::test(flavor = "multi_thread")]
272    async fn test_auth_websocket() {
273        trace_init();
274
275        let auth = Some(Auth::Bearer {
276            token: "OiJIUzI1NiIsInR5cCI6IkpXVCJ".to_string().into(),
277        });
278        let auth_clone = auth.clone();
279        let addr = next_addr();
280        let config = WebSocketSinkConfig {
281            common: WebSocketCommonConfig {
282                uri: format!("ws://{addr}"),
283                tls: None,
284                ping_interval: None,
285                ping_timeout: None,
286                auth: None,
287            },
288            encoding: JsonSerializerConfig::default().into(),
289            acknowledgements: Default::default(),
290        };
291        let tls = MaybeTlsSettings::Raw(());
292
293        send_events_and_assert(addr, config, tls, auth_clone).await;
294    }
295
296    #[tokio::test(flavor = "multi_thread")]
297    async fn test_tls_websocket() {
298        trace_init();
299
300        let addr = next_addr();
301        let tls_config = Some(TlsEnableableConfig::test_config());
302        let tls = MaybeTlsSettings::from_config(tls_config.as_ref(), true).unwrap();
303
304        let config = WebSocketSinkConfig {
305            common: WebSocketCommonConfig {
306                uri: format!("wss://{addr}"),
307                tls: Some(TlsEnableableConfig {
308                    enabled: Some(true),
309                    options: TlsConfig {
310                        verify_certificate: Some(false),
311                        verify_hostname: Some(true),
312                        ca_file: Some(tls::TEST_PEM_CRT_PATH.into()),
313                        ..Default::default()
314                    },
315                }),
316                ping_timeout: None,
317                ping_interval: None,
318                auth: None,
319            },
320            encoding: JsonSerializerConfig::default().into(),
321            acknowledgements: Default::default(),
322        };
323
324        send_events_and_assert(addr, config, tls, None).await;
325    }
326
327    #[tokio::test]
328    async fn test_websocket_reconnect() {
329        trace_init();
330
331        let addr = next_addr();
332        let config = WebSocketSinkConfig {
333            common: WebSocketCommonConfig {
334                uri: format!("ws://{addr}"),
335                tls: None,
336                ping_interval: None,
337                ping_timeout: None,
338                auth: None,
339            },
340            encoding: JsonSerializerConfig::default().into(),
341            acknowledgements: Default::default(),
342        };
343        let tls = MaybeTlsSettings::Raw(());
344
345        let mut receiver = create_count_receiver(addr, tls.clone(), true, None);
346
347        let context = SinkContext::default();
348        let (sink, _healthcheck) = config.build(context).await.unwrap();
349
350        let (_lines, events) = random_lines_with_stream(10, 100, None);
351        let events = events.then(|event| async move {
352            time::sleep(Duration::from_millis(10)).await;
353            event
354        });
355        drop(tokio::spawn(sink.run(events)));
356
357        receiver.connected().await;
358        time::sleep(Duration::from_millis(500)).await;
359        assert!(!receiver.await.is_empty());
360
361        let mut receiver = create_count_receiver(addr, tls, false, None);
362        assert!(timeout(Duration::from_secs(10), receiver.connected())
363            .await
364            .is_ok());
365    }
366
367    async fn send_events_and_assert(
368        addr: SocketAddr,
369        config: WebSocketSinkConfig,
370        tls: MaybeTlsSettings,
371        auth: Option<Auth>,
372    ) {
373        let mut receiver = create_count_receiver(addr, tls, false, auth);
374
375        let context = SinkContext::default();
376        let (sink, _healthcheck) = config.build(context).await.unwrap();
377
378        let (lines, events) = random_lines_with_stream(10, 100, None);
379        run_and_assert_sink_compliance(sink, events, &SINK_TAGS).await;
380
381        receiver.connected().await;
382
383        let output = receiver.await;
384        assert_eq!(lines.len(), output.len());
385        let message_key = crate::config::log_schema()
386            .message_key()
387            .expect("global log_schema.message_key to be valid path")
388            .to_string();
389        for (source, received) in lines.iter().zip(output) {
390            let json = serde_json::from_str::<JsonValue>(&received).expect("Invalid JSON");
391            let received = json.get(message_key.as_str()).unwrap().as_str().unwrap();
392            assert_eq!(source, received);
393        }
394    }
395
396    fn create_count_receiver(
397        addr: SocketAddr,
398        tls: MaybeTlsSettings,
399        interrupt_stream: bool,
400        auth: Option<Auth>,
401    ) -> CountReceiver<String> {
402        CountReceiver::receive_items_stream(move |tripwire, connected| async move {
403            let listener = tls.bind(&addr).await.unwrap();
404            let stream = listener.accept_stream();
405
406            let tripwire = tripwire.map(|_| ()).shared();
407            let stream_tripwire = tripwire.clone();
408            let mut connected = Some(connected);
409
410            let stream = stream
411                .take_until(tripwire)
412                .filter_map(move |maybe_tls_stream| {
413                    let au = auth.clone();
414                    async move {
415                        let maybe_tls_stream = maybe_tls_stream.unwrap();
416                        let ws_stream = match au {
417                            Some(a) => {
418                                let auth_callback = |req: &Request, res: Response| {
419                                    let hdr = req.headers().get("Authorization");
420                                    if let Some(h) = hdr {
421                                        match a {
422                                            Auth::Bearer { token } => {
423                                                if format!("Bearer {}", token.inner())
424                                                    != h.to_str().unwrap()
425                                                {
426                                                    return Err(
427                                                        http::Response::<Option<String>>::new(None),
428                                                    );
429                                                }
430                                            }
431                                            Auth::Basic {
432                                                user: _user,
433                                                password: _password,
434                                            } => { /* Not needed for tests at the moment */ }
435                                            #[cfg(feature = "aws-core")]
436                                            _ => {}
437                                        }
438                                    }
439                                    Ok(res)
440                                };
441                                accept_hdr_async(maybe_tls_stream, auth_callback)
442                                    .await
443                                    .unwrap()
444                            }
445                            None => accept_async(maybe_tls_stream).await.unwrap(),
446                        };
447
448                        Some(
449                            ws_stream
450                                .filter_map(|msg| {
451                                    future::ready(match msg {
452                                        Ok(msg) if msg.is_text() => {
453                                            Some(Ok(msg.into_text().unwrap()))
454                                        }
455                                        Err(TungsteniteError::Protocol(
456                                            ProtocolError::ResetWithoutClosingHandshake,
457                                        )) => None,
458                                        Err(e) => Some(Err(e)),
459                                        _ => None,
460                                    })
461                                })
462                                .take_while(|msg| future::ready(msg.is_ok()))
463                                .filter_map(|msg| future::ready(msg.ok())),
464                        )
465                    }
466                })
467                .map(move |ws_stream| {
468                    connected.take().map(|trigger| trigger.send(()));
469                    ws_stream
470                })
471                .flatten();
472
473            match interrupt_stream {
474                false => stream.boxed(),
475                true => stream.take_until(stream_tripwire).boxed(),
476            }
477        })
478    }
479}