1use std::{
2 io,
3 num::NonZeroU64,
4 time::{Duration, Instant},
5};
6
7use crate::{
8 codecs::{Encoder, Transformer},
9 common::websocket::{PingInterval, WebSocketConnector, is_closed},
10 event::{Event, EventStatus, Finalizable},
11 internal_events::{
12 ConnectionOpen, OpenGauge, WebSocketConnectionError, WebSocketConnectionShutdown,
13 },
14 sinks::{util::StreamSink, websocket::config::WebSocketSinkConfig},
15};
16use async_trait::async_trait;
17use bytes::BytesMut;
18use futures::{Sink, Stream, StreamExt, pin_mut, sink::SinkExt, stream::BoxStream};
19use tokio_tungstenite::tungstenite::{error::Error as TungsteniteError, protocol::Message};
20use tokio_util::codec::Encoder as _;
21use vector_lib::{
22 EstimatedJsonEncodedSizeOf, emit,
23 internal_event::{
24 ByteSize, BytesSent, CountByteSize, EventsSent, InternalEventHandle as _, Output, Protocol,
25 },
26};
27
28pub struct WebSocketSink {
29 transformer: Transformer,
30 encoder: Encoder<()>,
31 connector: WebSocketConnector,
32 ping_interval: Option<NonZeroU64>,
33 ping_timeout: Option<NonZeroU64>,
34}
35
36impl WebSocketSink {
37 pub(crate) fn new(
38 config: &WebSocketSinkConfig,
39 connector: WebSocketConnector,
40 ) -> crate::Result<Self> {
41 let transformer = config.encoding.transformer();
42 let serializer = config.encoding.build()?;
43 let encoder = Encoder::<()>::new(serializer);
44
45 Ok(Self {
46 transformer,
47 encoder,
48 connector,
49 ping_interval: config.common.ping_interval,
50 ping_timeout: config.common.ping_timeout,
51 })
52 }
53
54 async fn create_sink_and_stream(
55 &self,
56 ) -> (
57 impl Sink<Message, Error = TungsteniteError> + use<>,
58 impl Stream<Item = Result<Message, TungsteniteError>> + use<>,
59 ) {
60 let ws_stream = self.connector.connect_backoff().await;
61 ws_stream.split()
62 }
63
64 fn check_received_pong_time(&self, last_pong: Instant) -> Result<(), TungsteniteError> {
65 if let Some(ping_timeout) = self.ping_timeout
66 && last_pong.elapsed() > Duration::from_secs(ping_timeout.into())
67 {
68 return Err(TungsteniteError::Io(io::Error::new(
69 io::ErrorKind::TimedOut,
70 "Pong not received in time",
71 )));
72 }
73
74 Ok(())
75 }
76
77 async fn handle_events<I, WS, O>(
78 &mut self,
79 input: &mut I,
80 ws_stream: &mut WS,
81 ws_sink: &mut O,
82 ) -> Result<(), ()>
83 where
84 I: Stream<Item = Event> + Unpin,
85 WS: Stream<Item = Result<Message, TungsteniteError>> + Unpin,
86 O: Sink<Message, Error = TungsteniteError> + Unpin,
87 {
88 const PING: &[u8] = b"PING";
89
90 let mut ping_interval = PingInterval::new(self.ping_interval.map(u64::from));
93
94 if let Err(error) = ws_sink.send(Message::Ping(PING.to_vec())).await {
95 emit!(WebSocketConnectionError { error });
96 return Err(());
97 }
98 let mut last_pong = Instant::now();
99
100 let bytes_sent = register!(BytesSent::from(Protocol("websocket".into())));
101 let events_sent = register!(EventsSent::from(Output(None)));
102 let encode_as_binary = self.encoder.serializer().is_binary();
103
104 loop {
105 let result = tokio::select! {
106 _ = ping_interval.tick() => {
107 match self.check_received_pong_time(last_pong) {
108 Ok(()) => ws_sink.send(Message::Ping(PING.to_vec())).await.map(|_| ()),
109 Err(e) => Err(e)
110 }
111 },
112
113 Some(msg) = ws_stream.next() => {
114 match msg {
116 Ok(Message::Pong(_)) => {
117 last_pong = Instant::now();
118 Ok(())
119 },
120 Ok(_) => Ok(()),
121 Err(e) => Err(e)
122 }
123 },
124
125 event = input.next() => {
126 let mut event = if let Some(event) = event {
127 event
128 } else {
129 break;
130 };
131
132 let finalizers = event.take_finalizers();
133
134 self.transformer.transform(&mut event);
135
136 let event_byte_size = event.estimated_json_encoded_size_of();
137
138 let mut bytes = BytesMut::new();
139 match self.encoder.encode(event, &mut bytes) {
140 Ok(()) => {
141 finalizers.update_status(EventStatus::Delivered);
142
143 let message = if encode_as_binary {
144 Message::binary(bytes)
145 }
146 else {
147 Message::text(String::from_utf8_lossy(&bytes))
148 };
149 let message_len = message.len();
150
151 ws_sink.send(message).await.map(|_| {
152 events_sent.emit(CountByteSize(1, event_byte_size));
153 bytes_sent.emit(ByteSize(message_len));
154 })
155 },
156 Err(_) => {
157 finalizers.update_status(EventStatus::Errored);
159 Ok(())
160 }
161 }
162 },
163 else => break,
164 };
165
166 if let Err(error) = result {
167 if is_closed(&error) {
168 emit!(WebSocketConnectionShutdown);
169 } else {
170 emit!(WebSocketConnectionError { error });
171 }
172 return Err(());
173 }
174 }
175
176 Ok(())
177 }
178}
179
180#[async_trait]
181impl StreamSink<Event> for WebSocketSink {
182 async fn run(mut self: Box<Self>, input: BoxStream<'_, Event>) -> Result<(), ()> {
183 let input = input.fuse().peekable();
184 pin_mut!(input);
185
186 while input.as_mut().peek().await.is_some() {
187 let (ws_sink, ws_stream) = self.create_sink_and_stream().await;
188 pin_mut!(ws_sink);
189 pin_mut!(ws_stream);
190
191 let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
192
193 if self
194 .handle_events(&mut input, &mut ws_stream, &mut ws_sink)
195 .await
196 .is_ok()
197 {
198 _ = ws_sink.close().await;
199 }
200 }
201
202 Ok(())
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use std::net::SocketAddr;
209
210 use futures::{FutureExt, StreamExt, future};
211 use serde_json::Value as JsonValue;
212 use tokio::{time, time::timeout};
213 use tokio_tungstenite::{
214 accept_async, accept_hdr_async,
215 tungstenite::{
216 error::ProtocolError,
217 handshake::server::{Request, Response},
218 },
219 };
220 use vector_lib::codecs::JsonSerializerConfig;
221
222 use super::*;
223 use crate::{
224 common::websocket::WebSocketCommonConfig,
225 config::{SinkConfig, SinkContext},
226 http::Auth,
227 test_util::{
228 CountReceiver,
229 components::{SINK_TAGS, run_and_assert_sink_compliance},
230 next_addr, random_lines_with_stream, trace_init,
231 },
232 tls::{self, MaybeTlsSettings, TlsConfig, TlsEnableableConfig},
233 };
234
235 #[tokio::test(flavor = "multi_thread")]
236 async fn test_websocket() {
237 trace_init();
238
239 let addr = next_addr();
240 let config = WebSocketSinkConfig {
241 common: WebSocketCommonConfig {
242 uri: format!("ws://{addr}"),
243 tls: None,
244 ping_interval: None,
245 ping_timeout: None,
246 auth: None,
247 },
248 encoding: JsonSerializerConfig::default().into(),
249 acknowledgements: Default::default(),
250 };
251 let tls = MaybeTlsSettings::Raw(());
252
253 send_events_and_assert(addr, config, tls, None).await;
254 }
255
256 #[tokio::test(flavor = "multi_thread")]
257 async fn test_auth_websocket() {
258 trace_init();
259
260 let auth = Some(Auth::Bearer {
261 token: "OiJIUzI1NiIsInR5cCI6IkpXVCJ".to_string().into(),
262 });
263 let auth_clone = auth.clone();
264 let addr = next_addr();
265 let config = WebSocketSinkConfig {
266 common: WebSocketCommonConfig {
267 uri: format!("ws://{addr}"),
268 tls: None,
269 ping_interval: None,
270 ping_timeout: None,
271 auth: None,
272 },
273 encoding: JsonSerializerConfig::default().into(),
274 acknowledgements: Default::default(),
275 };
276 let tls = MaybeTlsSettings::Raw(());
277
278 send_events_and_assert(addr, config, tls, auth_clone).await;
279 }
280
281 #[tokio::test(flavor = "multi_thread")]
282 async fn test_tls_websocket() {
283 trace_init();
284
285 let addr = next_addr();
286 let tls_config = Some(TlsEnableableConfig::test_config());
287 let tls = MaybeTlsSettings::from_config(tls_config.as_ref(), true).unwrap();
288
289 let config = WebSocketSinkConfig {
290 common: WebSocketCommonConfig {
291 uri: format!("wss://{addr}"),
292 tls: Some(TlsEnableableConfig {
293 enabled: Some(true),
294 options: TlsConfig {
295 verify_certificate: Some(false),
296 verify_hostname: Some(true),
297 ca_file: Some(tls::TEST_PEM_CRT_PATH.into()),
298 ..Default::default()
299 },
300 }),
301 ping_timeout: None,
302 ping_interval: None,
303 auth: None,
304 },
305 encoding: JsonSerializerConfig::default().into(),
306 acknowledgements: Default::default(),
307 };
308
309 send_events_and_assert(addr, config, tls, None).await;
310 }
311
312 #[tokio::test]
313 async fn test_websocket_reconnect() {
314 trace_init();
315
316 let addr = next_addr();
317 let config = WebSocketSinkConfig {
318 common: WebSocketCommonConfig {
319 uri: format!("ws://{addr}"),
320 tls: None,
321 ping_interval: None,
322 ping_timeout: None,
323 auth: None,
324 },
325 encoding: JsonSerializerConfig::default().into(),
326 acknowledgements: Default::default(),
327 };
328 let tls = MaybeTlsSettings::Raw(());
329
330 let mut receiver = create_count_receiver(addr, tls.clone(), true, None);
331
332 let context = SinkContext::default();
333 let (sink, _healthcheck) = config.build(context).await.unwrap();
334
335 let (_lines, events) = random_lines_with_stream(10, 100, None);
336 let events = events.then(|event| async move {
337 time::sleep(Duration::from_millis(10)).await;
338 event
339 });
340 drop(tokio::spawn(sink.run(events)));
341
342 receiver.connected().await;
343 time::sleep(Duration::from_millis(500)).await;
344 assert!(!receiver.await.is_empty());
345
346 let mut receiver = create_count_receiver(addr, tls, false, None);
347 assert!(
348 timeout(Duration::from_secs(10), receiver.connected())
349 .await
350 .is_ok()
351 );
352 }
353
354 async fn send_events_and_assert(
355 addr: SocketAddr,
356 config: WebSocketSinkConfig,
357 tls: MaybeTlsSettings,
358 auth: Option<Auth>,
359 ) {
360 let mut receiver = create_count_receiver(addr, tls, false, auth);
361
362 let context = SinkContext::default();
363 let (sink, _healthcheck) = config.build(context).await.unwrap();
364
365 let (lines, events) = random_lines_with_stream(10, 100, None);
366 run_and_assert_sink_compliance(sink, events, &SINK_TAGS).await;
367
368 receiver.connected().await;
369
370 let output = receiver.await;
371 assert_eq!(lines.len(), output.len());
372 let message_key = crate::config::log_schema()
373 .message_key()
374 .expect("global log_schema.message_key to be valid path")
375 .to_string();
376 for (source, received) in lines.iter().zip(output) {
377 let json = serde_json::from_str::<JsonValue>(&received).expect("Invalid JSON");
378 let received = json.get(message_key.as_str()).unwrap().as_str().unwrap();
379 assert_eq!(source, received);
380 }
381 }
382
383 fn create_count_receiver(
384 addr: SocketAddr,
385 tls: MaybeTlsSettings,
386 interrupt_stream: bool,
387 auth: Option<Auth>,
388 ) -> CountReceiver<String> {
389 CountReceiver::receive_items_stream(move |tripwire, connected| async move {
390 let listener = tls.bind(&addr).await.unwrap();
391 let stream = listener.accept_stream();
392
393 let tripwire = tripwire.map(|_| ()).shared();
394 let stream_tripwire = tripwire.clone();
395 let mut connected = Some(connected);
396
397 let stream = stream
398 .take_until(tripwire)
399 .filter_map(move |maybe_tls_stream| {
400 let au = auth.clone();
401 async move {
402 let maybe_tls_stream = maybe_tls_stream.unwrap();
403 let ws_stream = match au {
404 Some(a) => {
405 let auth_callback = |req: &Request, res: Response| {
406 let hdr = req.headers().get("Authorization");
407 if let Some(h) = hdr {
408 match a {
409 Auth::Bearer { token } => {
410 if format!("Bearer {}", token.inner())
411 != h.to_str().unwrap()
412 {
413 return Err(
414 http::Response::<Option<String>>::new(None),
415 );
416 }
417 }
418 Auth::Basic {
419 user: _user,
420 password: _password,
421 } => { }
422 #[cfg(feature = "aws-core")]
423 _ => {}
424 }
425 }
426 Ok(res)
427 };
428 accept_hdr_async(maybe_tls_stream, auth_callback)
429 .await
430 .unwrap()
431 }
432 None => accept_async(maybe_tls_stream).await.unwrap(),
433 };
434
435 Some(
436 ws_stream
437 .filter_map(|msg| {
438 future::ready(match msg {
439 Ok(msg) if msg.is_text() => {
440 Some(Ok(msg.into_text().unwrap()))
441 }
442 Err(TungsteniteError::Protocol(
443 ProtocolError::ResetWithoutClosingHandshake,
444 )) => None,
445 Err(e) => Some(Err(e)),
446 _ => None,
447 })
448 })
449 .take_while(|msg| future::ready(msg.is_ok()))
450 .filter_map(|msg| future::ready(msg.ok())),
451 )
452 }
453 })
454 .map(move |ws_stream| {
455 connected.take().map(|trigger| trigger.send(()));
456 ws_stream
457 })
458 .flatten();
459
460 match interrupt_stream {
461 false => stream.boxed(),
462 true => stream.take_until(stream_tripwire).boxed(),
463 }
464 })
465 }
466}