1#[cfg(unix)]
2use std::os::unix::{fs::PermissionsExt, io::AsRawFd};
3use std::{
4 convert::TryInto,
5 fs,
6 marker::{Send, Sync},
7 net::SocketAddr,
8 path::PathBuf,
9 sync::{
10 Arc, Mutex,
11 atomic::{AtomicUsize, Ordering},
12 },
13 time::Duration,
14};
15
16use bytes::{Buf, Bytes, BytesMut};
17use futures::{
18 executor::block_on,
19 future::{self, OptionFuture},
20 sink::{Sink, SinkExt},
21 stream::{self, StreamExt, TryStreamExt},
22};
23use futures_util::{Future, FutureExt, future::BoxFuture};
24use ipnet::IpNet;
25use listenfd::ListenFd;
26use tokio::{
27 self,
28 io::{AsyncRead, AsyncWrite},
29 net::{TcpStream, UnixListener},
30 task::JoinHandle,
31 time::sleep,
32};
33use tokio_stream::wrappers::UnixListenerStream;
34use tokio_util::codec::{Framed, length_delimited};
35use tracing::{Instrument, Span, field};
36use vector_lib::{
37 lookup::OwnedValuePath,
38 tcp::TcpKeepaliveConfig,
39 tls::{CertificateMetadata, MaybeTlsIncomingStream, MaybeTlsSettings},
40};
41
42use super::net::{RequestLimiter, SocketListenAddr};
43use crate::{
44 SourceSender,
45 event::Event,
46 internal_events::{
47 ConnectionOpen, OpenGauge, SocketBindError, SocketMode, SocketReceiveError,
48 TcpBytesReceived, TcpSocketError, TcpSocketTlsConnectionError, UnixSocketError,
49 UnixSocketFileDeleteError,
50 },
51 shutdown::ShutdownSignal,
52 sources::{
53 Source,
54 util::{
55 AfterReadExt,
56 net::{MAX_IN_FLIGHT_EVENTS_TARGET, try_bind_tcp_listener},
57 },
58 },
59};
60
61const FSTRM_CONTROL_FRAME_LENGTH_MAX: usize = 512;
62const FSTRM_CONTROL_FIELD_CONTENT_TYPE_LENGTH_MAX: usize = 256;
63
64const PERMIT_HOLD_TIMEOUT_MS: u64 = 10;
69
70pub type FrameStreamSink = Box<dyn Sink<Bytes, Error = std::io::Error> + Send + Unpin>;
71
72pub struct FrameStreamReader {
73 response_sink: Mutex<FrameStreamSink>,
74 expected_content_type: String,
75 state: FrameStreamState,
76}
77
78struct FrameStreamState {
79 expect_control_frame: bool,
80 control_state: ControlState,
81 is_bidirectional: bool,
82}
83impl FrameStreamState {
84 const fn new() -> Self {
85 FrameStreamState {
86 expect_control_frame: false,
87 control_state: ControlState::Initial,
89 is_bidirectional: true, }
91 }
92}
93
94#[derive(PartialEq, Debug)]
95enum ControlState {
96 Initial,
97 GotReady,
98 ReadingData,
99 Stopped,
100}
101
102#[derive(Copy, Clone)]
103enum ControlHeader {
104 Accept,
105 Start,
106 Stop,
107 Ready,
108 Finish,
109}
110
111impl ControlHeader {
112 fn from_u32(val: u32) -> Result<Self, ()> {
113 match val {
114 0x01 => Ok(ControlHeader::Accept),
115 0x02 => Ok(ControlHeader::Start),
116 0x03 => Ok(ControlHeader::Stop),
117 0x04 => Ok(ControlHeader::Ready),
118 0x05 => Ok(ControlHeader::Finish),
119 _ => {
120 error!("Don't know header value {} (expected 0x01 - 0x05).", val);
121 Err(())
122 }
123 }
124 }
125
126 const fn to_u32(self) -> u32 {
127 match self {
128 ControlHeader::Accept => 0x01,
129 ControlHeader::Start => 0x02,
130 ControlHeader::Stop => 0x03,
131 ControlHeader::Ready => 0x04,
132 ControlHeader::Finish => 0x05,
133 }
134 }
135}
136
137enum ControlField {
138 ContentType,
139}
140
141impl ControlField {
142 fn from_u32(val: u32) -> Result<Self, ()> {
143 match val {
144 0x01 => Ok(ControlField::ContentType),
145 _ => {
146 error!("Don't know field type {} (expected 0x01).", val);
147 Err(())
148 }
149 }
150 }
151 const fn to_u32(&self) -> u32 {
152 match self {
153 ControlField::ContentType => 0x01,
154 }
155 }
156}
157
158fn advance_u32(b: &mut Bytes) -> Result<u32, ()> {
159 if b.len() < 4 {
160 error!("Malformed frame.");
161 return Err(());
162 }
163 let a = b.split_to(4);
164 Ok(u32::from_be_bytes(a[..].try_into().unwrap()))
165}
166
167impl FrameStreamReader {
168 pub fn new(response_sink: FrameStreamSink, expected_content_type: String) -> Self {
169 FrameStreamReader {
170 response_sink: Mutex::new(response_sink),
171 expected_content_type,
172 state: FrameStreamState::new(),
173 }
174 }
175
176 pub fn handle_frame(&mut self, frame: Bytes) -> Option<Bytes> {
177 if frame.is_empty() {
178 self.state.expect_control_frame = true;
180 None
181 } else if self.state.expect_control_frame {
182 self.state.expect_control_frame = false;
183 _ = self.handle_control_frame(frame);
184 None
185 } else {
186 if self.state.control_state == ControlState::ReadingData {
188 Some(frame) } else {
190 error!(
191 "Received a data frame while in state {:?}.",
192 self.state.control_state
193 );
194 None
195 }
196 }
197 }
198
199 fn handle_control_frame(&mut self, mut frame: Bytes) -> Result<(), ()> {
200 if frame.len() > FSTRM_CONTROL_FRAME_LENGTH_MAX {
202 error!("Control frame is too long.");
203 }
204
205 let header = ControlHeader::from_u32(advance_u32(&mut frame)?)?;
206
207 match self.state.control_state {
209 ControlState::Initial => {
210 match header {
211 ControlHeader::Ready => {
212 let content_type = self.process_fields(header, &mut frame)?.unwrap();
213
214 self.send_control_frame(Self::make_frame(
215 ControlHeader::Accept,
216 Some(content_type),
217 ));
218 self.state.control_state = ControlState::GotReady; }
220 ControlHeader::Start => {
221 _ = self.process_fields(header, &mut frame)?;
223 self.state.control_state = ControlState::ReadingData;
225 self.state.is_bidirectional = false; }
227 _ => error!("Got wrong control frame, expected READY."),
228 }
229 }
230 ControlState::GotReady => {
231 match header {
232 ControlHeader::Start => {
233 _ = self.process_fields(header, &mut frame)?;
235 self.state.control_state = ControlState::ReadingData;
237 }
238 _ => error!("Got wrong control frame, expected START."),
239 }
240 }
241 ControlState::ReadingData => {
242 match header {
243 ControlHeader::Stop => {
244 _ = self.process_fields(header, &mut frame)?;
246 if self.state.is_bidirectional {
247 self.send_control_frame(Self::make_frame(ControlHeader::Finish, None));
249 }
250 self.state.control_state = ControlState::Stopped; }
252 _ => error!("Got wrong control frame, expected STOP."),
253 }
254 }
255 ControlState::Stopped => error!("Unexpected control frame, current state is STOPPED."),
256 };
257 Ok(())
258 }
259
260 fn process_fields(
261 &mut self,
262 header: ControlHeader,
263 frame: &mut Bytes,
264 ) -> Result<Option<String>, ()> {
265 match header {
266 ControlHeader::Ready => {
267 let is_start_frame = false;
270 let content_type = self.process_content_type(frame, is_start_frame)?;
271 Ok(Some(content_type))
272 }
273 ControlHeader::Start => {
274 if frame.is_empty() {
276 Ok(None)
277 } else {
278 let is_start_frame = true;
280 let content_type = self.process_content_type(frame, is_start_frame)?;
281 Ok(Some(content_type))
282 }
283 }
284 ControlHeader::Stop => {
285 if !frame.is_empty() {
287 error!("Unexpected fields in STOP header.");
288 Err(())
289 } else {
290 Ok(None)
291 }
292 }
293 _ => {
294 error!("Unexpected control header value {:?}.", header.to_u32());
295 Err(())
296 }
297 }
298 }
299
300 fn process_content_type(&self, frame: &mut Bytes, is_start_frame: bool) -> Result<String, ()> {
301 if frame.is_empty() {
302 error!("No fields in control frame.");
303 return Err(());
304 }
305
306 let mut content_types = vec![];
307 while !frame.is_empty() {
308 let field_val = advance_u32(frame)?;
310 let field_type = ControlField::from_u32(field_val)?;
311 match field_type {
312 ControlField::ContentType => {
313 let field_len = advance_u32(frame)? as usize;
315
316 if field_len > FSTRM_CONTROL_FIELD_CONTENT_TYPE_LENGTH_MAX {
318 error!("Content-Type string is too long.");
319 return Err(());
320 }
321
322 let content_type = std::str::from_utf8(&frame[..field_len]).unwrap();
323 content_types.push(content_type.to_string());
324 frame.advance(field_len);
325 }
326 }
327 }
328
329 if is_start_frame && content_types.len() > 1 {
330 error!(
331 "START control frame can only have one content-type provided (got {}).",
332 content_types.len()
333 );
334 return Err(());
335 }
336
337 for content_type in &content_types {
338 if *content_type == self.expected_content_type {
339 return Ok(content_type.clone());
340 }
341 }
342
343 error!(
344 "Content types did not match up. Expected {} got {:?}.",
345 self.expected_content_type, content_types
346 );
347 Err(())
348 }
349
350 fn make_frame(header: ControlHeader, content_type: Option<String>) -> Bytes {
351 let mut frame = BytesMut::new();
352 frame.extend(header.to_u32().to_be_bytes());
353 if let Some(s) = content_type {
354 frame.extend(ControlField::ContentType.to_u32().to_be_bytes()); frame.extend((s.len() as u32).to_be_bytes()); frame.extend(s.as_bytes());
357 }
358 Bytes::from(frame)
359 }
360
361 fn send_control_frame(&mut self, frame: Bytes) {
362 let empty_frame = Bytes::from(&b""[..]); let mut stream = stream::iter(vec![Ok(empty_frame), Ok(frame)]);
364
365 if let Err(e) = block_on(self.response_sink.lock().unwrap().send_all(&mut stream)) {
366 error!("Encountered error '{:#?}' while sending control frame.", e);
367 }
368 }
369}
370
371pub trait FrameHandler {
372 fn content_type(&self) -> String;
373 fn max_frame_length(&self) -> usize;
374 fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event>;
375 fn multithreaded(&self) -> bool;
376 fn max_frame_handling_tasks(&self) -> usize;
377 fn host_key(&self) -> &Option<OwnedValuePath>;
378 fn timestamp_key(&self) -> Option<&OwnedValuePath>;
379 fn source_type_key(&self) -> Option<&OwnedValuePath>;
380}
381
382pub trait UnixFrameHandler: FrameHandler {
383 fn socket_path(&self) -> PathBuf;
384 fn socket_file_mode(&self) -> Option<u32>;
385 fn socket_receive_buffer_size(&self) -> Option<usize>;
386 fn socket_send_buffer_size(&self) -> Option<usize>;
387}
388
389pub trait TcpFrameHandler: FrameHandler {
390 fn address(&self) -> SocketListenAddr;
391 fn keepalive(&self) -> Option<TcpKeepaliveConfig>;
392 fn shutdown_timeout_secs(&self) -> Duration;
393 fn tls(&self) -> MaybeTlsSettings;
394 fn tls_client_metadata_key(&self) -> Option<OwnedValuePath>;
395 fn receive_buffer_bytes(&self) -> Option<usize>;
396 fn max_connection_duration_secs(&self) -> Option<u64>;
397 fn max_connections(&self) -> Option<u32>;
398 fn allowed_origins(&self) -> Option<&[IpNet]>;
399 fn insert_tls_client_metadata(&mut self, metadata: Option<CertificateMetadata>);
400}
401
402pub fn build_framestream_tcp_source(
407 frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
408 shutdown: ShutdownSignal,
409 out: SourceSender,
410) -> crate::Result<Source> {
411 let addr = frame_handler.address();
412 let tls = frame_handler.tls();
413 let shutdown = shutdown.clone();
414 let out = out.clone();
415
416 Ok(Box::pin(async move {
417 let listenfd = ListenFd::from_env();
418 let listener = try_bind_tcp_listener(
419 addr,
420 listenfd,
421 &tls,
422 frame_handler
423 .allowed_origins()
424 .map(|origins| origins.to_vec()),
425 )
426 .await
427 .map_err(|error| {
428 emit!(SocketBindError {
429 mode: SocketMode::Tcp,
430 error: &error,
431 })
432 })?;
433
434 info!(
435 message = "Listening.",
436 addr = %listener
437 .local_addr()
438 .map(SocketListenAddr::SocketAddr)
439 .unwrap_or(addr)
440 );
441
442 let tripwire = shutdown.clone();
443 let shutdown_timeout_secs = frame_handler.shutdown_timeout_secs();
444 let tripwire = async move {
445 _ = tripwire.await;
446 sleep(shutdown_timeout_secs).await;
447 }
448 .shared();
449
450 let connection_gauge = OpenGauge::new();
451 let shutdown_clone = shutdown.clone();
452
453 let request_limiter = RequestLimiter::new(
454 MAX_IN_FLIGHT_EVENTS_TARGET,
455 frame_handler.max_frame_handling_tasks(),
456 );
457
458 listener
459 .accept_stream_limited(frame_handler.max_connections())
460 .take_until(shutdown_clone)
461 .for_each(move |(connection, tcp_connection_permit)| {
462 let shutdown_signal = shutdown.clone();
463 let tripwire = tripwire.clone();
464 let out = out.clone();
465 let connection_gauge = connection_gauge.clone();
466 let request_limiter = request_limiter.clone();
467 let frame_handler_clone = frame_handler.clone();
468
469 async move {
470 let socket = match connection {
471 Ok(socket) => socket,
472 Err(error) => {
473 emit!(SocketReceiveError {
474 mode: SocketMode::Tcp,
475 error: &error
476 });
477 return;
478 }
479 };
480
481 let peer_addr = socket.peer_addr();
482 let span = info_span!("connection", %peer_addr);
483
484 let tripwire = tripwire
485 .map(move |_| {
486 info!(
487 message = "Resetting connection (still open after seconds).",
488 seconds = ?shutdown_timeout_secs
489 );
490 })
491 .boxed();
492
493 span.clone().in_scope(|| {
494 debug!(message = "Accepted a new connection.", peer_addr = %peer_addr);
495
496 let open_token =
497 connection_gauge.open(|count| emit!(ConnectionOpen { count }));
498
499 let fut = handle_stream(
500 frame_handler_clone,
501 shutdown_signal,
502 socket,
503 tripwire,
504 peer_addr,
505 out,
506 request_limiter,
507 );
508
509 tokio::spawn(
510 fut.map(move |()| {
511 drop(open_token);
512 drop(tcp_connection_permit);
513 })
514 .instrument(span.or_current()),
515 );
516 });
517 }
518 })
519 .map(Ok)
520 .await
521 }))
522}
523
524#[allow(clippy::too_many_arguments)]
525async fn handle_stream(
526 mut frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
527 mut shutdown_signal: ShutdownSignal,
528 mut socket: MaybeTlsIncomingStream<TcpStream>,
529 mut tripwire: BoxFuture<'static, ()>,
530 peer_addr: SocketAddr,
531 out: SourceSender,
532 request_limiter: RequestLimiter,
533) {
534 tokio::select! {
535 result = socket.handshake() => {
536 if let Err(error) = result {
537 emit!(TcpSocketTlsConnectionError { error });
538 return;
539 }
540 },
541 _ = &mut shutdown_signal => {
542 return;
543 }
544 };
545
546 if let Some(keepalive) = frame_handler.keepalive()
547 && let Err(error) = socket.set_keepalive(keepalive)
548 {
549 warn!(message = "Failed configuring TCP keepalive.", %error);
550 }
551
552 if let Some(receive_buffer_bytes) = frame_handler.receive_buffer_bytes()
553 && let Err(error) = socket.set_receive_buffer_bytes(receive_buffer_bytes)
554 {
555 warn!(message = "Failed configuring receive buffer size on TCP socket.", %error);
556 }
557
558 let socket = socket.after_read(move |byte_size| {
559 emit!(TcpBytesReceived {
560 byte_size,
561 peer_addr
562 });
563 });
564
565 let certificate_metadata = socket
566 .get_ref()
567 .ssl_stream()
568 .and_then(|stream| stream.ssl().peer_certificate())
569 .map(CertificateMetadata::from);
570
571 frame_handler.insert_tls_client_metadata(certificate_metadata);
572
573 let span = info_span!("connection");
574 span.record("peer_addr", field::debug(&peer_addr));
575 let received_from: Option<Bytes> = Some(peer_addr.to_string().into());
576
577 let connection_close_timeout = OptionFuture::from(
578 frame_handler
579 .max_connection_duration_secs()
580 .map(|timeout_secs| tokio::time::sleep(Duration::from_secs(timeout_secs))),
581 );
582 tokio::pin!(connection_close_timeout);
583
584 let content_type = frame_handler.content_type();
585 let mut event_sink = out.clone();
586 let (sock_sink, sock_stream) = Framed::new(
587 socket,
588 length_delimited::Builder::new()
589 .max_frame_length(frame_handler.max_frame_length())
590 .new_codec(),
591 )
592 .split();
593 let mut reader = FrameStreamReader::new(Box::new(sock_sink), content_type);
594 let mut frames = sock_stream
595 .map_err(move |error| {
596 emit!(TcpSocketError {
597 error: &error,
598 peer_addr,
599 });
600 })
601 .filter_map(move |frame| {
602 future::ready(match frame {
603 Ok(f) => reader.handle_frame(Bytes::from(f)),
604 Err(_) => None,
605 })
606 });
607
608 let active_parsing_task_nums = Arc::new(AtomicUsize::new(0));
609 loop {
610 let mut permit = tokio::select! {
611 _ = &mut tripwire => break,
612 Some(_) = &mut connection_close_timeout => {
613 break;
614 },
615 _ = &mut shutdown_signal => {
616 break;
617 },
618 permit = request_limiter.acquire() => {
619 Some(permit)
620 }
621 else => break,
622 };
623
624 let timeout = tokio::time::sleep(Duration::from_millis(PERMIT_HOLD_TIMEOUT_MS));
625 tokio::pin!(timeout);
626
627 tokio::select! {
628 _ = &mut tripwire => break,
629 _ = &mut shutdown_signal => break,
630 _ = &mut timeout => {
631 continue;
634 }
635 res = frames.next() => {
636 match res {
637 Some(frame) => {
638 if let Some(permit) = &mut permit {
639 permit.decoding_finished(1);
643 };
644 handle_tcp_frame(&mut frame_handler, frame, &mut event_sink, received_from.clone(), Arc::clone(&active_parsing_task_nums)).await;
645 }
646 None => {
647 debug!("Connection closed.");
648 break
649 },
650 }
651 }
652 else => break,
653 }
654
655 drop(permit);
656 }
657}
658
659async fn handle_tcp_frame<T>(
660 frame_handler: &mut T,
661 frame: Bytes,
662 event_sink: &mut SourceSender,
663 received_from: Option<Bytes>,
664 active_parsing_task_nums: Arc<AtomicUsize>,
665) where
666 T: TcpFrameHandler + Send + Sync + Clone + 'static,
667{
668 if frame_handler.multithreaded() {
669 spawn_event_handling_tasks(
670 frame,
671 frame_handler.clone(),
672 event_sink.clone(),
673 received_from,
674 active_parsing_task_nums,
675 frame_handler.max_frame_handling_tasks(),
676 )
677 .await;
678 } else if let Some(event) = frame_handler.handle_event(received_from, frame)
679 && let Err(e) = event_sink.send_event(event).await
680 {
681 error!("Error sending event: {e:?}.");
682 }
683}
684
685pub fn build_framestream_unix_source(
691 frame_handler: impl UnixFrameHandler + Send + Sync + Clone + 'static,
692 shutdown: ShutdownSignal,
693 out: SourceSender,
694) -> crate::Result<Source> {
695 let path = frame_handler.socket_path();
696
697 match fs::metadata(&path) {
699 Ok(_) => {
700 info!(message = "Deleting file.", ?path);
702 fs::remove_file(&path)?;
703 }
704 Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {} Err(e) => {
706 error!("Unable to get socket information; error = {:?}.", e);
707 return Err(Box::new(e));
708 }
709 };
710
711 let listener = UnixListener::bind(&path)?;
712
713 if let Some(socket_receive_buffer_size) = frame_handler.socket_receive_buffer_size() {
715 _ = nix::sys::socket::setsockopt(
716 listener.as_raw_fd(),
717 nix::sys::socket::sockopt::RcvBuf,
718 &(socket_receive_buffer_size),
719 );
720 let rcv_buf_size =
721 nix::sys::socket::getsockopt(listener.as_raw_fd(), nix::sys::socket::sockopt::RcvBuf);
722 info!(
723 "Unix socket receive buffer size modified to {}.",
724 rcv_buf_size.unwrap()
725 );
726 }
727
728 if let Some(socket_send_buffer_size) = frame_handler.socket_send_buffer_size() {
730 _ = nix::sys::socket::setsockopt(
731 listener.as_raw_fd(),
732 nix::sys::socket::sockopt::SndBuf,
733 &(socket_send_buffer_size),
734 );
735 let snd_buf_size =
736 nix::sys::socket::getsockopt(listener.as_raw_fd(), nix::sys::socket::sockopt::SndBuf);
737 info!(
738 "Unix socket buffer send size modified to {}.",
739 snd_buf_size.unwrap()
740 );
741 }
742
743 if let Some(socket_permission) = frame_handler.socket_file_mode() {
745 if !(448..=511).contains(&socket_permission) {
746 return Err(format!(
747 "Invalid Socket permission {socket_permission:#o}. Must between 0o700 and 0o777."
748 )
749 .into());
750 }
751 match fs::set_permissions(&path, fs::Permissions::from_mode(socket_permission)) {
752 Ok(_) => {
753 info!("Socket permissions updated to {:#o}.", socket_permission);
754 }
755 Err(e) => {
756 error!(
757 "Failed to update listener socket permissions; error = {:?}.",
758 e
759 );
760 return Err(Box::new(e));
761 }
762 };
763 };
764
765 let fut = async move {
766 let active_parsing_task_nums = Arc::new(AtomicUsize::new(0));
767
768 info!(message = "Listening...", ?path, r#type = "unix");
769
770 let mut stream = UnixListenerStream::new(listener).take_until(shutdown.clone());
771 while let Some(socket) = stream.next().await {
772 let socket = match socket {
773 Err(e) => {
774 error!("Failed to accept socket; error = {:?}.", e);
775 continue;
776 }
777 Ok(s) => s,
778 };
779 let peer_addr = socket.peer_addr().ok();
780 let listen_path = path.clone();
781 let active_task_nums_ = Arc::clone(&active_parsing_task_nums);
782
783 let span = info_span!("connection");
784 let path = if let Some(addr) = peer_addr {
785 if let Some(path) = addr.as_pathname().map(|e| e.to_owned()) {
786 span.record("peer_path", field::debug(&path));
787 Some(path)
788 } else {
789 None
790 }
791 } else {
792 None
793 };
794 let received_from: Option<Bytes> =
795 path.map(|p| p.to_string_lossy().into_owned().into());
796
797 build_framestream_source(
798 frame_handler.clone(),
799 socket,
800 received_from,
801 out.clone(),
802 shutdown.clone(),
803 span,
804 active_task_nums_,
805 move |error| {
806 emit!(UnixSocketError {
807 error: &error,
808 path: &listen_path,
809 });
810 },
811 );
812 }
813
814 drop(stream);
816
817 if let Err(error) = fs::remove_file(&path) {
819 emit!(UnixSocketFileDeleteError { path: &path, error });
820 }
821
822 Ok(())
823 };
824
825 Ok(Box::pin(fut))
826}
827
828#[allow(clippy::too_many_arguments)]
829fn build_framestream_source<T: Send + 'static>(
830 frame_handler: impl FrameHandler + Send + Sync + Clone + 'static,
831 socket: impl AsyncRead + AsyncWrite + Send + 'static,
832 received_from: Option<Bytes>,
833 out: SourceSender,
834 shutdown: impl Future<Output = T> + Unpin + Send + 'static,
835 span: Span,
836 active_task_nums: Arc<AtomicUsize>,
837 error_mapper: impl FnMut(std::io::Error) + Send + 'static,
838) {
839 let content_type = frame_handler.content_type();
840 let mut event_sink = out.clone();
841 let (sock_sink, sock_stream) = Framed::new(
842 socket,
843 length_delimited::Builder::new()
844 .max_frame_length(frame_handler.max_frame_length())
845 .new_codec(),
846 )
847 .split();
848 let mut fs_reader = FrameStreamReader::new(Box::new(sock_sink), content_type);
849 let frame_handler_copy = frame_handler.clone();
850 let frames = sock_stream
851 .take_until(shutdown)
852 .map_err(error_mapper)
853 .filter_map(move |frame| {
854 future::ready(match frame {
855 Ok(f) => fs_reader.handle_frame(Bytes::from(f)),
856 Err(_) => None,
857 })
858 });
859 if !frame_handler.multithreaded() {
860 let mut events = frames.filter_map(move |f| {
861 future::ready(frame_handler_copy.handle_event(received_from.clone(), f))
862 });
863
864 let handler = async move {
865 if let Err(e) = event_sink.send_event_stream(&mut events).await {
866 error!("Error sending event: {:?}.", e);
867 }
868
869 info!("Finished sending.");
870 };
871 tokio::spawn(handler.instrument(span.or_current()));
872 } else {
873 let handler = async move {
874 frames
875 .for_each(move |f| {
876 let max_frame_handling_tasks = frame_handler_copy.max_frame_handling_tasks();
877 let f_handler = frame_handler_copy.clone();
878 let received_from_copy = received_from.clone();
879 let event_sink_copy = event_sink.clone();
880 let active_task_nums_copy = Arc::clone(&active_task_nums);
881
882 async move {
883 spawn_event_handling_tasks(
884 f,
885 f_handler,
886 event_sink_copy,
887 received_from_copy,
888 active_task_nums_copy,
889 max_frame_handling_tasks,
890 )
891 .await;
892 }
893 })
894 .await;
895 info!("Finished sending.");
896 };
897 tokio::spawn(handler.instrument(span.or_current()));
898 }
899}
900
901async fn spawn_event_handling_tasks(
902 event_data: Bytes,
903 event_handler: impl FrameHandler + Send + Sync + 'static,
904 mut event_sink: SourceSender,
905 received_from: Option<Bytes>,
906 active_task_nums: Arc<AtomicUsize>,
907 max_frame_handling_tasks: usize,
908) -> JoinHandle<()> {
909 wait_for_task_quota(&active_task_nums, max_frame_handling_tasks).await;
910
911 tokio::spawn(async move {
912 future::ready({
913 if let Some(evt) = event_handler.handle_event(received_from, event_data)
914 && event_sink.send_event(evt).await.is_err()
915 {
916 error!("Encountered error while sending event.");
917 }
918 active_task_nums.fetch_sub(1, Ordering::AcqRel);
919 })
920 .await;
921 })
922}
923
924async fn wait_for_task_quota(active_task_nums: &Arc<AtomicUsize>, max_tasks: usize) {
925 while max_tasks > 0 && max_tasks < active_task_nums.load(Ordering::Acquire) {
926 tokio::time::sleep(Duration::from_millis(3)).await;
927 }
928 active_task_nums.fetch_add(1, Ordering::AcqRel);
929}
930
931#[cfg(test)]
932mod test {
933 use std::net::SocketAddr;
934 #[cfg(unix)]
935 use std::{
936 path::PathBuf,
937 sync::{
938 Arc,
939 atomic::{AtomicUsize, Ordering},
940 },
941 thread,
942 };
943
944 use bytes::{Bytes, BytesMut, buf::Buf};
945 use futures::{
946 future,
947 sink::{Sink, SinkExt},
948 stream::{self, StreamExt},
949 };
950 use futures_util::Stream;
951 use ipnet::IpNet;
952 use tokio::{
953 self,
954 net::{TcpStream, UnixStream},
955 task::JoinHandle,
956 time::{Duration, Instant},
957 };
958 use tokio_util::codec::{Framed, length_delimited};
959 use vector_lib::{
960 config::{LegacyKey, LogNamespace},
961 lookup::{OwnedValuePath, owned_value_path, path},
962 tcp::TcpKeepaliveConfig,
963 tls::{CertificateMetadata, MaybeTls, MaybeTlsSettings},
964 };
965
966 use super::{
967 ControlField, ControlHeader, FrameHandler, TcpFrameHandler, UnixFrameHandler,
968 build_framestream_tcp_source, build_framestream_unix_source, spawn_event_handling_tasks,
969 };
970 use crate::{
971 SourceSender,
972 config::{ComponentKey, log_schema},
973 event::{Event, LogEvent},
974 shutdown::SourceShutdownCoordinator,
975 sources::util::net::SocketListenAddr,
976 test_util::{collect_n, collect_n_stream, next_addr},
977 };
978
979 #[derive(Clone)]
980 struct MockFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
981 content_type: String,
982 max_frame_length: usize,
983 multithreaded: bool,
984 max_frame_handling_tasks: usize,
985 extra_task_handling_routine: F,
986 host_key: Option<OwnedValuePath>,
987 timestamp_key: Option<OwnedValuePath>,
988 source_type_key: Option<OwnedValuePath>,
989 log_namespace: LogNamespace,
990 }
991
992 #[derive(Clone)]
993 struct MockUnixFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
994 frame_handler: MockFrameHandler<F>,
995 socket_path: PathBuf,
996 socket_file_mode: Option<u32>,
997 socket_receive_buffer_size: Option<usize>,
998 socket_send_buffer_size: Option<usize>,
999 }
1000
1001 #[derive(Clone)]
1002 struct MockTcpFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
1003 frame_handler: MockFrameHandler<F>,
1004 address: SocketListenAddr,
1005 keepalive: Option<TcpKeepaliveConfig>,
1006 shutdown_timeout_secs: Duration,
1007 tls: MaybeTlsSettings,
1008 tls_client_metadata_key: Option<OwnedValuePath>,
1009 receive_buffer_bytes: Option<usize>,
1010 max_connection_duration_secs: Option<u64>,
1011 max_connections: Option<u32>,
1012 permit_origin: Option<Vec<IpNet>>,
1013 }
1014
1015 impl<F: Send + Sync + Clone + FnOnce() + 'static> MockTcpFrameHandler<F> {
1016 pub fn new(
1017 addr: SocketAddr,
1018 content_type: String,
1019 multithreaded: bool,
1020 extra_routine: F,
1021 permit_origin: Option<Vec<IpNet>>,
1022 ) -> Self {
1023 Self {
1024 frame_handler: MockFrameHandler::new(content_type, multithreaded, extra_routine),
1025 address: addr.into(),
1026 keepalive: None,
1027 shutdown_timeout_secs: Duration::from_secs(30),
1028 tls: MaybeTls::Raw(()),
1029 tls_client_metadata_key: None,
1030 receive_buffer_bytes: None,
1031 max_connection_duration_secs: None,
1032 max_connections: None,
1033 permit_origin,
1034 }
1035 }
1036 }
1037
1038 impl<F: Send + Sync + Clone + FnOnce() + 'static> MockUnixFrameHandler<F> {
1039 pub fn new(content_type: String, multithreaded: bool, extra_routine: F) -> Self {
1040 Self {
1041 frame_handler: MockFrameHandler::new(content_type, multithreaded, extra_routine),
1042 socket_path: tempfile::tempdir().unwrap().keep().join("unix_test"),
1043 socket_file_mode: None,
1044 socket_receive_buffer_size: None,
1045 socket_send_buffer_size: None,
1046 }
1047 }
1048 }
1049
1050 impl<F: Send + Sync + Clone + FnOnce() + 'static> MockFrameHandler<F> {
1051 pub fn new(content_type: String, multithreaded: bool, extra_routine: F) -> Self {
1052 Self {
1053 content_type,
1054 max_frame_length: bytesize::kib(100u64) as usize,
1055 multithreaded,
1056 max_frame_handling_tasks: 0,
1057 extra_task_handling_routine: extra_routine,
1058 host_key: Some(owned_value_path!("test_framestream")),
1059 timestamp_key: Some(owned_value_path!("my_timestamp")),
1060 source_type_key: Some(owned_value_path!("source_type")),
1061 log_namespace: LogNamespace::Legacy,
1062 }
1063 }
1064 }
1065
1066 impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockFrameHandler<F> {
1067 fn content_type(&self) -> String {
1068 self.content_type.clone()
1069 }
1070 fn max_frame_length(&self) -> usize {
1071 self.max_frame_length
1072 }
1073
1074 fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1075 let mut log_event = LogEvent::from(frame);
1076
1077 log_event.insert(
1078 log_schema().source_type_key_target_path().unwrap(),
1079 "framestream",
1080 );
1081 if let Some(host) = received_from {
1082 self.log_namespace.insert_source_metadata(
1083 "framestream",
1084 &mut log_event,
1085 self.host_key.as_ref().map(LegacyKey::Overwrite),
1086 path!("host"),
1087 host,
1088 )
1089 }
1090
1091 (self.extra_task_handling_routine.clone())();
1092
1093 Some(log_event.into())
1094 }
1095
1096 fn multithreaded(&self) -> bool {
1097 self.multithreaded
1098 }
1099 fn max_frame_handling_tasks(&self) -> usize {
1100 self.max_frame_handling_tasks
1101 }
1102
1103 fn host_key(&self) -> &Option<OwnedValuePath> {
1104 &self.host_key
1105 }
1106
1107 fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1108 self.timestamp_key.as_ref()
1109 }
1110
1111 fn source_type_key(&self) -> Option<&OwnedValuePath> {
1112 self.source_type_key.as_ref()
1113 }
1114 }
1115
1116 impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockUnixFrameHandler<F> {
1117 fn content_type(&self) -> String {
1118 self.frame_handler.content_type()
1119 }
1120
1121 fn max_frame_length(&self) -> usize {
1122 self.frame_handler.max_frame_length()
1123 }
1124
1125 fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1126 self.frame_handler.handle_event(received_from, frame)
1127 }
1128
1129 fn multithreaded(&self) -> bool {
1130 self.frame_handler.multithreaded()
1131 }
1132
1133 fn max_frame_handling_tasks(&self) -> usize {
1134 self.frame_handler.max_frame_handling_tasks()
1135 }
1136
1137 fn host_key(&self) -> &Option<OwnedValuePath> {
1138 self.frame_handler.host_key()
1139 }
1140
1141 fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1142 self.frame_handler.timestamp_key()
1143 }
1144
1145 fn source_type_key(&self) -> Option<&OwnedValuePath> {
1146 self.frame_handler.source_type_key()
1147 }
1148 }
1149
1150 impl<F: Send + Sync + Clone + FnOnce() + 'static> UnixFrameHandler for MockUnixFrameHandler<F> {
1151 fn socket_path(&self) -> PathBuf {
1152 self.socket_path.clone()
1153 }
1154
1155 fn socket_file_mode(&self) -> Option<u32> {
1156 self.socket_file_mode
1157 }
1158
1159 fn socket_receive_buffer_size(&self) -> Option<usize> {
1160 self.socket_receive_buffer_size
1161 }
1162
1163 fn socket_send_buffer_size(&self) -> Option<usize> {
1164 self.socket_send_buffer_size
1165 }
1166 }
1167
1168 impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockTcpFrameHandler<F> {
1169 fn content_type(&self) -> String {
1170 self.frame_handler.content_type()
1171 }
1172
1173 fn max_frame_length(&self) -> usize {
1174 self.frame_handler.max_frame_length()
1175 }
1176
1177 fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1178 self.frame_handler.handle_event(received_from, frame)
1179 }
1180
1181 fn multithreaded(&self) -> bool {
1182 self.frame_handler.multithreaded()
1183 }
1184
1185 fn max_frame_handling_tasks(&self) -> usize {
1186 self.frame_handler.max_frame_handling_tasks()
1187 }
1188
1189 fn host_key(&self) -> &Option<OwnedValuePath> {
1190 self.frame_handler.host_key()
1191 }
1192
1193 fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1194 self.frame_handler.timestamp_key()
1195 }
1196
1197 fn source_type_key(&self) -> Option<&OwnedValuePath> {
1198 self.frame_handler.source_type_key()
1199 }
1200 }
1201
1202 impl<F: Send + Sync + Clone + FnOnce() + 'static> TcpFrameHandler for MockTcpFrameHandler<F> {
1203 fn address(&self) -> SocketListenAddr {
1204 self.address
1205 }
1206
1207 fn keepalive(&self) -> Option<TcpKeepaliveConfig> {
1208 self.keepalive
1209 }
1210
1211 fn shutdown_timeout_secs(&self) -> Duration {
1212 self.shutdown_timeout_secs
1213 }
1214
1215 fn tls(&self) -> MaybeTlsSettings {
1216 self.tls.clone()
1217 }
1218
1219 fn tls_client_metadata_key(&self) -> Option<OwnedValuePath> {
1220 self.tls_client_metadata_key.clone()
1221 }
1222
1223 fn receive_buffer_bytes(&self) -> Option<usize> {
1224 self.receive_buffer_bytes
1225 }
1226
1227 fn max_connection_duration_secs(&self) -> Option<u64> {
1228 self.max_connection_duration_secs
1229 }
1230
1231 fn max_connections(&self) -> Option<u32> {
1232 self.max_connections
1233 }
1234
1235 fn insert_tls_client_metadata(&mut self, _: Option<CertificateMetadata>) {}
1236
1237 fn allowed_origins(&self) -> Option<&[IpNet]> {
1238 self.permit_origin.as_deref()
1239 }
1240 }
1241
1242 fn init_framestream_tcp(
1243 source_id: &str,
1244 addr: &SocketAddr,
1245 frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
1246 pipeline: SourceSender,
1247 ) -> (JoinHandle<Result<(), ()>>, SourceShutdownCoordinator) {
1248 let source_id = ComponentKey::from(source_id);
1249 let mut shutdown = SourceShutdownCoordinator::default();
1250 let (shutdown_signal, _) = shutdown.register_source(&source_id, false);
1251 let server = build_framestream_tcp_source(frame_handler, shutdown_signal, pipeline)
1252 .expect("Failed to build framestream tcp source.");
1253
1254 let join_handle = tokio::spawn(server);
1255
1256 while std::net::TcpStream::connect(addr).is_err() {
1257 thread::sleep(Duration::from_millis(2));
1258 }
1259
1260 (join_handle, shutdown)
1261 }
1262
1263 fn init_framestream_unix(
1264 source_id: &str,
1265 frame_handler: impl UnixFrameHandler + Send + Sync + Clone + 'static,
1266 pipeline: SourceSender,
1267 ) -> (
1268 PathBuf,
1269 JoinHandle<Result<(), ()>>,
1270 SourceShutdownCoordinator,
1271 ) {
1272 let source_id = ComponentKey::from(source_id);
1273 let socket_path = frame_handler.socket_path();
1274 let mut shutdown = SourceShutdownCoordinator::default();
1275 let (shutdown_signal, _) = shutdown.register_source(&source_id, false);
1276 let server = build_framestream_unix_source(frame_handler, shutdown_signal, pipeline)
1277 .expect("Failed to build framestream unix source.");
1278
1279 let join_handle = tokio::spawn(server);
1280
1281 while std::os::unix::net::UnixStream::connect(&socket_path).is_err() {
1283 thread::sleep(Duration::from_millis(2));
1284 }
1285
1286 (socket_path, join_handle, shutdown)
1287 }
1288
1289 async fn make_tcp_stream(
1290 addr: SocketAddr,
1291 ) -> Framed<TcpStream, length_delimited::LengthDelimitedCodec> {
1292 let socket = TcpStream::connect(&addr).await.unwrap();
1293 Framed::new(socket, length_delimited::Builder::new().new_codec())
1294 }
1295
1296 async fn make_unix_stream(
1297 path: PathBuf,
1298 ) -> Framed<UnixStream, length_delimited::LengthDelimitedCodec> {
1299 let socket = UnixStream::connect(&path).await.unwrap();
1300 Framed::new(socket, length_delimited::Builder::new().new_codec())
1301 }
1302
1303 async fn send_data_frames<S: Sink<Bytes, Error = std::io::Error> + Unpin>(
1304 sock_sink: &mut S,
1305 frames: Vec<Result<Bytes, std::io::Error>>,
1306 ) {
1307 let mut stream = stream::iter(frames.into_iter());
1308 _ = sock_sink.send_all(&mut stream).await;
1310 }
1311
1312 async fn send_control_frame<S: Sink<Bytes, Error = std::io::Error> + Unpin>(
1313 sock_sink: &mut S,
1314 frame: Bytes,
1315 ) {
1316 send_data_frames(sock_sink, vec![Ok(Bytes::new()), Ok(frame)]).await; }
1318
1319 fn create_control_frame(header: ControlHeader) -> Bytes {
1320 Bytes::from(header.to_u32().to_be_bytes().to_vec())
1321 }
1322
1323 fn create_control_frame_with_content(
1324 header: ControlHeader,
1325 content_types: Vec<Bytes>,
1326 ) -> Bytes {
1327 let mut frame = BytesMut::from(&header.to_u32().to_be_bytes()[..]);
1328 for content_type in content_types {
1329 frame.extend(ControlField::ContentType.to_u32().to_be_bytes());
1330 frame.extend((content_type.len() as u32).to_be_bytes());
1331 frame.extend(content_type.clone());
1332 }
1333 Bytes::from(frame)
1334 }
1335
1336 fn assert_accept_frame(frame: &mut BytesMut, expected_content_type: Bytes) {
1337 assert_eq!(&frame[..4], &ControlHeader::Accept.to_u32().to_be_bytes(),);
1340 frame.advance(4);
1341 assert_eq!(
1343 &frame[..4],
1344 &ControlField::ContentType.to_u32().to_be_bytes(),
1345 );
1346 frame.advance(4);
1347 assert_eq!(
1349 &frame[..4],
1350 &(expected_content_type.len() as u32).to_be_bytes(),
1351 );
1352 frame.advance(4);
1353 assert_eq!(&frame[..], &expected_content_type[..]);
1355 }
1356
1357 fn create_frame_handler(multithreaded: bool) -> impl UnixFrameHandler + Send + Sync + Clone {
1358 MockUnixFrameHandler::new("test_content".to_string(), multithreaded, move || {})
1359 }
1360
1361 fn create_tcp_frame_handler(
1362 addr: SocketAddr,
1363 multithreaded: bool,
1364 permit_origin: Option<Vec<IpNet>>,
1365 ) -> impl TcpFrameHandler + Send + Sync + Clone {
1366 MockTcpFrameHandler::new(
1367 addr,
1368 "test_content".to_string(),
1369 multithreaded,
1370 move || {},
1371 permit_origin,
1372 )
1373 }
1374
1375 async fn signal_shutdown(source_name: &str, shutdown: &mut SourceShutdownCoordinator) {
1376 let deadline = Instant::now() + Duration::from_secs(10);
1378 let id = ComponentKey::from(source_name);
1379 let shutdown_complete = shutdown.shutdown_source(&id, deadline);
1380 let shutdown_success = shutdown_complete.await;
1381 assert!(shutdown_success);
1382 }
1383
1384 async fn test_normal_framestream<
1385 T: Sink<Bytes, Error = std::io::Error> + Unpin,
1386 U: Stream<Item = Result<BytesMut, std::io::Error>> + Unpin,
1387 V: Stream<Item = Event> + Unpin,
1388 >(
1389 source_name: &str,
1390 mut sock_sink: T,
1391 mut sock_stream: U,
1392 rx: V,
1393 mut shutdown: SourceShutdownCoordinator,
1394 source_handle: JoinHandle<Result<(), ()>>,
1395 ) {
1396 let content_type = Bytes::from(&b"test_content"[..]);
1398 let ready_msg =
1399 create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1400 send_control_frame(&mut sock_sink, ready_msg).await;
1401
1402 let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1404 assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1406 assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1407
1408 send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Start)).await;
1410
1411 send_data_frames(
1413 &mut sock_sink,
1414 vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1415 )
1416 .await;
1417 let events = collect_n(rx, 2).await;
1418
1419 send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1421
1422 let message_key = log_schema().message_key().unwrap().to_string();
1423 assert!(
1424 events
1425 .iter()
1426 .any(|e| e.as_log()[&message_key] == "hello".into())
1427 );
1428 assert!(
1429 events
1430 .iter()
1431 .any(|e| e.as_log()[&message_key] == "world".into())
1432 );
1433
1434 drop(sock_stream); signal_shutdown(source_name, &mut shutdown).await;
1438 _ = source_handle.await.unwrap();
1439 }
1440
1441 async fn test_multiple_content_types<
1442 T: Sink<Bytes, Error = std::io::Error> + Unpin,
1443 U: Stream<Item = Result<BytesMut, std::io::Error>> + Unpin,
1444 >(
1445 source_name: &str,
1446 mut sock_sink: T,
1447 mut sock_stream: U,
1448 mut shutdown: SourceShutdownCoordinator,
1449 source_handle: JoinHandle<Result<(), ()>>,
1450 ) {
1451 let content_type = Bytes::from(&b"test_content"[..]);
1453 let ready_msg = create_control_frame_with_content(
1454 ControlHeader::Ready,
1455 vec![Bytes::from(&b"test_content2"[..]), content_type.clone()],
1456 ); send_control_frame(&mut sock_sink, ready_msg).await;
1458
1459 let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1461
1462 assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1464 assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1465
1466 drop(sock_stream); signal_shutdown(source_name, &mut shutdown).await;
1470 _ = source_handle.await.unwrap();
1471 }
1472
1473 #[tokio::test(flavor = "multi_thread")]
1474 #[should_panic]
1475 async fn blocked_framestream_tcp() {
1476 let source_name = "test_source";
1477 let (tx, rx) = SourceSender::new_test();
1478 let addr = next_addr();
1479 let (source_handle, shutdown) = init_framestream_tcp(
1480 source_name,
1481 &addr,
1482 create_tcp_frame_handler(addr, false, Some(vec![])),
1483 tx,
1484 );
1485 let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1486
1487 test_normal_framestream(
1488 source_name,
1489 sock_sink,
1490 sock_stream,
1491 rx,
1492 shutdown,
1493 source_handle,
1494 )
1495 .await;
1496 }
1497
1498 #[tokio::test(flavor = "multi_thread")]
1499 async fn normal_framestream_singlethreaded_tcp() {
1500 let source_name = "test_source";
1501 let (tx, rx) = SourceSender::new_test();
1502 let addr = next_addr();
1503 let (source_handle, shutdown) = init_framestream_tcp(
1504 source_name,
1505 &addr,
1506 create_tcp_frame_handler(addr, false, None),
1507 tx,
1508 );
1509 let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1510
1511 test_normal_framestream(
1512 source_name,
1513 sock_sink,
1514 sock_stream,
1515 rx,
1516 shutdown,
1517 source_handle,
1518 )
1519 .await;
1520 }
1521
1522 #[tokio::test(flavor = "multi_thread")]
1523 async fn normal_framestream_singlethreaded_unix() {
1524 let source_name = "test_source";
1525 let (tx, rx) = SourceSender::new_test();
1526 let (path, source_handle, shutdown) =
1527 init_framestream_unix(source_name, create_frame_handler(false), tx);
1528 let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1529
1530 test_normal_framestream(
1531 source_name,
1532 sock_sink,
1533 sock_stream,
1534 rx,
1535 shutdown,
1536 source_handle,
1537 )
1538 .await;
1539 }
1540
1541 #[tokio::test(flavor = "multi_thread")]
1542 async fn normal_framestream_multithreaded_tcp() {
1543 let source_name = "test_source";
1544 let (tx, rx) = SourceSender::new_test();
1545 let addr = next_addr();
1546 let (source_handle, shutdown) = init_framestream_tcp(
1547 source_name,
1548 &addr,
1549 create_tcp_frame_handler(addr, true, None),
1550 tx,
1551 );
1552 let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1553
1554 test_normal_framestream(
1555 source_name,
1556 sock_sink,
1557 sock_stream,
1558 rx,
1559 shutdown,
1560 source_handle,
1561 )
1562 .await;
1563 }
1564
1565 #[tokio::test(flavor = "multi_thread")]
1566 async fn normal_framestream_multithreaded_unix() {
1567 let source_name = "test_source";
1568 let (tx, rx) = SourceSender::new_test();
1569 let (path, source_handle, shutdown) =
1570 init_framestream_unix(source_name, create_frame_handler(true), tx);
1571 let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1572
1573 test_normal_framestream(
1574 source_name,
1575 sock_sink,
1576 sock_stream,
1577 rx,
1578 shutdown,
1579 source_handle,
1580 )
1581 .await;
1582 }
1583
1584 #[tokio::test(flavor = "multi_thread")]
1585 async fn multiple_content_types_tcp() {
1586 let source_name = "test_source";
1587 let (tx, _) = SourceSender::new_test();
1588 let addr = next_addr();
1589 let (source_handle, shutdown) = init_framestream_tcp(
1590 source_name,
1591 &addr,
1592 create_tcp_frame_handler(addr, false, None),
1593 tx,
1594 );
1595 let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1596
1597 test_multiple_content_types(source_name, sock_sink, sock_stream, shutdown, source_handle)
1598 .await;
1599 }
1600
1601 #[tokio::test(flavor = "multi_thread")]
1602 async fn multiple_content_types_unix() {
1603 let source_name = "test_source";
1604 let (tx, _) = SourceSender::new_test();
1605 let (path, source_handle, shutdown) =
1606 init_framestream_unix(source_name, create_frame_handler(false), tx);
1607 let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1608
1609 test_multiple_content_types(source_name, sock_sink, sock_stream, shutdown, source_handle)
1610 .await;
1611 }
1612
1613 #[tokio::test(flavor = "multi_thread")]
1614 async fn wrong_content_type() {
1615 let source_name = "test_source";
1616 let (tx, _) = SourceSender::new_test();
1617 let (path, source_handle, mut shutdown) =
1618 init_framestream_unix(source_name, create_frame_handler(false), tx);
1619 let (mut sock_sink, mut sock_stream) = make_unix_stream(path).await.split();
1620
1621 let ready_msg = create_control_frame_with_content(
1623 ControlHeader::Ready,
1624 vec![Bytes::from(&b"test_content2"[..])],
1625 ); send_control_frame(&mut sock_sink, ready_msg).await;
1627
1628 let content_type = Bytes::from(&b"test_content"[..]);
1630 let ready_msg =
1631 create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1632 send_control_frame(&mut sock_sink, ready_msg).await;
1633
1634 let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1636
1637 assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1639 assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1640
1641 drop(sock_stream); signal_shutdown(source_name, &mut shutdown).await;
1645 _ = source_handle.await.unwrap();
1646 }
1647
1648 #[tokio::test(flavor = "multi_thread")]
1649 async fn data_too_soon() {
1650 let source_name = "test_source";
1651 let (tx, rx) = SourceSender::new_test();
1652 let (path, source_handle, mut shutdown) =
1653 init_framestream_unix(source_name, create_frame_handler(false), tx);
1654 let (mut sock_sink, mut sock_stream) = make_unix_stream(path).await.split();
1655
1656 send_data_frames(
1658 &mut sock_sink,
1659 vec![Ok(Bytes::from("bad")), Ok(Bytes::from("data"))],
1660 )
1661 .await;
1662
1663 let content_type = Bytes::from(&b"test_content"[..]);
1665 let ready_msg =
1666 create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1667 send_control_frame(&mut sock_sink, ready_msg).await;
1668
1669 let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1671
1672 assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1674 assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1675
1676 send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Start)).await;
1678
1679 send_data_frames(
1681 &mut sock_sink,
1682 vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1683 )
1684 .await;
1685 let events = collect_n(rx, 2).await;
1686
1687 send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1689
1690 assert_eq!(
1691 events[0].as_log()[log_schema().message_key().unwrap().to_string()],
1692 "hello".into(),
1693 );
1694 assert_eq!(
1695 events[1].as_log()[log_schema().message_key().unwrap().to_string()],
1696 "world".into(),
1697 );
1698
1699 drop(sock_stream); signal_shutdown(source_name, &mut shutdown).await;
1703 _ = source_handle.await.unwrap();
1704 }
1705
1706 #[tokio::test(flavor = "multi_thread")]
1707 async fn unidirectional_framestream() {
1708 let source_name = "test_source";
1709 let (tx, rx) = SourceSender::new_test();
1710 let (path, source_handle, mut shutdown) =
1711 init_framestream_unix(source_name, create_frame_handler(false), tx);
1712 let (mut sock_sink, _) = make_unix_stream(path).await.split();
1713
1714 let content_type = Bytes::from(&b"test_content"[..]);
1716 let start_msg = create_control_frame_with_content(ControlHeader::Start, vec![content_type]);
1717 send_control_frame(&mut sock_sink, start_msg).await;
1718
1719 send_data_frames(
1721 &mut sock_sink,
1722 vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1723 )
1724 .await;
1725 let events = collect_n(rx, 2).await;
1726
1727 send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1729
1730 assert_eq!(
1731 events[0].as_log()[log_schema().message_key().unwrap().to_string()],
1732 "hello".into(),
1733 );
1734 assert_eq!(
1735 events[1].as_log()[log_schema().message_key().unwrap().to_string()],
1736 "world".into(),
1737 );
1738
1739 signal_shutdown(source_name, &mut shutdown).await;
1741 _ = source_handle.await.unwrap();
1742 }
1743
1744 #[tokio::test(flavor = "multi_thread")]
1745 async fn test_spawn_event_handling_tasks() {
1746 let (out, rx) = SourceSender::new_test();
1747
1748 let max_frame_handling_tasks = 20;
1749 let active_task_nums = Arc::new(AtomicUsize::new(0));
1750 let active_task_nums_copy = Arc::clone(&active_task_nums);
1751 let max_task_nums_reached = Arc::new(AtomicUsize::new(0));
1752 let max_task_nums_reached_copy = Arc::clone(&max_task_nums_reached);
1753
1754 let mut join_handles = vec![];
1755 let active_task_nums_copy_2 = Arc::clone(&active_task_nums_copy);
1756 let extra_routine = move || {
1757 thread::sleep(Duration::from_millis(10));
1758 max_task_nums_reached_copy.fetch_max(
1759 active_task_nums_copy_2.load(Ordering::Acquire),
1760 Ordering::AcqRel,
1761 );
1762 };
1763
1764 let total_events = max_frame_handling_tasks * 10;
1765
1766 join_handles.push(tokio::spawn(async move {
1767 future::ready({
1768 let events = collect_n(rx, total_events).await;
1769 assert_eq!(total_events, events.len(), "Missed events");
1770 })
1771 .await;
1772 }));
1773
1774 for i in 0..total_events {
1775 join_handles.push(
1776 spawn_event_handling_tasks(
1777 Bytes::from(format!("event_{i}")),
1778 MockFrameHandler::new("test_content".to_string(), true, extra_routine.clone()),
1779 out.clone(),
1780 None,
1781 Arc::clone(&active_task_nums_copy),
1782 max_frame_handling_tasks,
1783 )
1784 .await,
1785 );
1786 }
1787
1788 future::join_all(join_handles).await;
1789
1790 let final_task_nums = active_task_nums.load(Ordering::Acquire);
1791 assert_eq!(
1792 0, final_task_nums,
1793 "There should be NO left-over tasks at the end"
1794 );
1795
1796 let max_task_nums_reached_value = max_task_nums_reached.load(Ordering::Acquire);
1797 assert!(
1798 max_task_nums_reached_value > 1,
1799 "MultiThreaded mode does NOT work"
1800 );
1801 assert!(
1802 (max_task_nums_reached_value - max_frame_handling_tasks) < 2,
1803 "Max number of tasks at any given time should NOT Exceed max_frame_handling_tasks too much"
1804 );
1805 }
1806}