1use std::{
2 io,
3 num::NonZeroU64,
4 time::{Duration, Instant},
5};
6
7use async_trait::async_trait;
8use bytes::BytesMut;
9use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt, stream::BoxStream};
10use tokio_tungstenite::tungstenite::{error::Error as TungsteniteError, protocol::Message};
11use tokio_util::codec::Encoder as _;
12use vector_lib::{
13 EstimatedJsonEncodedSizeOf, emit,
14 internal_event::{
15 ByteSize, BytesSent, CountByteSize, EventsSent, InternalEventHandle as _, Output, Protocol,
16 },
17};
18
19use crate::{
20 codecs::{Encoder, Transformer},
21 common::websocket::{PingInterval, WebSocketConnector, is_closed},
22 event::{Event, EventStatus, Finalizable},
23 internal_events::{
24 ConnectionOpen, OpenGauge, WebSocketConnectionError, WebSocketConnectionShutdown,
25 },
26 sinks::{util::StreamSink, websocket::config::WebSocketSinkConfig},
27};
28
29pub struct WebSocketSink {
30 transformer: Transformer,
31 encoder: Encoder<()>,
32 connector: WebSocketConnector,
33 ping_interval: Option<NonZeroU64>,
34 ping_timeout: Option<NonZeroU64>,
35}
36
37impl WebSocketSink {
38 pub(crate) fn new(
39 config: &WebSocketSinkConfig,
40 connector: WebSocketConnector,
41 ) -> crate::Result<Self> {
42 let transformer = config.encoding.transformer();
43 let serializer = config.encoding.build()?;
44 let encoder = Encoder::<()>::new(serializer);
45
46 Ok(Self {
47 transformer,
48 encoder,
49 connector,
50 ping_interval: config.common.ping_interval,
51 ping_timeout: config.common.ping_timeout,
52 })
53 }
54
55 async fn create_sink_and_stream(
56 &self,
57 ) -> (
58 impl Sink<Message, Error = TungsteniteError> + use<>,
59 impl Stream<Item = Result<Message, TungsteniteError>> + use<>,
60 ) {
61 let ws_stream = self.connector.connect_backoff().await;
62 ws_stream.split()
63 }
64
65 fn check_received_pong_time(&self, last_pong: Instant) -> Result<(), TungsteniteError> {
66 if let Some(ping_timeout) = self.ping_timeout
67 && last_pong.elapsed() > Duration::from_secs(ping_timeout.into())
68 {
69 return Err(TungsteniteError::Io(io::Error::new(
70 io::ErrorKind::TimedOut,
71 "Pong not received in time",
72 )));
73 }
74
75 Ok(())
76 }
77
78 const fn should_encode_as_binary(&self) -> bool {
79 use vector_lib::codecs::encoding::Serializer::{
80 Avro, Cef, Csv, Gelf, Json, Logfmt, Native, NativeJson, Protobuf, RawMessage, Text,
81 };
82
83 match self.encoder.serializer() {
84 RawMessage(_) | Avro(_) | Native(_) | Protobuf(_) => true,
85 Cef(_) | Csv(_) | Logfmt(_) | Gelf(_) | Json(_) | Text(_) | NativeJson(_) => false,
86 }
87 }
88
89 async fn handle_events<I, WS, O>(
90 &mut self,
91 input: &mut I,
92 ws_stream: &mut WS,
93 ws_sink: &mut O,
94 ) -> Result<(), ()>
95 where
96 I: Stream<Item = Event> + Unpin,
97 WS: Stream<Item = Result<Message, TungsteniteError>> + Unpin,
98 O: Sink<Message, Error = TungsteniteError> + Unpin,
99 {
100 const PING: &[u8] = b"PING";
101
102 let mut ping_interval = PingInterval::new(self.ping_interval.map(u64::from));
105
106 if let Err(error) = ws_sink.send(Message::Ping(PING.to_vec())).await {
107 emit!(WebSocketConnectionError { error });
108 return Err(());
109 }
110 let mut last_pong = Instant::now();
111
112 let bytes_sent = register!(BytesSent::from(Protocol("websocket".into())));
113 let events_sent = register!(EventsSent::from(Output(None)));
114 let encode_as_binary = self.should_encode_as_binary();
115
116 loop {
117 let result = tokio::select! {
118 _ = ping_interval.tick() => {
119 match self.check_received_pong_time(last_pong) {
120 Ok(()) => ws_sink.send(Message::Ping(PING.to_vec())).await.map(|_| ()),
121 Err(e) => Err(e)
122 }
123 },
124
125 Some(msg) = ws_stream.next() => {
126 match msg {
128 Ok(Message::Pong(_)) => {
129 last_pong = Instant::now();
130 Ok(())
131 },
132 Ok(_) => Ok(()),
133 Err(e) => Err(e)
134 }
135 },
136
137 event = input.next() => {
138 let mut event = if let Some(event) = event {
139 event
140 } else {
141 break;
142 };
143
144 let finalizers = event.take_finalizers();
145
146 self.transformer.transform(&mut event);
147
148 let event_byte_size = event.estimated_json_encoded_size_of();
149
150 let mut bytes = BytesMut::new();
151 match self.encoder.encode(event, &mut bytes) {
152 Ok(()) => {
153 finalizers.update_status(EventStatus::Delivered);
154
155 let message = if encode_as_binary {
156 Message::binary(bytes)
157 }
158 else {
159 Message::text(String::from_utf8_lossy(&bytes))
160 };
161 let message_len = message.len();
162
163 ws_sink.send(message).await.map(|_| {
164 events_sent.emit(CountByteSize(1, event_byte_size));
165 bytes_sent.emit(ByteSize(message_len));
166 })
167 },
168 Err(_) => {
169 finalizers.update_status(EventStatus::Errored);
171 Ok(())
172 }
173 }
174 },
175 else => break,
176 };
177
178 if let Err(error) = result {
179 if is_closed(&error) {
180 emit!(WebSocketConnectionShutdown);
181 } else {
182 emit!(WebSocketConnectionError { error });
183 }
184 return Err(());
185 }
186 }
187
188 Ok(())
189 }
190}
191
192#[async_trait]
193impl StreamSink<Event> for WebSocketSink {
194 async fn run(mut self: Box<Self>, input: BoxStream<'_, Event>) -> Result<(), ()> {
195 let input = input.fuse().peekable();
196 pin_mut!(input);
197
198 while input.as_mut().peek().await.is_some() {
199 let (ws_sink, ws_stream) = self.create_sink_and_stream().await;
200 pin_mut!(ws_sink);
201 pin_mut!(ws_stream);
202
203 let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
204
205 if self
206 .handle_events(&mut input, &mut ws_stream, &mut ws_sink)
207 .await
208 .is_ok()
209 {
210 _ = ws_sink.close().await;
211 }
212 }
213
214 Ok(())
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use std::net::SocketAddr;
221
222 use futures::{FutureExt, StreamExt, future};
223 use serde_json::Value as JsonValue;
224 use tokio::{time, time::timeout};
225 use tokio_tungstenite::{
226 accept_async, accept_hdr_async,
227 tungstenite::{
228 error::ProtocolError,
229 handshake::server::{Request, Response},
230 },
231 };
232 use vector_lib::codecs::JsonSerializerConfig;
233
234 use super::*;
235 use crate::{
236 common::websocket::WebSocketCommonConfig,
237 config::{SinkConfig, SinkContext},
238 http::Auth,
239 test_util::{
240 CountReceiver,
241 components::{SINK_TAGS, run_and_assert_sink_compliance},
242 next_addr, random_lines_with_stream, trace_init,
243 },
244 tls::{self, MaybeTlsSettings, TlsConfig, TlsEnableableConfig},
245 };
246
247 #[tokio::test(flavor = "multi_thread")]
248 async fn test_websocket() {
249 trace_init();
250
251 let addr = next_addr();
252 let config = WebSocketSinkConfig {
253 common: WebSocketCommonConfig {
254 uri: format!("ws://{addr}"),
255 tls: None,
256 ping_interval: None,
257 ping_timeout: None,
258 auth: None,
259 },
260 encoding: JsonSerializerConfig::default().into(),
261 acknowledgements: Default::default(),
262 };
263 let tls = MaybeTlsSettings::Raw(());
264
265 send_events_and_assert(addr, config, tls, None).await;
266 }
267
268 #[tokio::test(flavor = "multi_thread")]
269 async fn test_auth_websocket() {
270 trace_init();
271
272 let auth = Some(Auth::Bearer {
273 token: "OiJIUzI1NiIsInR5cCI6IkpXVCJ".to_string().into(),
274 });
275 let auth_clone = auth.clone();
276 let addr = next_addr();
277 let config = WebSocketSinkConfig {
278 common: WebSocketCommonConfig {
279 uri: format!("ws://{addr}"),
280 tls: None,
281 ping_interval: None,
282 ping_timeout: None,
283 auth: None,
284 },
285 encoding: JsonSerializerConfig::default().into(),
286 acknowledgements: Default::default(),
287 };
288 let tls = MaybeTlsSettings::Raw(());
289
290 send_events_and_assert(addr, config, tls, auth_clone).await;
291 }
292
293 #[tokio::test(flavor = "multi_thread")]
294 async fn test_tls_websocket() {
295 trace_init();
296
297 let addr = next_addr();
298 let tls_config = Some(TlsEnableableConfig::test_config());
299 let tls = MaybeTlsSettings::from_config(tls_config.as_ref(), true).unwrap();
300
301 let config = WebSocketSinkConfig {
302 common: WebSocketCommonConfig {
303 uri: format!("wss://{addr}"),
304 tls: Some(TlsEnableableConfig {
305 enabled: Some(true),
306 options: TlsConfig {
307 verify_certificate: Some(false),
308 verify_hostname: Some(true),
309 ca_file: Some(tls::TEST_PEM_CRT_PATH.into()),
310 ..Default::default()
311 },
312 }),
313 ping_timeout: None,
314 ping_interval: None,
315 auth: None,
316 },
317 encoding: JsonSerializerConfig::default().into(),
318 acknowledgements: Default::default(),
319 };
320
321 send_events_and_assert(addr, config, tls, None).await;
322 }
323
324 #[tokio::test]
325 async fn test_websocket_reconnect() {
326 trace_init();
327
328 let addr = next_addr();
329 let config = WebSocketSinkConfig {
330 common: WebSocketCommonConfig {
331 uri: format!("ws://{addr}"),
332 tls: None,
333 ping_interval: None,
334 ping_timeout: None,
335 auth: None,
336 },
337 encoding: JsonSerializerConfig::default().into(),
338 acknowledgements: Default::default(),
339 };
340 let tls = MaybeTlsSettings::Raw(());
341
342 let mut receiver = create_count_receiver(addr, tls.clone(), true, None);
343
344 let context = SinkContext::default();
345 let (sink, _healthcheck) = config.build(context).await.unwrap();
346
347 let (_lines, events) = random_lines_with_stream(10, 100, None);
348 let events = events.then(|event| async move {
349 time::sleep(Duration::from_millis(10)).await;
350 event
351 });
352 drop(tokio::spawn(sink.run(events)));
353
354 receiver.connected().await;
355 time::sleep(Duration::from_millis(500)).await;
356 assert!(!receiver.await.is_empty());
357
358 let mut receiver = create_count_receiver(addr, tls, false, None);
359 assert!(
360 timeout(Duration::from_secs(10), receiver.connected())
361 .await
362 .is_ok()
363 );
364 }
365
366 async fn send_events_and_assert(
367 addr: SocketAddr,
368 config: WebSocketSinkConfig,
369 tls: MaybeTlsSettings,
370 auth: Option<Auth>,
371 ) {
372 let mut receiver = create_count_receiver(addr, tls, false, auth);
373
374 let context = SinkContext::default();
375 let (sink, _healthcheck) = config.build(context).await.unwrap();
376
377 let (lines, events) = random_lines_with_stream(10, 100, None);
378 run_and_assert_sink_compliance(sink, events, &SINK_TAGS).await;
379
380 receiver.connected().await;
381
382 let output = receiver.await;
383 assert_eq!(lines.len(), output.len());
384 let message_key = crate::config::log_schema()
385 .message_key()
386 .expect("global log_schema.message_key to be valid path")
387 .to_string();
388 for (source, received) in lines.iter().zip(output) {
389 let json = serde_json::from_str::<JsonValue>(&received).expect("Invalid JSON");
390 let received = json.get(message_key.as_str()).unwrap().as_str().unwrap();
391 assert_eq!(source, received);
392 }
393 }
394
395 fn create_count_receiver(
396 addr: SocketAddr,
397 tls: MaybeTlsSettings,
398 interrupt_stream: bool,
399 auth: Option<Auth>,
400 ) -> CountReceiver<String> {
401 CountReceiver::receive_items_stream(move |tripwire, connected| async move {
402 let listener = tls.bind(&addr).await.unwrap();
403 let stream = listener.accept_stream();
404
405 let tripwire = tripwire.map(|_| ()).shared();
406 let stream_tripwire = tripwire.clone();
407 let mut connected = Some(connected);
408
409 let stream = stream
410 .take_until(tripwire)
411 .filter_map(move |maybe_tls_stream| {
412 let au = auth.clone();
413 async move {
414 let maybe_tls_stream = maybe_tls_stream.unwrap();
415 let ws_stream = match au {
416 Some(a) => {
417 let auth_callback = |req: &Request, res: Response| {
418 let hdr = req.headers().get("Authorization");
419 if let Some(h) = hdr {
420 match a {
421 Auth::Bearer { token } => {
422 if format!("Bearer {}", token.inner())
423 != h.to_str().unwrap()
424 {
425 return Err(
426 http::Response::<Option<String>>::new(None),
427 );
428 }
429 }
430 Auth::Basic {
431 user: _user,
432 password: _password,
433 } => { }
434 #[cfg(feature = "aws-core")]
435 _ => {}
436 }
437 }
438 Ok(res)
439 };
440 accept_hdr_async(maybe_tls_stream, auth_callback)
441 .await
442 .unwrap()
443 }
444 None => accept_async(maybe_tls_stream).await.unwrap(),
445 };
446
447 Some(
448 ws_stream
449 .filter_map(|msg| {
450 future::ready(match msg {
451 Ok(msg) if msg.is_text() => {
452 Some(Ok(msg.into_text().unwrap()))
453 }
454 Err(TungsteniteError::Protocol(
455 ProtocolError::ResetWithoutClosingHandshake,
456 )) => None,
457 Err(e) => Some(Err(e)),
458 _ => None,
459 })
460 })
461 .take_while(|msg| future::ready(msg.is_ok()))
462 .filter_map(|msg| future::ready(msg.ok())),
463 )
464 }
465 })
466 .map(move |ws_stream| {
467 connected.take().map(|trigger| trigger.send(()));
468 ws_stream
469 })
470 .flatten();
471
472 match interrupt_stream {
473 false => stream.boxed(),
474 true => stream.take_until(stream_tripwire).boxed(),
475 }
476 })
477 }
478}