vector/sinks/websocket/
sink.rs

1use std::{
2    io,
3    num::NonZeroU64,
4    time::{Duration, Instant},
5};
6
7use crate::{
8    codecs::{Encoder, Transformer},
9    common::websocket::{PingInterval, WebSocketConnector, is_closed},
10    event::{Event, EventStatus, Finalizable},
11    internal_events::{
12        ConnectionOpen, OpenGauge, WebSocketConnectionError, WebSocketConnectionShutdown,
13    },
14    sinks::{util::StreamSink, websocket::config::WebSocketSinkConfig},
15};
16use async_trait::async_trait;
17use bytes::BytesMut;
18use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt, stream::BoxStream};
19use tokio_tungstenite::tungstenite::{error::Error as TungsteniteError, protocol::Message};
20use tokio_util::codec::Encoder as _;
21use vector_lib::{
22    EstimatedJsonEncodedSizeOf, emit,
23    internal_event::{
24        ByteSize, BytesSent, CountByteSize, EventsSent, InternalEventHandle as _, Output, Protocol,
25    },
26};
27
28pub struct WebSocketSink {
29    transformer: Transformer,
30    encoder: Encoder<()>,
31    connector: WebSocketConnector,
32    ping_interval: Option<NonZeroU64>,
33    ping_timeout: Option<NonZeroU64>,
34}
35
36impl WebSocketSink {
37    pub(crate) fn new(
38        config: &WebSocketSinkConfig,
39        connector: WebSocketConnector,
40    ) -> crate::Result<Self> {
41        let transformer = config.encoding.transformer();
42        let serializer = config.encoding.build()?;
43        let encoder = Encoder::<()>::new(serializer);
44
45        Ok(Self {
46            transformer,
47            encoder,
48            connector,
49            ping_interval: config.common.ping_interval,
50            ping_timeout: config.common.ping_timeout,
51        })
52    }
53
54    async fn create_sink_and_stream(
55        &self,
56    ) -> (
57        impl Sink<Message, Error = TungsteniteError> + use<>,
58        impl Stream<Item = Result<Message, TungsteniteError>> + use<>,
59    ) {
60        let ws_stream = self.connector.connect_backoff().await;
61        ws_stream.split()
62    }
63
64    fn check_received_pong_time(&self, last_pong: Instant) -> Result<(), TungsteniteError> {
65        if let Some(ping_timeout) = self.ping_timeout
66            && last_pong.elapsed() > Duration::from_secs(ping_timeout.into())
67        {
68            return Err(TungsteniteError::Io(io::Error::new(
69                io::ErrorKind::TimedOut,
70                "Pong not received in time",
71            )));
72        }
73
74        Ok(())
75    }
76
77    async fn handle_events<I, WS, O>(
78        &mut self,
79        input: &mut I,
80        ws_stream: &mut WS,
81        ws_sink: &mut O,
82    ) -> Result<(), ()>
83    where
84        I: Stream<Item = Event> + Unpin,
85        WS: Stream<Item = Result<Message, TungsteniteError>> + Unpin,
86        O: Sink<Message, Error = TungsteniteError> + Unpin,
87    {
88        const PING: &[u8] = b"PING";
89
90        // tokio::time::Interval panics if the period arg is zero. Since the struct members are
91        // using NonZeroU64 that is not something we need to account for.
92        let mut ping_interval = PingInterval::new(self.ping_interval.map(u64::from));
93
94        if let Err(error) = ws_sink.send(Message::Ping(PING.to_vec())).await {
95            emit!(WebSocketConnectionError { error });
96            return Err(());
97        }
98        let mut last_pong = Instant::now();
99
100        let bytes_sent = register!(BytesSent::from(Protocol("websocket".into())));
101        let events_sent = register!(EventsSent::from(Output(None)));
102        let encode_as_binary = self.encoder.serializer().is_binary();
103
104        loop {
105            let result = tokio::select! {
106                _ = ping_interval.tick() => {
107                    match self.check_received_pong_time(last_pong) {
108                        Ok(()) => ws_sink.send(Message::Ping(PING.to_vec())).await.map(|_| ()),
109                        Err(e) => Err(e)
110                    }
111                },
112
113                Some(msg) = ws_stream.next() => {
114                    // Pongs are sent automatically by tungstenite during reading from the stream.
115                    match msg {
116                        Ok(Message::Pong(_)) => {
117                            last_pong = Instant::now();
118                            Ok(())
119                        },
120                        Ok(_) => Ok(()),
121                        Err(e) => Err(e)
122                    }
123                },
124
125                event = input.next() => {
126                    let mut event = if let Some(event) = event {
127                        event
128                    } else {
129                        break;
130                    };
131
132                    let finalizers = event.take_finalizers();
133
134                    self.transformer.transform(&mut event);
135
136                    let event_byte_size = event.estimated_json_encoded_size_of();
137
138                    let mut bytes = BytesMut::new();
139                    match self.encoder.encode(event, &mut bytes) {
140                        Ok(()) => {
141                            finalizers.update_status(EventStatus::Delivered);
142
143                            let message = if encode_as_binary {
144                                Message::binary(bytes)
145                            }
146                            else {
147                                Message::text(String::from_utf8_lossy(&bytes))
148                            };
149                            let message_len = message.len();
150
151                            ws_sink.send(message).await.map(|_| {
152                                events_sent.emit(CountByteSize(1, event_byte_size));
153                                bytes_sent.emit(ByteSize(message_len));
154                            })
155                        },
156                        Err(_) => {
157                            // Error is handled by `Encoder`.
158                            finalizers.update_status(EventStatus::Errored);
159                            Ok(())
160                        }
161                    }
162                },
163                else => break,
164            };
165
166            if let Err(error) = result {
167                if is_closed(&error) {
168                    emit!(WebSocketConnectionShutdown);
169                } else {
170                    emit!(WebSocketConnectionError { error });
171                }
172                return Err(());
173            }
174        }
175
176        Ok(())
177    }
178}
179
180#[async_trait]
181impl StreamSink<Event> for WebSocketSink {
182    async fn run(mut self: Box<Self>, input: BoxStream<'_, Event>) -> Result<(), ()> {
183        let input = input.fuse().peekable();
184        pin_mut!(input);
185
186        while input.as_mut().peek().await.is_some() {
187            let (ws_sink, ws_stream) = self.create_sink_and_stream().await;
188            pin_mut!(ws_sink);
189            pin_mut!(ws_stream);
190
191            let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
192
193            if self
194                .handle_events(&mut input, &mut ws_stream, &mut ws_sink)
195                .await
196                .is_ok()
197            {
198                _ = ws_sink.close().await;
199            }
200        }
201
202        Ok(())
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use std::net::SocketAddr;
209
210    use futures::{FutureExt, StreamExt, future};
211    use serde_json::Value as JsonValue;
212    use tokio::{time, time::timeout};
213    use tokio_tungstenite::{
214        accept_async, accept_hdr_async,
215        tungstenite::{
216            error::ProtocolError,
217            handshake::server::{Request, Response},
218        },
219    };
220    use vector_lib::codecs::JsonSerializerConfig;
221
222    use super::*;
223    use crate::{
224        common::websocket::WebSocketCommonConfig,
225        config::{SinkConfig, SinkContext},
226        http::Auth,
227        test_util::{
228            CountReceiver,
229            addr::next_addr,
230            components::{SINK_TAGS, run_and_assert_sink_compliance},
231            random_lines_with_stream, trace_init,
232        },
233        tls::{self, MaybeTlsSettings, TlsConfig, TlsEnableableConfig},
234    };
235
236    #[tokio::test(flavor = "multi_thread")]
237    async fn test_websocket() {
238        trace_init();
239
240        let (_guard, addr) = next_addr();
241        let config = WebSocketSinkConfig {
242            common: WebSocketCommonConfig {
243                uri: format!("ws://{addr}"),
244                tls: None,
245                ping_interval: None,
246                ping_timeout: None,
247                auth: None,
248            },
249            encoding: JsonSerializerConfig::default().into(),
250            acknowledgements: Default::default(),
251        };
252        let tls = MaybeTlsSettings::Raw(());
253
254        send_events_and_assert(addr, config, tls, None).await;
255    }
256
257    #[tokio::test(flavor = "multi_thread")]
258    async fn test_auth_websocket() {
259        trace_init();
260
261        let auth = Some(Auth::Bearer {
262            token: "OiJIUzI1NiIsInR5cCI6IkpXVCJ".to_string().into(),
263        });
264        let auth_clone = auth.clone();
265        let (_guard, addr) = next_addr();
266        let config = WebSocketSinkConfig {
267            common: WebSocketCommonConfig {
268                uri: format!("ws://{addr}"),
269                tls: None,
270                ping_interval: None,
271                ping_timeout: None,
272                auth: None,
273            },
274            encoding: JsonSerializerConfig::default().into(),
275            acknowledgements: Default::default(),
276        };
277        let tls = MaybeTlsSettings::Raw(());
278
279        send_events_and_assert(addr, config, tls, auth_clone).await;
280    }
281
282    #[tokio::test(flavor = "multi_thread")]
283    async fn test_tls_websocket() {
284        trace_init();
285
286        let (_guard, addr) = next_addr();
287        let tls_config = Some(TlsEnableableConfig::test_config());
288        let tls = MaybeTlsSettings::from_config(tls_config.as_ref(), true).unwrap();
289
290        let config = WebSocketSinkConfig {
291            common: WebSocketCommonConfig {
292                uri: format!("wss://{addr}"),
293                tls: Some(TlsEnableableConfig {
294                    enabled: Some(true),
295                    options: TlsConfig {
296                        verify_certificate: Some(false),
297                        verify_hostname: Some(true),
298                        ca_file: Some(tls::TEST_PEM_CRT_PATH.into()),
299                        ..Default::default()
300                    },
301                }),
302                ping_timeout: None,
303                ping_interval: None,
304                auth: None,
305            },
306            encoding: JsonSerializerConfig::default().into(),
307            acknowledgements: Default::default(),
308        };
309
310        send_events_and_assert(addr, config, tls, None).await;
311    }
312
313    #[tokio::test]
314    async fn test_websocket_reconnect() {
315        trace_init();
316
317        let (_guard, addr) = next_addr();
318        let config = WebSocketSinkConfig {
319            common: WebSocketCommonConfig {
320                uri: format!("ws://{addr}"),
321                tls: None,
322                ping_interval: None,
323                ping_timeout: None,
324                auth: None,
325            },
326            encoding: JsonSerializerConfig::default().into(),
327            acknowledgements: Default::default(),
328        };
329        let tls = MaybeTlsSettings::Raw(());
330
331        let mut receiver = create_count_receiver(addr, tls.clone(), true, None);
332
333        let context = SinkContext::default();
334        let (sink, _healthcheck) = config.build(context).await.unwrap();
335
336        let (_lines, events) = random_lines_with_stream(10, 100, None);
337        let events = events.then(|event| async move {
338            time::sleep(Duration::from_millis(10)).await;
339            event
340        });
341        drop(tokio::spawn(sink.run(events)));
342
343        receiver.connected().await;
344        time::sleep(Duration::from_millis(500)).await;
345        assert!(!receiver.await.is_empty());
346
347        let mut receiver = create_count_receiver(addr, tls, false, None);
348        assert!(
349            timeout(Duration::from_secs(10), receiver.connected())
350                .await
351                .is_ok()
352        );
353    }
354
355    async fn send_events_and_assert(
356        addr: SocketAddr,
357        config: WebSocketSinkConfig,
358        tls: MaybeTlsSettings,
359        auth: Option<Auth>,
360    ) {
361        let mut receiver = create_count_receiver(addr, tls, false, auth);
362
363        let context = SinkContext::default();
364        let (sink, _healthcheck) = config.build(context).await.unwrap();
365
366        let (lines, events) = random_lines_with_stream(10, 100, None);
367        run_and_assert_sink_compliance(sink, events, &SINK_TAGS).await;
368
369        receiver.connected().await;
370
371        let output = receiver.await;
372        assert_eq!(lines.len(), output.len());
373        let message_key = crate::config::log_schema()
374            .message_key()
375            .expect("global log_schema.message_key to be valid path")
376            .to_string();
377        for (source, received) in lines.iter().zip(output) {
378            let json = serde_json::from_str::<JsonValue>(&received).expect("Invalid JSON");
379            let received = json.get(message_key.as_str()).unwrap().as_str().unwrap();
380            assert_eq!(source, received);
381        }
382    }
383
384    fn create_count_receiver(
385        addr: SocketAddr,
386        tls: MaybeTlsSettings,
387        interrupt_stream: bool,
388        auth: Option<Auth>,
389    ) -> CountReceiver<String> {
390        CountReceiver::receive_items_stream(move |tripwire, connected| async move {
391            let listener = tls.bind(&addr).await.unwrap();
392            let stream = listener.accept_stream();
393
394            let tripwire = tripwire.map(|_| ()).shared();
395            let stream_tripwire = tripwire.clone();
396            let mut connected = Some(connected);
397
398            let stream = stream
399                .take_until(tripwire)
400                .filter_map(move |maybe_tls_stream| {
401                    let au = auth.clone();
402                    async move {
403                        let maybe_tls_stream = maybe_tls_stream.unwrap();
404                        let ws_stream = match au {
405                            Some(a) => {
406                                let auth_callback = |req: &Request, res: Response| {
407                                    let hdr = req.headers().get("Authorization");
408                                    if let Some(h) = hdr {
409                                        match a {
410                                            Auth::Bearer { token } => {
411                                                if format!("Bearer {}", token.inner())
412                                                    != h.to_str().unwrap()
413                                                {
414                                                    return Err(
415                                                        http::Response::<Option<String>>::new(None),
416                                                    );
417                                                }
418                                            }
419                                            Auth::Basic {
420                                                user: _user,
421                                                password: _password,
422                                            } => { /* Not needed for tests at the moment */ }
423                                            Auth::Custom { .. } => { /* Not needed for tests at the moment */ }
424                                            #[cfg(feature = "aws-core")]
425                                            _ => {}
426                                        }
427                                    }
428                                    Ok(res)
429                                };
430                                accept_hdr_async(maybe_tls_stream, auth_callback)
431                                    .await
432                                    .unwrap()
433                            }
434                            None => accept_async(maybe_tls_stream).await.unwrap(),
435                        };
436
437                        Some(
438                            ws_stream
439                                .filter_map(|msg| {
440                                    future::ready(match msg {
441                                        Ok(msg) if msg.is_text() => {
442                                            Some(Ok(msg.into_text().unwrap()))
443                                        }
444                                        Err(TungsteniteError::Protocol(
445                                            ProtocolError::ResetWithoutClosingHandshake,
446                                        )) => None,
447                                        Err(e) => Some(Err(e)),
448                                        _ => None,
449                                    })
450                                })
451                                .take_while(|msg| future::ready(msg.is_ok()))
452                                .filter_map(|msg| future::ready(msg.ok())),
453                        )
454                    }
455                })
456                .map(move |ws_stream| {
457                    connected.take().map(|trigger| trigger.send(()));
458                    ws_stream
459                })
460                .flatten();
461
462            match interrupt_stream {
463                false => stream.boxed(),
464                true => stream.take_until(stream_tripwire).boxed(),
465            }
466        })
467    }
468}