1use std::{
2 io,
3 num::NonZeroU64,
4 time::{Duration, Instant},
5};
6
7use async_trait::async_trait;
8use bytes::BytesMut;
9use futures::{pin_mut, sink::SinkExt, stream::BoxStream, Sink, Stream, StreamExt};
10use tokio_tungstenite::tungstenite::{error::Error as TungsteniteError, protocol::Message};
11use tokio_util::codec::Encoder as _;
12use vector_lib::{
13 emit,
14 internal_event::{
15 ByteSize, BytesSent, CountByteSize, EventsSent, InternalEventHandle as _, Output, Protocol,
16 },
17 EstimatedJsonEncodedSizeOf,
18};
19
20use crate::{
21 codecs::{Encoder, Transformer},
22 common::websocket::{is_closed, PingInterval, WebSocketConnector},
23 event::{Event, EventStatus, Finalizable},
24 internal_events::{
25 ConnectionOpen, OpenGauge, WebSocketConnectionError, WebSocketConnectionShutdown,
26 },
27 sinks::util::StreamSink,
28 sinks::websocket::config::WebSocketSinkConfig,
29};
30
31pub struct WebSocketSink {
32 transformer: Transformer,
33 encoder: Encoder<()>,
34 connector: WebSocketConnector,
35 ping_interval: Option<NonZeroU64>,
36 ping_timeout: Option<NonZeroU64>,
37}
38
39impl WebSocketSink {
40 pub(crate) fn new(
41 config: &WebSocketSinkConfig,
42 connector: WebSocketConnector,
43 ) -> crate::Result<Self> {
44 let transformer = config.encoding.transformer();
45 let serializer = config.encoding.build()?;
46 let encoder = Encoder::<()>::new(serializer);
47
48 Ok(Self {
49 transformer,
50 encoder,
51 connector,
52 ping_interval: config.common.ping_interval,
53 ping_timeout: config.common.ping_timeout,
54 })
55 }
56
57 async fn create_sink_and_stream(
58 &self,
59 ) -> (
60 impl Sink<Message, Error = TungsteniteError>,
61 impl Stream<Item = Result<Message, TungsteniteError>>,
62 ) {
63 let ws_stream = self.connector.connect_backoff().await;
64 ws_stream.split()
65 }
66
67 fn check_received_pong_time(&self, last_pong: Instant) -> Result<(), TungsteniteError> {
68 if let Some(ping_timeout) = self.ping_timeout {
69 if last_pong.elapsed() > Duration::from_secs(ping_timeout.into()) {
70 return Err(TungsteniteError::Io(io::Error::new(
71 io::ErrorKind::TimedOut,
72 "Pong not received in time",
73 )));
74 }
75 }
76
77 Ok(())
78 }
79
80 const fn should_encode_as_binary(&self) -> bool {
81 use vector_lib::codecs::encoding::Serializer::{
82 Avro, Cef, Csv, Gelf, Json, Logfmt, Native, NativeJson, Protobuf, RawMessage, Text,
83 };
84
85 match self.encoder.serializer() {
86 RawMessage(_) | Avro(_) | Native(_) | Protobuf(_) => true,
87 Cef(_) | Csv(_) | Logfmt(_) | Gelf(_) | Json(_) | Text(_) | NativeJson(_) => false,
88 }
89 }
90
91 async fn handle_events<I, WS, O>(
92 &mut self,
93 input: &mut I,
94 ws_stream: &mut WS,
95 ws_sink: &mut O,
96 ) -> Result<(), ()>
97 where
98 I: Stream<Item = Event> + Unpin,
99 WS: Stream<Item = Result<Message, TungsteniteError>> + Unpin,
100 O: Sink<Message, Error = TungsteniteError> + Unpin,
101 {
102 const PING: &[u8] = b"PING";
103
104 let mut ping_interval = PingInterval::new(self.ping_interval.map(u64::from));
107
108 if let Err(error) = ws_sink.send(Message::Ping(PING.to_vec())).await {
109 emit!(WebSocketConnectionError { error });
110 return Err(());
111 }
112 let mut last_pong = Instant::now();
113
114 let bytes_sent = register!(BytesSent::from(Protocol("websocket".into())));
115 let events_sent = register!(EventsSent::from(Output(None)));
116 let encode_as_binary = self.should_encode_as_binary();
117
118 loop {
119 let result = tokio::select! {
120 _ = ping_interval.tick() => {
121 match self.check_received_pong_time(last_pong) {
122 Ok(()) => ws_sink.send(Message::Ping(PING.to_vec())).await.map(|_| ()),
123 Err(e) => Err(e)
124 }
125 },
126
127 Some(msg) = ws_stream.next() => {
128 match msg {
130 Ok(Message::Pong(_)) => {
131 last_pong = Instant::now();
132 Ok(())
133 },
134 Ok(_) => Ok(()),
135 Err(e) => Err(e)
136 }
137 },
138
139 event = input.next() => {
140 let mut event = if let Some(event) = event {
141 event
142 } else {
143 break;
144 };
145
146 let finalizers = event.take_finalizers();
147
148 self.transformer.transform(&mut event);
149
150 let event_byte_size = event.estimated_json_encoded_size_of();
151
152 let mut bytes = BytesMut::new();
153 let res = match self.encoder.encode(event, &mut bytes) {
154 Ok(()) => {
155 finalizers.update_status(EventStatus::Delivered);
156
157 let message = if encode_as_binary {
158 Message::binary(bytes)
159 }
160 else {
161 Message::text(String::from_utf8_lossy(&bytes))
162 };
163 let message_len = message.len();
164
165 ws_sink.send(message).await.map(|_| {
166 events_sent.emit(CountByteSize(1, event_byte_size));
167 bytes_sent.emit(ByteSize(message_len));
168 })
169 },
170 Err(_) => {
171 finalizers.update_status(EventStatus::Errored);
173 Ok(())
174 }
175 };
176
177 res
178 },
179 else => break,
180 };
181
182 if let Err(error) = result {
183 if is_closed(&error) {
184 emit!(WebSocketConnectionShutdown);
185 } else {
186 emit!(WebSocketConnectionError { error });
187 }
188 return Err(());
189 }
190 }
191
192 Ok(())
193 }
194}
195
196#[async_trait]
197impl StreamSink<Event> for WebSocketSink {
198 async fn run(mut self: Box<Self>, input: BoxStream<'_, Event>) -> Result<(), ()> {
199 let input = input.fuse().peekable();
200 pin_mut!(input);
201
202 while input.as_mut().peek().await.is_some() {
203 let (ws_sink, ws_stream) = self.create_sink_and_stream().await;
204 pin_mut!(ws_sink);
205 pin_mut!(ws_stream);
206
207 let _open_token = OpenGauge::new().open(|count| emit!(ConnectionOpen { count }));
208
209 if self
210 .handle_events(&mut input, &mut ws_stream, &mut ws_sink)
211 .await
212 .is_ok()
213 {
214 _ = ws_sink.close().await;
215 }
216 }
217
218 Ok(())
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use std::net::SocketAddr;
225
226 use futures::{future, FutureExt, StreamExt};
227 use serde_json::Value as JsonValue;
228 use tokio::{time, time::timeout};
229 use tokio_tungstenite::{
230 accept_async, accept_hdr_async,
231 tungstenite::{
232 error::ProtocolError,
233 handshake::server::{Request, Response},
234 },
235 };
236 use vector_lib::codecs::JsonSerializerConfig;
237
238 use super::*;
239 use crate::{
240 common::websocket::WebSocketCommonConfig,
241 config::{SinkConfig, SinkContext},
242 http::Auth,
243 test_util::{
244 components::{run_and_assert_sink_compliance, SINK_TAGS},
245 next_addr, random_lines_with_stream, trace_init, CountReceiver,
246 },
247 tls::{self, MaybeTlsSettings, TlsConfig, TlsEnableableConfig},
248 };
249
250 #[tokio::test(flavor = "multi_thread")]
251 async fn test_websocket() {
252 trace_init();
253
254 let addr = next_addr();
255 let config = WebSocketSinkConfig {
256 common: WebSocketCommonConfig {
257 uri: format!("ws://{addr}"),
258 tls: None,
259 ping_interval: None,
260 ping_timeout: None,
261 auth: None,
262 },
263 encoding: JsonSerializerConfig::default().into(),
264 acknowledgements: Default::default(),
265 };
266 let tls = MaybeTlsSettings::Raw(());
267
268 send_events_and_assert(addr, config, tls, None).await;
269 }
270
271 #[tokio::test(flavor = "multi_thread")]
272 async fn test_auth_websocket() {
273 trace_init();
274
275 let auth = Some(Auth::Bearer {
276 token: "OiJIUzI1NiIsInR5cCI6IkpXVCJ".to_string().into(),
277 });
278 let auth_clone = auth.clone();
279 let addr = next_addr();
280 let config = WebSocketSinkConfig {
281 common: WebSocketCommonConfig {
282 uri: format!("ws://{addr}"),
283 tls: None,
284 ping_interval: None,
285 ping_timeout: None,
286 auth: None,
287 },
288 encoding: JsonSerializerConfig::default().into(),
289 acknowledgements: Default::default(),
290 };
291 let tls = MaybeTlsSettings::Raw(());
292
293 send_events_and_assert(addr, config, tls, auth_clone).await;
294 }
295
296 #[tokio::test(flavor = "multi_thread")]
297 async fn test_tls_websocket() {
298 trace_init();
299
300 let addr = next_addr();
301 let tls_config = Some(TlsEnableableConfig::test_config());
302 let tls = MaybeTlsSettings::from_config(tls_config.as_ref(), true).unwrap();
303
304 let config = WebSocketSinkConfig {
305 common: WebSocketCommonConfig {
306 uri: format!("wss://{addr}"),
307 tls: Some(TlsEnableableConfig {
308 enabled: Some(true),
309 options: TlsConfig {
310 verify_certificate: Some(false),
311 verify_hostname: Some(true),
312 ca_file: Some(tls::TEST_PEM_CRT_PATH.into()),
313 ..Default::default()
314 },
315 }),
316 ping_timeout: None,
317 ping_interval: None,
318 auth: None,
319 },
320 encoding: JsonSerializerConfig::default().into(),
321 acknowledgements: Default::default(),
322 };
323
324 send_events_and_assert(addr, config, tls, None).await;
325 }
326
327 #[tokio::test]
328 async fn test_websocket_reconnect() {
329 trace_init();
330
331 let addr = next_addr();
332 let config = WebSocketSinkConfig {
333 common: WebSocketCommonConfig {
334 uri: format!("ws://{addr}"),
335 tls: None,
336 ping_interval: None,
337 ping_timeout: None,
338 auth: None,
339 },
340 encoding: JsonSerializerConfig::default().into(),
341 acknowledgements: Default::default(),
342 };
343 let tls = MaybeTlsSettings::Raw(());
344
345 let mut receiver = create_count_receiver(addr, tls.clone(), true, None);
346
347 let context = SinkContext::default();
348 let (sink, _healthcheck) = config.build(context).await.unwrap();
349
350 let (_lines, events) = random_lines_with_stream(10, 100, None);
351 let events = events.then(|event| async move {
352 time::sleep(Duration::from_millis(10)).await;
353 event
354 });
355 drop(tokio::spawn(sink.run(events)));
356
357 receiver.connected().await;
358 time::sleep(Duration::from_millis(500)).await;
359 assert!(!receiver.await.is_empty());
360
361 let mut receiver = create_count_receiver(addr, tls, false, None);
362 assert!(timeout(Duration::from_secs(10), receiver.connected())
363 .await
364 .is_ok());
365 }
366
367 async fn send_events_and_assert(
368 addr: SocketAddr,
369 config: WebSocketSinkConfig,
370 tls: MaybeTlsSettings,
371 auth: Option<Auth>,
372 ) {
373 let mut receiver = create_count_receiver(addr, tls, false, auth);
374
375 let context = SinkContext::default();
376 let (sink, _healthcheck) = config.build(context).await.unwrap();
377
378 let (lines, events) = random_lines_with_stream(10, 100, None);
379 run_and_assert_sink_compliance(sink, events, &SINK_TAGS).await;
380
381 receiver.connected().await;
382
383 let output = receiver.await;
384 assert_eq!(lines.len(), output.len());
385 let message_key = crate::config::log_schema()
386 .message_key()
387 .expect("global log_schema.message_key to be valid path")
388 .to_string();
389 for (source, received) in lines.iter().zip(output) {
390 let json = serde_json::from_str::<JsonValue>(&received).expect("Invalid JSON");
391 let received = json.get(message_key.as_str()).unwrap().as_str().unwrap();
392 assert_eq!(source, received);
393 }
394 }
395
396 fn create_count_receiver(
397 addr: SocketAddr,
398 tls: MaybeTlsSettings,
399 interrupt_stream: bool,
400 auth: Option<Auth>,
401 ) -> CountReceiver<String> {
402 CountReceiver::receive_items_stream(move |tripwire, connected| async move {
403 let listener = tls.bind(&addr).await.unwrap();
404 let stream = listener.accept_stream();
405
406 let tripwire = tripwire.map(|_| ()).shared();
407 let stream_tripwire = tripwire.clone();
408 let mut connected = Some(connected);
409
410 let stream = stream
411 .take_until(tripwire)
412 .filter_map(move |maybe_tls_stream| {
413 let au = auth.clone();
414 async move {
415 let maybe_tls_stream = maybe_tls_stream.unwrap();
416 let ws_stream = match au {
417 Some(a) => {
418 let auth_callback = |req: &Request, res: Response| {
419 let hdr = req.headers().get("Authorization");
420 if let Some(h) = hdr {
421 match a {
422 Auth::Bearer { token } => {
423 if format!("Bearer {}", token.inner())
424 != h.to_str().unwrap()
425 {
426 return Err(
427 http::Response::<Option<String>>::new(None),
428 );
429 }
430 }
431 Auth::Basic {
432 user: _user,
433 password: _password,
434 } => { }
435 #[cfg(feature = "aws-core")]
436 _ => {}
437 }
438 }
439 Ok(res)
440 };
441 accept_hdr_async(maybe_tls_stream, auth_callback)
442 .await
443 .unwrap()
444 }
445 None => accept_async(maybe_tls_stream).await.unwrap(),
446 };
447
448 Some(
449 ws_stream
450 .filter_map(|msg| {
451 future::ready(match msg {
452 Ok(msg) if msg.is_text() => {
453 Some(Ok(msg.into_text().unwrap()))
454 }
455 Err(TungsteniteError::Protocol(
456 ProtocolError::ResetWithoutClosingHandshake,
457 )) => None,
458 Err(e) => Some(Err(e)),
459 _ => None,
460 })
461 })
462 .take_while(|msg| future::ready(msg.is_ok()))
463 .filter_map(|msg| future::ready(msg.ok())),
464 )
465 }
466 })
467 .map(move |ws_stream| {
468 connected.take().map(|trigger| trigger.send(()));
469 ws_stream
470 })
471 .flatten();
472
473 match interrupt_stream {
474 false => stream.boxed(),
475 true => stream.take_until(stream_tripwire).boxed(),
476 }
477 })
478 }
479}