1use ipnet::IpNet;
2#[cfg(unix)]
3use std::os::unix::{fs::PermissionsExt, io::AsRawFd};
4use std::{
5 convert::TryInto,
6 fs,
7 marker::{Send, Sync},
8 net::SocketAddr,
9 path::PathBuf,
10 sync::{
11 atomic::{AtomicUsize, Ordering},
12 Arc, Mutex,
13 },
14 time::Duration,
15};
16
17use bytes::{Buf, Bytes, BytesMut};
18use futures::{
19 executor::block_on,
20 future::{self, OptionFuture},
21 sink::{Sink, SinkExt},
22 stream::{self, StreamExt, TryStreamExt},
23};
24use futures_util::{future::BoxFuture, Future, FutureExt};
25use listenfd::ListenFd;
26use tokio::{
27 self,
28 io::{AsyncRead, AsyncWrite},
29 net::{TcpStream, UnixListener},
30 task::JoinHandle,
31 time::sleep,
32};
33use tokio_stream::wrappers::UnixListenerStream;
34use tokio_util::codec::{length_delimited, Framed};
35use tracing::{field, Instrument, Span};
36use vector_lib::{
37 lookup::OwnedValuePath,
38 tcp::TcpKeepaliveConfig,
39 tls::{CertificateMetadata, MaybeTlsIncomingStream, MaybeTlsSettings},
40};
41
42use crate::{
43 event::Event,
44 internal_events::{
45 ConnectionOpen, OpenGauge, SocketBindError, SocketMode, SocketReceiveError,
46 TcpBytesReceived, TcpSocketError, TcpSocketTlsConnectionError, UnixSocketError,
47 UnixSocketFileDeleteError,
48 },
49 shutdown::ShutdownSignal,
50 sources::{
51 util::{
52 net::{try_bind_tcp_listener, MAX_IN_FLIGHT_EVENTS_TARGET},
53 AfterReadExt,
54 },
55 Source,
56 },
57 SourceSender,
58};
59
60use super::net::{RequestLimiter, SocketListenAddr};
61
62const FSTRM_CONTROL_FRAME_LENGTH_MAX: usize = 512;
63const FSTRM_CONTROL_FIELD_CONTENT_TYPE_LENGTH_MAX: usize = 256;
64
65const PERMIT_HOLD_TIMEOUT_MS: u64 = 10;
70
71pub type FrameStreamSink = Box<dyn Sink<Bytes, Error = std::io::Error> + Send + Unpin>;
72
73pub struct FrameStreamReader {
74 response_sink: Mutex<FrameStreamSink>,
75 expected_content_type: String,
76 state: FrameStreamState,
77}
78
79struct FrameStreamState {
80 expect_control_frame: bool,
81 control_state: ControlState,
82 is_bidirectional: bool,
83}
84impl FrameStreamState {
85 const fn new() -> Self {
86 FrameStreamState {
87 expect_control_frame: false,
88 control_state: ControlState::Initial,
90 is_bidirectional: true, }
92 }
93}
94
95#[derive(PartialEq, Debug)]
96enum ControlState {
97 Initial,
98 GotReady,
99 ReadingData,
100 Stopped,
101}
102
103#[derive(Copy, Clone)]
104enum ControlHeader {
105 Accept,
106 Start,
107 Stop,
108 Ready,
109 Finish,
110}
111
112impl ControlHeader {
113 fn from_u32(val: u32) -> Result<Self, ()> {
114 match val {
115 0x01 => Ok(ControlHeader::Accept),
116 0x02 => Ok(ControlHeader::Start),
117 0x03 => Ok(ControlHeader::Stop),
118 0x04 => Ok(ControlHeader::Ready),
119 0x05 => Ok(ControlHeader::Finish),
120 _ => {
121 error!("Don't know header value {} (expected 0x01 - 0x05).", val);
122 Err(())
123 }
124 }
125 }
126
127 const fn to_u32(self) -> u32 {
128 match self {
129 ControlHeader::Accept => 0x01,
130 ControlHeader::Start => 0x02,
131 ControlHeader::Stop => 0x03,
132 ControlHeader::Ready => 0x04,
133 ControlHeader::Finish => 0x05,
134 }
135 }
136}
137
138enum ControlField {
139 ContentType,
140}
141
142impl ControlField {
143 fn from_u32(val: u32) -> Result<Self, ()> {
144 match val {
145 0x01 => Ok(ControlField::ContentType),
146 _ => {
147 error!("Don't know field type {} (expected 0x01).", val);
148 Err(())
149 }
150 }
151 }
152 const fn to_u32(&self) -> u32 {
153 match self {
154 ControlField::ContentType => 0x01,
155 }
156 }
157}
158
159fn advance_u32(b: &mut Bytes) -> Result<u32, ()> {
160 if b.len() < 4 {
161 error!("Malformed frame.");
162 return Err(());
163 }
164 let a = b.split_to(4);
165 Ok(u32::from_be_bytes(a[..].try_into().unwrap()))
166}
167
168impl FrameStreamReader {
169 pub fn new(response_sink: FrameStreamSink, expected_content_type: String) -> Self {
170 FrameStreamReader {
171 response_sink: Mutex::new(response_sink),
172 expected_content_type,
173 state: FrameStreamState::new(),
174 }
175 }
176
177 pub fn handle_frame(&mut self, frame: Bytes) -> Option<Bytes> {
178 if frame.is_empty() {
179 self.state.expect_control_frame = true;
181 None
182 } else if self.state.expect_control_frame {
183 self.state.expect_control_frame = false;
184 _ = self.handle_control_frame(frame);
185 None
186 } else {
187 if self.state.control_state == ControlState::ReadingData {
189 Some(frame) } else {
191 error!(
192 "Received a data frame while in state {:?}.",
193 self.state.control_state
194 );
195 None
196 }
197 }
198 }
199
200 fn handle_control_frame(&mut self, mut frame: Bytes) -> Result<(), ()> {
201 if frame.len() > FSTRM_CONTROL_FRAME_LENGTH_MAX {
203 error!("Control frame is too long.");
204 }
205
206 let header = ControlHeader::from_u32(advance_u32(&mut frame)?)?;
207
208 match self.state.control_state {
210 ControlState::Initial => {
211 match header {
212 ControlHeader::Ready => {
213 let content_type = self.process_fields(header, &mut frame)?.unwrap();
214
215 self.send_control_frame(Self::make_frame(
216 ControlHeader::Accept,
217 Some(content_type),
218 ));
219 self.state.control_state = ControlState::GotReady; }
221 ControlHeader::Start => {
222 _ = self.process_fields(header, &mut frame)?;
224 self.state.control_state = ControlState::ReadingData;
226 self.state.is_bidirectional = false; }
228 _ => error!("Got wrong control frame, expected READY."),
229 }
230 }
231 ControlState::GotReady => {
232 match header {
233 ControlHeader::Start => {
234 _ = self.process_fields(header, &mut frame)?;
236 self.state.control_state = ControlState::ReadingData;
238 }
239 _ => error!("Got wrong control frame, expected START."),
240 }
241 }
242 ControlState::ReadingData => {
243 match header {
244 ControlHeader::Stop => {
245 _ = self.process_fields(header, &mut frame)?;
247 if self.state.is_bidirectional {
248 self.send_control_frame(Self::make_frame(ControlHeader::Finish, None));
250 }
251 self.state.control_state = ControlState::Stopped; }
253 _ => error!("Got wrong control frame, expected STOP."),
254 }
255 }
256 ControlState::Stopped => error!("Unexpected control frame, current state is STOPPED."),
257 };
258 Ok(())
259 }
260
261 fn process_fields(
262 &mut self,
263 header: ControlHeader,
264 frame: &mut Bytes,
265 ) -> Result<Option<String>, ()> {
266 match header {
267 ControlHeader::Ready => {
268 let is_start_frame = false;
271 let content_type = self.process_content_type(frame, is_start_frame)?;
272 Ok(Some(content_type))
273 }
274 ControlHeader::Start => {
275 if frame.is_empty() {
277 Ok(None)
278 } else {
279 let is_start_frame = true;
281 let content_type = self.process_content_type(frame, is_start_frame)?;
282 Ok(Some(content_type))
283 }
284 }
285 ControlHeader::Stop => {
286 if !frame.is_empty() {
288 error!("Unexpected fields in STOP header.");
289 Err(())
290 } else {
291 Ok(None)
292 }
293 }
294 _ => {
295 error!("Unexpected control header value {:?}.", header.to_u32());
296 Err(())
297 }
298 }
299 }
300
301 fn process_content_type(&self, frame: &mut Bytes, is_start_frame: bool) -> Result<String, ()> {
302 if frame.is_empty() {
303 error!("No fields in control frame.");
304 return Err(());
305 }
306
307 let mut content_types = vec![];
308 while !frame.is_empty() {
309 let field_val = advance_u32(frame)?;
311 let field_type = ControlField::from_u32(field_val)?;
312 match field_type {
313 ControlField::ContentType => {
314 let field_len = advance_u32(frame)? as usize;
316
317 if field_len > FSTRM_CONTROL_FIELD_CONTENT_TYPE_LENGTH_MAX {
319 error!("Content-Type string is too long.");
320 return Err(());
321 }
322
323 let content_type = std::str::from_utf8(&frame[..field_len]).unwrap();
324 content_types.push(content_type.to_string());
325 frame.advance(field_len);
326 }
327 }
328 }
329
330 if is_start_frame && content_types.len() > 1 {
331 error!(
332 "START control frame can only have one content-type provided (got {}).",
333 content_types.len()
334 );
335 return Err(());
336 }
337
338 for content_type in &content_types {
339 if *content_type == self.expected_content_type {
340 return Ok(content_type.clone());
341 }
342 }
343
344 error!(
345 "Content types did not match up. Expected {} got {:?}.",
346 self.expected_content_type, content_types
347 );
348 Err(())
349 }
350
351 fn make_frame(header: ControlHeader, content_type: Option<String>) -> Bytes {
352 let mut frame = BytesMut::new();
353 frame.extend(header.to_u32().to_be_bytes());
354 if let Some(s) = content_type {
355 frame.extend(ControlField::ContentType.to_u32().to_be_bytes()); frame.extend((s.len() as u32).to_be_bytes()); frame.extend(s.as_bytes());
358 }
359 Bytes::from(frame)
360 }
361
362 fn send_control_frame(&mut self, frame: Bytes) {
363 let empty_frame = Bytes::from(&b""[..]); let mut stream = stream::iter(vec![Ok(empty_frame), Ok(frame)]);
365
366 if let Err(e) = block_on(self.response_sink.lock().unwrap().send_all(&mut stream)) {
367 error!("Encountered error '{:#?}' while sending control frame.", e);
368 }
369 }
370}
371
372pub trait FrameHandler {
373 fn content_type(&self) -> String;
374 fn max_frame_length(&self) -> usize;
375 fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event>;
376 fn multithreaded(&self) -> bool;
377 fn max_frame_handling_tasks(&self) -> usize;
378 fn host_key(&self) -> &Option<OwnedValuePath>;
379 fn timestamp_key(&self) -> Option<&OwnedValuePath>;
380 fn source_type_key(&self) -> Option<&OwnedValuePath>;
381}
382
383pub trait UnixFrameHandler: FrameHandler {
384 fn socket_path(&self) -> PathBuf;
385 fn socket_file_mode(&self) -> Option<u32>;
386 fn socket_receive_buffer_size(&self) -> Option<usize>;
387 fn socket_send_buffer_size(&self) -> Option<usize>;
388}
389
390pub trait TcpFrameHandler: FrameHandler {
391 fn address(&self) -> SocketListenAddr;
392 fn keepalive(&self) -> Option<TcpKeepaliveConfig>;
393 fn shutdown_timeout_secs(&self) -> Duration;
394 fn tls(&self) -> MaybeTlsSettings;
395 fn tls_client_metadata_key(&self) -> Option<OwnedValuePath>;
396 fn receive_buffer_bytes(&self) -> Option<usize>;
397 fn max_connection_duration_secs(&self) -> Option<u64>;
398 fn max_connections(&self) -> Option<u32>;
399 fn allowed_origins(&self) -> Option<&[IpNet]>;
400 fn insert_tls_client_metadata(&mut self, metadata: Option<CertificateMetadata>);
401}
402
403pub fn build_framestream_tcp_source(
408 frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
409 shutdown: ShutdownSignal,
410 out: SourceSender,
411) -> crate::Result<Source> {
412 let addr = frame_handler.address();
413 let tls = frame_handler.tls();
414 let shutdown = shutdown.clone();
415 let out = out.clone();
416
417 Ok(Box::pin(async move {
418 let listenfd = ListenFd::from_env();
419 let listener = try_bind_tcp_listener(
420 addr,
421 listenfd,
422 &tls,
423 frame_handler
424 .allowed_origins()
425 .map(|origins| origins.to_vec()),
426 )
427 .await
428 .map_err(|error| {
429 emit!(SocketBindError {
430 mode: SocketMode::Tcp,
431 error: &error,
432 })
433 })?;
434
435 info!(
436 message = "Listening.",
437 addr = %listener
438 .local_addr()
439 .map(SocketListenAddr::SocketAddr)
440 .unwrap_or(addr)
441 );
442
443 let tripwire = shutdown.clone();
444 let shutdown_timeout_secs = frame_handler.shutdown_timeout_secs();
445 let tripwire = async move {
446 _ = tripwire.await;
447 sleep(shutdown_timeout_secs).await;
448 }
449 .shared();
450
451 let connection_gauge = OpenGauge::new();
452 let shutdown_clone = shutdown.clone();
453
454 let request_limiter = RequestLimiter::new(
455 MAX_IN_FLIGHT_EVENTS_TARGET,
456 frame_handler.max_frame_handling_tasks(),
457 );
458
459 listener
460 .accept_stream_limited(frame_handler.max_connections())
461 .take_until(shutdown_clone)
462 .for_each(move |(connection, tcp_connection_permit)| {
463 let shutdown_signal = shutdown.clone();
464 let tripwire = tripwire.clone();
465 let out = out.clone();
466 let connection_gauge = connection_gauge.clone();
467 let request_limiter = request_limiter.clone();
468 let frame_handler_clone = frame_handler.clone();
469
470 async move {
471 let socket = match connection {
472 Ok(socket) => socket,
473 Err(error) => {
474 emit!(SocketReceiveError {
475 mode: SocketMode::Tcp,
476 error: &error
477 });
478 return;
479 }
480 };
481
482 let peer_addr = socket.peer_addr();
483 let span = info_span!("connection", %peer_addr);
484
485 let tripwire = tripwire
486 .map(move |_| {
487 info!(
488 message = "Resetting connection (still open after seconds).",
489 seconds = ?shutdown_timeout_secs
490 );
491 })
492 .boxed();
493
494 span.clone().in_scope(|| {
495 debug!(message = "Accepted a new connection.", peer_addr = %peer_addr);
496
497 let open_token =
498 connection_gauge.open(|count| emit!(ConnectionOpen { count }));
499
500 let fut = handle_stream(
501 frame_handler_clone,
502 shutdown_signal,
503 socket,
504 tripwire,
505 peer_addr,
506 out,
507 request_limiter,
508 );
509
510 tokio::spawn(
511 fut.map(move |()| {
512 drop(open_token);
513 drop(tcp_connection_permit);
514 })
515 .instrument(span.or_current()),
516 );
517 });
518 }
519 })
520 .map(Ok)
521 .await
522 }))
523}
524
525#[allow(clippy::too_many_arguments)]
526async fn handle_stream(
527 mut frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
528 mut shutdown_signal: ShutdownSignal,
529 mut socket: MaybeTlsIncomingStream<TcpStream>,
530 mut tripwire: BoxFuture<'static, ()>,
531 peer_addr: SocketAddr,
532 out: SourceSender,
533 request_limiter: RequestLimiter,
534) {
535 tokio::select! {
536 result = socket.handshake() => {
537 if let Err(error) = result {
538 emit!(TcpSocketTlsConnectionError { error });
539 return;
540 }
541 },
542 _ = &mut shutdown_signal => {
543 return;
544 }
545 };
546
547 if let Some(keepalive) = frame_handler.keepalive() {
548 if let Err(error) = socket.set_keepalive(keepalive) {
549 warn!(message = "Failed configuring TCP keepalive.", %error);
550 }
551 }
552
553 if let Some(receive_buffer_bytes) = frame_handler.receive_buffer_bytes() {
554 if let Err(error) = socket.set_receive_buffer_bytes(receive_buffer_bytes) {
555 warn!(message = "Failed configuring receive buffer size on TCP socket.", %error);
556 }
557 }
558
559 let socket = socket.after_read(move |byte_size| {
560 emit!(TcpBytesReceived {
561 byte_size,
562 peer_addr
563 });
564 });
565
566 let certificate_metadata = socket
567 .get_ref()
568 .ssl_stream()
569 .and_then(|stream| stream.ssl().peer_certificate())
570 .map(CertificateMetadata::from);
571
572 frame_handler.insert_tls_client_metadata(certificate_metadata);
573
574 let span = info_span!("connection");
575 span.record("peer_addr", field::debug(&peer_addr));
576 let received_from: Option<Bytes> = Some(peer_addr.to_string().into());
577
578 let connection_close_timeout = OptionFuture::from(
579 frame_handler
580 .max_connection_duration_secs()
581 .map(|timeout_secs| tokio::time::sleep(Duration::from_secs(timeout_secs))),
582 );
583 tokio::pin!(connection_close_timeout);
584
585 let content_type = frame_handler.content_type();
586 let mut event_sink = out.clone();
587 let (sock_sink, sock_stream) = Framed::new(
588 socket,
589 length_delimited::Builder::new()
590 .max_frame_length(frame_handler.max_frame_length())
591 .new_codec(),
592 )
593 .split();
594 let mut reader = FrameStreamReader::new(Box::new(sock_sink), content_type);
595 let mut frames = sock_stream
596 .map_err(move |error| {
597 emit!(TcpSocketError {
598 error: &error,
599 peer_addr,
600 });
601 })
602 .filter_map(move |frame| {
603 future::ready(match frame {
604 Ok(f) => reader.handle_frame(Bytes::from(f)),
605 Err(_) => None,
606 })
607 });
608
609 let active_parsing_task_nums = Arc::new(AtomicUsize::new(0));
610 loop {
611 let mut permit = tokio::select! {
612 _ = &mut tripwire => break,
613 Some(_) = &mut connection_close_timeout => {
614 break;
615 },
616 _ = &mut shutdown_signal => {
617 break;
618 },
619 permit = request_limiter.acquire() => {
620 Some(permit)
621 }
622 else => break,
623 };
624
625 let timeout = tokio::time::sleep(Duration::from_millis(PERMIT_HOLD_TIMEOUT_MS));
626 tokio::pin!(timeout);
627
628 tokio::select! {
629 _ = &mut tripwire => break,
630 _ = &mut shutdown_signal => break,
631 _ = &mut timeout => {
632 continue;
635 }
636 res = frames.next() => {
637 match res {
638 Some(frame) => {
639 if let Some(permit) = &mut permit {
640 permit.decoding_finished(1);
644 };
645 handle_tcp_frame(&mut frame_handler, frame, &mut event_sink, received_from.clone(), Arc::clone(&active_parsing_task_nums)).await;
646 }
647 None => {
648 debug!("Connection closed.");
649 break
650 },
651 }
652 }
653 else => break,
654 }
655
656 drop(permit);
657 }
658}
659
660async fn handle_tcp_frame<T>(
661 frame_handler: &mut T,
662 frame: Bytes,
663 event_sink: &mut SourceSender,
664 received_from: Option<Bytes>,
665 active_parsing_task_nums: Arc<AtomicUsize>,
666) where
667 T: TcpFrameHandler + Send + Sync + Clone + 'static,
668{
669 if frame_handler.multithreaded() {
670 spawn_event_handling_tasks(
671 frame,
672 frame_handler.clone(),
673 event_sink.clone(),
674 received_from,
675 active_parsing_task_nums,
676 frame_handler.max_frame_handling_tasks(),
677 )
678 .await;
679 } else if let Some(event) = frame_handler.handle_event(received_from, frame) {
680 if let Err(e) = event_sink.send_event(event).await {
681 error!(
682 internal_log_rate_limit = true,
683 "Error sending event: {e:?}."
684 );
685 }
686 }
687}
688
689pub fn build_framestream_unix_source(
695 frame_handler: impl UnixFrameHandler + Send + Sync + Clone + 'static,
696 shutdown: ShutdownSignal,
697 out: SourceSender,
698) -> crate::Result<Source> {
699 let path = frame_handler.socket_path();
700
701 match fs::metadata(&path) {
703 Ok(_) => {
704 info!(message = "Deleting file.", ?path);
706 fs::remove_file(&path)?;
707 }
708 Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {} Err(e) => {
710 error!("Unable to get socket information; error = {:?}.", e);
711 return Err(Box::new(e));
712 }
713 };
714
715 let listener = UnixListener::bind(&path)?;
716
717 if let Some(socket_receive_buffer_size) = frame_handler.socket_receive_buffer_size() {
719 _ = nix::sys::socket::setsockopt(
720 listener.as_raw_fd(),
721 nix::sys::socket::sockopt::RcvBuf,
722 &(socket_receive_buffer_size),
723 );
724 let rcv_buf_size =
725 nix::sys::socket::getsockopt(listener.as_raw_fd(), nix::sys::socket::sockopt::RcvBuf);
726 info!(
727 "Unix socket receive buffer size modified to {}.",
728 rcv_buf_size.unwrap()
729 );
730 }
731
732 if let Some(socket_send_buffer_size) = frame_handler.socket_send_buffer_size() {
734 _ = nix::sys::socket::setsockopt(
735 listener.as_raw_fd(),
736 nix::sys::socket::sockopt::SndBuf,
737 &(socket_send_buffer_size),
738 );
739 let snd_buf_size =
740 nix::sys::socket::getsockopt(listener.as_raw_fd(), nix::sys::socket::sockopt::SndBuf);
741 info!(
742 "Unix socket buffer send size modified to {}.",
743 snd_buf_size.unwrap()
744 );
745 }
746
747 if let Some(socket_permission) = frame_handler.socket_file_mode() {
749 if !(448..=511).contains(&socket_permission) {
750 return Err(format!(
751 "Invalid Socket permission {socket_permission:#o}. Must between 0o700 and 0o777."
752 )
753 .into());
754 }
755 match fs::set_permissions(&path, fs::Permissions::from_mode(socket_permission)) {
756 Ok(_) => {
757 info!("Socket permissions updated to {:#o}.", socket_permission);
758 }
759 Err(e) => {
760 error!(
761 "Failed to update listener socket permissions; error = {:?}.",
762 e
763 );
764 return Err(Box::new(e));
765 }
766 };
767 };
768
769 let fut = async move {
770 let active_parsing_task_nums = Arc::new(AtomicUsize::new(0));
771
772 info!(message = "Listening...", ?path, r#type = "unix");
773
774 let mut stream = UnixListenerStream::new(listener).take_until(shutdown.clone());
775 while let Some(socket) = stream.next().await {
776 let socket = match socket {
777 Err(e) => {
778 error!("Failed to accept socket; error = {:?}.", e);
779 continue;
780 }
781 Ok(s) => s,
782 };
783 let peer_addr = socket.peer_addr().ok();
784 let listen_path = path.clone();
785 let active_task_nums_ = Arc::clone(&active_parsing_task_nums);
786
787 let span = info_span!("connection");
788 let path = if let Some(addr) = peer_addr {
789 if let Some(path) = addr.as_pathname().map(|e| e.to_owned()) {
790 span.record("peer_path", field::debug(&path));
791 Some(path)
792 } else {
793 None
794 }
795 } else {
796 None
797 };
798 let received_from: Option<Bytes> =
799 path.map(|p| p.to_string_lossy().into_owned().into());
800
801 build_framestream_source(
802 frame_handler.clone(),
803 socket,
804 received_from,
805 out.clone(),
806 shutdown.clone(),
807 span,
808 active_task_nums_,
809 move |error| {
810 emit!(UnixSocketError {
811 error: &error,
812 path: &listen_path,
813 });
814 },
815 );
816 }
817
818 drop(stream);
820
821 if let Err(error) = fs::remove_file(&path) {
823 emit!(UnixSocketFileDeleteError { path: &path, error });
824 }
825
826 Ok(())
827 };
828
829 Ok(Box::pin(fut))
830}
831
832#[allow(clippy::too_many_arguments)]
833fn build_framestream_source<T: Send + 'static>(
834 frame_handler: impl FrameHandler + Send + Sync + Clone + 'static,
835 socket: impl AsyncRead + AsyncWrite + Send + 'static,
836 received_from: Option<Bytes>,
837 out: SourceSender,
838 shutdown: impl Future<Output = T> + Unpin + Send + 'static,
839 span: Span,
840 active_task_nums: Arc<AtomicUsize>,
841 error_mapper: impl FnMut(std::io::Error) + Send + 'static,
842) {
843 let content_type = frame_handler.content_type();
844 let mut event_sink = out.clone();
845 let (sock_sink, sock_stream) = Framed::new(
846 socket,
847 length_delimited::Builder::new()
848 .max_frame_length(frame_handler.max_frame_length())
849 .new_codec(),
850 )
851 .split();
852 let mut fs_reader = FrameStreamReader::new(Box::new(sock_sink), content_type);
853 let frame_handler_copy = frame_handler.clone();
854 let frames = sock_stream
855 .take_until(shutdown)
856 .map_err(error_mapper)
857 .filter_map(move |frame| {
858 future::ready(match frame {
859 Ok(f) => fs_reader.handle_frame(Bytes::from(f)),
860 Err(_) => None,
861 })
862 });
863 if !frame_handler.multithreaded() {
864 let mut events = frames.filter_map(move |f| {
865 future::ready(frame_handler_copy.handle_event(received_from.clone(), f))
866 });
867
868 let handler = async move {
869 if let Err(e) = event_sink.send_event_stream(&mut events).await {
870 error!(
871 internal_log_rate_limit = true,
872 "Error sending event: {:?}.", e
873 );
874 }
875
876 info!("Finished sending.");
877 };
878 tokio::spawn(handler.instrument(span.or_current()));
879 } else {
880 let handler = async move {
881 frames
882 .for_each(move |f| {
883 let max_frame_handling_tasks = frame_handler_copy.max_frame_handling_tasks();
884 let f_handler = frame_handler_copy.clone();
885 let received_from_copy = received_from.clone();
886 let event_sink_copy = event_sink.clone();
887 let active_task_nums_copy = Arc::clone(&active_task_nums);
888
889 async move {
890 spawn_event_handling_tasks(
891 f,
892 f_handler,
893 event_sink_copy,
894 received_from_copy,
895 active_task_nums_copy,
896 max_frame_handling_tasks,
897 )
898 .await;
899 }
900 })
901 .await;
902 info!("Finished sending.");
903 };
904 tokio::spawn(handler.instrument(span.or_current()));
905 }
906}
907
908async fn spawn_event_handling_tasks(
909 event_data: Bytes,
910 event_handler: impl FrameHandler + Send + Sync + 'static,
911 mut event_sink: SourceSender,
912 received_from: Option<Bytes>,
913 active_task_nums: Arc<AtomicUsize>,
914 max_frame_handling_tasks: usize,
915) -> JoinHandle<()> {
916 wait_for_task_quota(&active_task_nums, max_frame_handling_tasks).await;
917
918 tokio::spawn(async move {
919 future::ready({
920 if let Some(evt) = event_handler.handle_event(received_from, event_data) {
921 if event_sink.send_event(evt).await.is_err() {
922 error!("Encountered error while sending event.");
923 }
924 }
925 active_task_nums.fetch_sub(1, Ordering::AcqRel);
926 })
927 .await;
928 })
929}
930
931async fn wait_for_task_quota(active_task_nums: &Arc<AtomicUsize>, max_tasks: usize) {
932 while max_tasks > 0 && max_tasks < active_task_nums.load(Ordering::Acquire) {
933 tokio::time::sleep(Duration::from_millis(3)).await;
934 }
935 active_task_nums.fetch_add(1, Ordering::AcqRel);
936}
937
938#[cfg(test)]
939mod test {
940 use futures_util::Stream;
941 use std::net::SocketAddr;
942 #[cfg(unix)]
943 use std::{
944 path::PathBuf,
945 sync::{
946 atomic::{AtomicUsize, Ordering},
947 Arc,
948 },
949 thread,
950 };
951 use tokio::net::TcpStream;
952
953 use bytes::{buf::Buf, Bytes, BytesMut};
954 use futures::{
955 future,
956 sink::{Sink, SinkExt},
957 stream::{self, StreamExt},
958 };
959 use ipnet::IpNet;
960 use tokio::{
961 self,
962 net::UnixStream,
963 task::JoinHandle,
964 time::{Duration, Instant},
965 };
966 use tokio_util::codec::{length_delimited, Framed};
967 use vector_lib::{
968 config::{LegacyKey, LogNamespace},
969 tcp::TcpKeepaliveConfig,
970 tls::{CertificateMetadata, MaybeTls},
971 };
972 use vector_lib::{
973 lookup::{owned_value_path, path, OwnedValuePath},
974 tls::MaybeTlsSettings,
975 };
976
977 use super::{
978 build_framestream_tcp_source, build_framestream_unix_source, spawn_event_handling_tasks,
979 ControlField, ControlHeader, FrameHandler, TcpFrameHandler, UnixFrameHandler,
980 };
981 use crate::{
982 config::{log_schema, ComponentKey},
983 event::{Event, LogEvent},
984 shutdown::SourceShutdownCoordinator,
985 sources::util::net::SocketListenAddr,
986 test_util::{collect_n, collect_n_stream, next_addr},
987 SourceSender,
988 };
989
990 #[derive(Clone)]
991 struct MockFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
992 content_type: String,
993 max_frame_length: usize,
994 multithreaded: bool,
995 max_frame_handling_tasks: usize,
996 extra_task_handling_routine: F,
997 host_key: Option<OwnedValuePath>,
998 timestamp_key: Option<OwnedValuePath>,
999 source_type_key: Option<OwnedValuePath>,
1000 log_namespace: LogNamespace,
1001 }
1002
1003 #[derive(Clone)]
1004 struct MockUnixFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
1005 frame_handler: MockFrameHandler<F>,
1006 socket_path: PathBuf,
1007 socket_file_mode: Option<u32>,
1008 socket_receive_buffer_size: Option<usize>,
1009 socket_send_buffer_size: Option<usize>,
1010 }
1011
1012 #[derive(Clone)]
1013 struct MockTcpFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
1014 frame_handler: MockFrameHandler<F>,
1015 address: SocketListenAddr,
1016 keepalive: Option<TcpKeepaliveConfig>,
1017 shutdown_timeout_secs: Duration,
1018 tls: MaybeTlsSettings,
1019 tls_client_metadata_key: Option<OwnedValuePath>,
1020 receive_buffer_bytes: Option<usize>,
1021 max_connection_duration_secs: Option<u64>,
1022 max_connections: Option<u32>,
1023 permit_origin: Option<Vec<IpNet>>,
1024 }
1025
1026 impl<F: Send + Sync + Clone + FnOnce() + 'static> MockTcpFrameHandler<F> {
1027 pub fn new(
1028 addr: SocketAddr,
1029 content_type: String,
1030 multithreaded: bool,
1031 extra_routine: F,
1032 permit_origin: Option<Vec<IpNet>>,
1033 ) -> Self {
1034 Self {
1035 frame_handler: MockFrameHandler::new(content_type, multithreaded, extra_routine),
1036 address: addr.into(),
1037 keepalive: None,
1038 shutdown_timeout_secs: Duration::from_secs(30),
1039 tls: MaybeTls::Raw(()),
1040 tls_client_metadata_key: None,
1041 receive_buffer_bytes: None,
1042 max_connection_duration_secs: None,
1043 max_connections: None,
1044 permit_origin,
1045 }
1046 }
1047 }
1048
1049 impl<F: Send + Sync + Clone + FnOnce() + 'static> MockUnixFrameHandler<F> {
1050 pub fn new(content_type: String, multithreaded: bool, extra_routine: F) -> Self {
1051 Self {
1052 frame_handler: MockFrameHandler::new(content_type, multithreaded, extra_routine),
1053 socket_path: tempfile::tempdir().unwrap().keep().join("unix_test"),
1054 socket_file_mode: None,
1055 socket_receive_buffer_size: None,
1056 socket_send_buffer_size: None,
1057 }
1058 }
1059 }
1060
1061 impl<F: Send + Sync + Clone + FnOnce() + 'static> MockFrameHandler<F> {
1062 pub fn new(content_type: String, multithreaded: bool, extra_routine: F) -> Self {
1063 Self {
1064 content_type,
1065 max_frame_length: bytesize::kib(100u64) as usize,
1066 multithreaded,
1067 max_frame_handling_tasks: 0,
1068 extra_task_handling_routine: extra_routine,
1069 host_key: Some(owned_value_path!("test_framestream")),
1070 timestamp_key: Some(owned_value_path!("my_timestamp")),
1071 source_type_key: Some(owned_value_path!("source_type")),
1072 log_namespace: LogNamespace::Legacy,
1073 }
1074 }
1075 }
1076
1077 impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockFrameHandler<F> {
1078 fn content_type(&self) -> String {
1079 self.content_type.clone()
1080 }
1081 fn max_frame_length(&self) -> usize {
1082 self.max_frame_length
1083 }
1084
1085 fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1086 let mut log_event = LogEvent::from(frame);
1087
1088 log_event.insert(
1089 log_schema().source_type_key_target_path().unwrap(),
1090 "framestream",
1091 );
1092 if let Some(host) = received_from {
1093 self.log_namespace.insert_source_metadata(
1094 "framestream",
1095 &mut log_event,
1096 self.host_key.as_ref().map(LegacyKey::Overwrite),
1097 path!("host"),
1098 host,
1099 )
1100 }
1101
1102 (self.extra_task_handling_routine.clone())();
1103
1104 Some(log_event.into())
1105 }
1106
1107 fn multithreaded(&self) -> bool {
1108 self.multithreaded
1109 }
1110 fn max_frame_handling_tasks(&self) -> usize {
1111 self.max_frame_handling_tasks
1112 }
1113
1114 fn host_key(&self) -> &Option<OwnedValuePath> {
1115 &self.host_key
1116 }
1117
1118 fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1119 self.timestamp_key.as_ref()
1120 }
1121
1122 fn source_type_key(&self) -> Option<&OwnedValuePath> {
1123 self.source_type_key.as_ref()
1124 }
1125 }
1126
1127 impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockUnixFrameHandler<F> {
1128 fn content_type(&self) -> String {
1129 self.frame_handler.content_type()
1130 }
1131
1132 fn max_frame_length(&self) -> usize {
1133 self.frame_handler.max_frame_length()
1134 }
1135
1136 fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1137 self.frame_handler.handle_event(received_from, frame)
1138 }
1139
1140 fn multithreaded(&self) -> bool {
1141 self.frame_handler.multithreaded()
1142 }
1143
1144 fn max_frame_handling_tasks(&self) -> usize {
1145 self.frame_handler.max_frame_handling_tasks()
1146 }
1147
1148 fn host_key(&self) -> &Option<OwnedValuePath> {
1149 self.frame_handler.host_key()
1150 }
1151
1152 fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1153 self.frame_handler.timestamp_key()
1154 }
1155
1156 fn source_type_key(&self) -> Option<&OwnedValuePath> {
1157 self.frame_handler.source_type_key()
1158 }
1159 }
1160
1161 impl<F: Send + Sync + Clone + FnOnce() + 'static> UnixFrameHandler for MockUnixFrameHandler<F> {
1162 fn socket_path(&self) -> PathBuf {
1163 self.socket_path.clone()
1164 }
1165
1166 fn socket_file_mode(&self) -> Option<u32> {
1167 self.socket_file_mode
1168 }
1169
1170 fn socket_receive_buffer_size(&self) -> Option<usize> {
1171 self.socket_receive_buffer_size
1172 }
1173
1174 fn socket_send_buffer_size(&self) -> Option<usize> {
1175 self.socket_send_buffer_size
1176 }
1177 }
1178
1179 impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockTcpFrameHandler<F> {
1180 fn content_type(&self) -> String {
1181 self.frame_handler.content_type()
1182 }
1183
1184 fn max_frame_length(&self) -> usize {
1185 self.frame_handler.max_frame_length()
1186 }
1187
1188 fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1189 self.frame_handler.handle_event(received_from, frame)
1190 }
1191
1192 fn multithreaded(&self) -> bool {
1193 self.frame_handler.multithreaded()
1194 }
1195
1196 fn max_frame_handling_tasks(&self) -> usize {
1197 self.frame_handler.max_frame_handling_tasks()
1198 }
1199
1200 fn host_key(&self) -> &Option<OwnedValuePath> {
1201 self.frame_handler.host_key()
1202 }
1203
1204 fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1205 self.frame_handler.timestamp_key()
1206 }
1207
1208 fn source_type_key(&self) -> Option<&OwnedValuePath> {
1209 self.frame_handler.source_type_key()
1210 }
1211 }
1212
1213 impl<F: Send + Sync + Clone + FnOnce() + 'static> TcpFrameHandler for MockTcpFrameHandler<F> {
1214 fn address(&self) -> SocketListenAddr {
1215 self.address
1216 }
1217
1218 fn keepalive(&self) -> Option<TcpKeepaliveConfig> {
1219 self.keepalive
1220 }
1221
1222 fn shutdown_timeout_secs(&self) -> Duration {
1223 self.shutdown_timeout_secs
1224 }
1225
1226 fn tls(&self) -> MaybeTlsSettings {
1227 self.tls.clone()
1228 }
1229
1230 fn tls_client_metadata_key(&self) -> Option<OwnedValuePath> {
1231 self.tls_client_metadata_key.clone()
1232 }
1233
1234 fn receive_buffer_bytes(&self) -> Option<usize> {
1235 self.receive_buffer_bytes
1236 }
1237
1238 fn max_connection_duration_secs(&self) -> Option<u64> {
1239 self.max_connection_duration_secs
1240 }
1241
1242 fn max_connections(&self) -> Option<u32> {
1243 self.max_connections
1244 }
1245
1246 fn insert_tls_client_metadata(&mut self, _: Option<CertificateMetadata>) {}
1247
1248 fn allowed_origins(&self) -> Option<&[IpNet]> {
1249 self.permit_origin.as_deref()
1250 }
1251 }
1252
1253 fn init_framestream_tcp(
1254 source_id: &str,
1255 addr: &SocketAddr,
1256 frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
1257 pipeline: SourceSender,
1258 ) -> (JoinHandle<Result<(), ()>>, SourceShutdownCoordinator) {
1259 let source_id = ComponentKey::from(source_id);
1260 let mut shutdown = SourceShutdownCoordinator::default();
1261 let (shutdown_signal, _) = shutdown.register_source(&source_id, false);
1262 let server = build_framestream_tcp_source(frame_handler, shutdown_signal, pipeline)
1263 .expect("Failed to build framestream tcp source.");
1264
1265 let join_handle = tokio::spawn(server);
1266
1267 while std::net::TcpStream::connect(addr).is_err() {
1268 thread::sleep(Duration::from_millis(2));
1269 }
1270
1271 (join_handle, shutdown)
1272 }
1273
1274 fn init_framestream_unix(
1275 source_id: &str,
1276 frame_handler: impl UnixFrameHandler + Send + Sync + Clone + 'static,
1277 pipeline: SourceSender,
1278 ) -> (
1279 PathBuf,
1280 JoinHandle<Result<(), ()>>,
1281 SourceShutdownCoordinator,
1282 ) {
1283 let source_id = ComponentKey::from(source_id);
1284 let socket_path = frame_handler.socket_path();
1285 let mut shutdown = SourceShutdownCoordinator::default();
1286 let (shutdown_signal, _) = shutdown.register_source(&source_id, false);
1287 let server = build_framestream_unix_source(frame_handler, shutdown_signal, pipeline)
1288 .expect("Failed to build framestream unix source.");
1289
1290 let join_handle = tokio::spawn(server);
1291
1292 while std::os::unix::net::UnixStream::connect(&socket_path).is_err() {
1294 thread::sleep(Duration::from_millis(2));
1295 }
1296
1297 (socket_path, join_handle, shutdown)
1298 }
1299
1300 async fn make_tcp_stream(
1301 addr: SocketAddr,
1302 ) -> Framed<TcpStream, length_delimited::LengthDelimitedCodec> {
1303 let socket = TcpStream::connect(&addr).await.unwrap();
1304 Framed::new(socket, length_delimited::Builder::new().new_codec())
1305 }
1306
1307 async fn make_unix_stream(
1308 path: PathBuf,
1309 ) -> Framed<UnixStream, length_delimited::LengthDelimitedCodec> {
1310 let socket = UnixStream::connect(&path).await.unwrap();
1311 Framed::new(socket, length_delimited::Builder::new().new_codec())
1312 }
1313
1314 async fn send_data_frames<S: Sink<Bytes, Error = std::io::Error> + Unpin>(
1315 sock_sink: &mut S,
1316 frames: Vec<Result<Bytes, std::io::Error>>,
1317 ) {
1318 let mut stream = stream::iter(frames.into_iter());
1319 _ = sock_sink.send_all(&mut stream).await;
1321 }
1322
1323 async fn send_control_frame<S: Sink<Bytes, Error = std::io::Error> + Unpin>(
1324 sock_sink: &mut S,
1325 frame: Bytes,
1326 ) {
1327 send_data_frames(sock_sink, vec![Ok(Bytes::new()), Ok(frame)]).await; }
1329
1330 fn create_control_frame(header: ControlHeader) -> Bytes {
1331 Bytes::from(header.to_u32().to_be_bytes().to_vec())
1332 }
1333
1334 fn create_control_frame_with_content(
1335 header: ControlHeader,
1336 content_types: Vec<Bytes>,
1337 ) -> Bytes {
1338 let mut frame = BytesMut::from(&header.to_u32().to_be_bytes()[..]);
1339 for content_type in content_types {
1340 frame.extend(ControlField::ContentType.to_u32().to_be_bytes());
1341 frame.extend((content_type.len() as u32).to_be_bytes());
1342 frame.extend(content_type.clone());
1343 }
1344 Bytes::from(frame)
1345 }
1346
1347 fn assert_accept_frame(frame: &mut BytesMut, expected_content_type: Bytes) {
1348 assert_eq!(&frame[..4], &ControlHeader::Accept.to_u32().to_be_bytes(),);
1351 frame.advance(4);
1352 assert_eq!(
1354 &frame[..4],
1355 &ControlField::ContentType.to_u32().to_be_bytes(),
1356 );
1357 frame.advance(4);
1358 assert_eq!(
1360 &frame[..4],
1361 &(expected_content_type.len() as u32).to_be_bytes(),
1362 );
1363 frame.advance(4);
1364 assert_eq!(&frame[..], &expected_content_type[..]);
1366 }
1367
1368 fn create_frame_handler(multithreaded: bool) -> impl UnixFrameHandler + Send + Sync + Clone {
1369 MockUnixFrameHandler::new("test_content".to_string(), multithreaded, move || {})
1370 }
1371
1372 fn create_tcp_frame_handler(
1373 addr: SocketAddr,
1374 multithreaded: bool,
1375 permit_origin: Option<Vec<IpNet>>,
1376 ) -> impl TcpFrameHandler + Send + Sync + Clone {
1377 MockTcpFrameHandler::new(
1378 addr,
1379 "test_content".to_string(),
1380 multithreaded,
1381 move || {},
1382 permit_origin,
1383 )
1384 }
1385
1386 async fn signal_shutdown(source_name: &str, shutdown: &mut SourceShutdownCoordinator) {
1387 let deadline = Instant::now() + Duration::from_secs(10);
1389 let id = ComponentKey::from(source_name);
1390 let shutdown_complete = shutdown.shutdown_source(&id, deadline);
1391 let shutdown_success = shutdown_complete.await;
1392 assert!(shutdown_success);
1393 }
1394
1395 async fn test_normal_framestream<
1396 T: Sink<Bytes, Error = std::io::Error> + Unpin,
1397 U: Stream<Item = Result<BytesMut, std::io::Error>> + Unpin,
1398 V: Stream<Item = Event> + Unpin,
1399 >(
1400 source_name: &str,
1401 mut sock_sink: T,
1402 mut sock_stream: U,
1403 rx: V,
1404 mut shutdown: SourceShutdownCoordinator,
1405 source_handle: JoinHandle<Result<(), ()>>,
1406 ) {
1407 let content_type = Bytes::from(&b"test_content"[..]);
1409 let ready_msg =
1410 create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1411 send_control_frame(&mut sock_sink, ready_msg).await;
1412
1413 let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1415 assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1417 assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1418
1419 send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Start)).await;
1421
1422 send_data_frames(
1424 &mut sock_sink,
1425 vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1426 )
1427 .await;
1428 let events = collect_n(rx, 2).await;
1429
1430 send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1432
1433 let message_key = log_schema().message_key().unwrap().to_string();
1434 assert!(events
1435 .iter()
1436 .any(|e| e.as_log()[&message_key] == "hello".into()));
1437 assert!(events
1438 .iter()
1439 .any(|e| e.as_log()[&message_key] == "world".into()));
1440
1441 drop(sock_stream); signal_shutdown(source_name, &mut shutdown).await;
1445 _ = source_handle.await.unwrap();
1446 }
1447
1448 async fn test_multiple_content_types<
1449 T: Sink<Bytes, Error = std::io::Error> + Unpin,
1450 U: Stream<Item = Result<BytesMut, std::io::Error>> + Unpin,
1451 >(
1452 source_name: &str,
1453 mut sock_sink: T,
1454 mut sock_stream: U,
1455 mut shutdown: SourceShutdownCoordinator,
1456 source_handle: JoinHandle<Result<(), ()>>,
1457 ) {
1458 let content_type = Bytes::from(&b"test_content"[..]);
1460 let ready_msg = create_control_frame_with_content(
1461 ControlHeader::Ready,
1462 vec![Bytes::from(&b"test_content2"[..]), content_type.clone()],
1463 ); send_control_frame(&mut sock_sink, ready_msg).await;
1465
1466 let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1468
1469 assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1471 assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1472
1473 drop(sock_stream); signal_shutdown(source_name, &mut shutdown).await;
1477 _ = source_handle.await.unwrap();
1478 }
1479
1480 #[tokio::test(flavor = "multi_thread")]
1481 #[should_panic]
1482 async fn blocked_framestream_tcp() {
1483 let source_name = "test_source";
1484 let (tx, rx) = SourceSender::new_test();
1485 let addr = next_addr();
1486 let (source_handle, shutdown) = init_framestream_tcp(
1487 source_name,
1488 &addr,
1489 create_tcp_frame_handler(addr, false, Some(vec![])),
1490 tx,
1491 );
1492 let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1493
1494 test_normal_framestream(
1495 source_name,
1496 sock_sink,
1497 sock_stream,
1498 rx,
1499 shutdown,
1500 source_handle,
1501 )
1502 .await;
1503 }
1504
1505 #[tokio::test(flavor = "multi_thread")]
1506 async fn normal_framestream_singlethreaded_tcp() {
1507 let source_name = "test_source";
1508 let (tx, rx) = SourceSender::new_test();
1509 let addr = next_addr();
1510 let (source_handle, shutdown) = init_framestream_tcp(
1511 source_name,
1512 &addr,
1513 create_tcp_frame_handler(addr, false, None),
1514 tx,
1515 );
1516 let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1517
1518 test_normal_framestream(
1519 source_name,
1520 sock_sink,
1521 sock_stream,
1522 rx,
1523 shutdown,
1524 source_handle,
1525 )
1526 .await;
1527 }
1528
1529 #[tokio::test(flavor = "multi_thread")]
1530 async fn normal_framestream_singlethreaded_unix() {
1531 let source_name = "test_source";
1532 let (tx, rx) = SourceSender::new_test();
1533 let (path, source_handle, shutdown) =
1534 init_framestream_unix(source_name, create_frame_handler(false), tx);
1535 let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1536
1537 test_normal_framestream(
1538 source_name,
1539 sock_sink,
1540 sock_stream,
1541 rx,
1542 shutdown,
1543 source_handle,
1544 )
1545 .await;
1546 }
1547
1548 #[tokio::test(flavor = "multi_thread")]
1549 async fn normal_framestream_multithreaded_tcp() {
1550 let source_name = "test_source";
1551 let (tx, rx) = SourceSender::new_test();
1552 let addr = next_addr();
1553 let (source_handle, shutdown) = init_framestream_tcp(
1554 source_name,
1555 &addr,
1556 create_tcp_frame_handler(addr, true, None),
1557 tx,
1558 );
1559 let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1560
1561 test_normal_framestream(
1562 source_name,
1563 sock_sink,
1564 sock_stream,
1565 rx,
1566 shutdown,
1567 source_handle,
1568 )
1569 .await;
1570 }
1571
1572 #[tokio::test(flavor = "multi_thread")]
1573 async fn normal_framestream_multithreaded_unix() {
1574 let source_name = "test_source";
1575 let (tx, rx) = SourceSender::new_test();
1576 let (path, source_handle, shutdown) =
1577 init_framestream_unix(source_name, create_frame_handler(true), tx);
1578 let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1579
1580 test_normal_framestream(
1581 source_name,
1582 sock_sink,
1583 sock_stream,
1584 rx,
1585 shutdown,
1586 source_handle,
1587 )
1588 .await;
1589 }
1590
1591 #[tokio::test(flavor = "multi_thread")]
1592 async fn multiple_content_types_tcp() {
1593 let source_name = "test_source";
1594 let (tx, _) = SourceSender::new_test();
1595 let addr = next_addr();
1596 let (source_handle, shutdown) = init_framestream_tcp(
1597 source_name,
1598 &addr,
1599 create_tcp_frame_handler(addr, false, None),
1600 tx,
1601 );
1602 let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1603
1604 test_multiple_content_types(source_name, sock_sink, sock_stream, shutdown, source_handle)
1605 .await;
1606 }
1607
1608 #[tokio::test(flavor = "multi_thread")]
1609 async fn multiple_content_types_unix() {
1610 let source_name = "test_source";
1611 let (tx, _) = SourceSender::new_test();
1612 let (path, source_handle, shutdown) =
1613 init_framestream_unix(source_name, create_frame_handler(false), tx);
1614 let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1615
1616 test_multiple_content_types(source_name, sock_sink, sock_stream, shutdown, source_handle)
1617 .await;
1618 }
1619
1620 #[tokio::test(flavor = "multi_thread")]
1621 async fn wrong_content_type() {
1622 let source_name = "test_source";
1623 let (tx, _) = SourceSender::new_test();
1624 let (path, source_handle, mut shutdown) =
1625 init_framestream_unix(source_name, create_frame_handler(false), tx);
1626 let (mut sock_sink, mut sock_stream) = make_unix_stream(path).await.split();
1627
1628 let ready_msg = create_control_frame_with_content(
1630 ControlHeader::Ready,
1631 vec![Bytes::from(&b"test_content2"[..])],
1632 ); send_control_frame(&mut sock_sink, ready_msg).await;
1634
1635 let content_type = Bytes::from(&b"test_content"[..]);
1637 let ready_msg =
1638 create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1639 send_control_frame(&mut sock_sink, ready_msg).await;
1640
1641 let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1643
1644 assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1646 assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1647
1648 drop(sock_stream); signal_shutdown(source_name, &mut shutdown).await;
1652 _ = source_handle.await.unwrap();
1653 }
1654
1655 #[tokio::test(flavor = "multi_thread")]
1656 async fn data_too_soon() {
1657 let source_name = "test_source";
1658 let (tx, rx) = SourceSender::new_test();
1659 let (path, source_handle, mut shutdown) =
1660 init_framestream_unix(source_name, create_frame_handler(false), tx);
1661 let (mut sock_sink, mut sock_stream) = make_unix_stream(path).await.split();
1662
1663 send_data_frames(
1665 &mut sock_sink,
1666 vec![Ok(Bytes::from("bad")), Ok(Bytes::from("data"))],
1667 )
1668 .await;
1669
1670 let content_type = Bytes::from(&b"test_content"[..]);
1672 let ready_msg =
1673 create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1674 send_control_frame(&mut sock_sink, ready_msg).await;
1675
1676 let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1678
1679 assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1681 assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1682
1683 send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Start)).await;
1685
1686 send_data_frames(
1688 &mut sock_sink,
1689 vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1690 )
1691 .await;
1692 let events = collect_n(rx, 2).await;
1693
1694 send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1696
1697 assert_eq!(
1698 events[0].as_log()[log_schema().message_key().unwrap().to_string()],
1699 "hello".into(),
1700 );
1701 assert_eq!(
1702 events[1].as_log()[log_schema().message_key().unwrap().to_string()],
1703 "world".into(),
1704 );
1705
1706 drop(sock_stream); signal_shutdown(source_name, &mut shutdown).await;
1710 _ = source_handle.await.unwrap();
1711 }
1712
1713 #[tokio::test(flavor = "multi_thread")]
1714 async fn unidirectional_framestream() {
1715 let source_name = "test_source";
1716 let (tx, rx) = SourceSender::new_test();
1717 let (path, source_handle, mut shutdown) =
1718 init_framestream_unix(source_name, create_frame_handler(false), tx);
1719 let (mut sock_sink, _) = make_unix_stream(path).await.split();
1720
1721 let content_type = Bytes::from(&b"test_content"[..]);
1723 let start_msg = create_control_frame_with_content(ControlHeader::Start, vec![content_type]);
1724 send_control_frame(&mut sock_sink, start_msg).await;
1725
1726 send_data_frames(
1728 &mut sock_sink,
1729 vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1730 )
1731 .await;
1732 let events = collect_n(rx, 2).await;
1733
1734 send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1736
1737 assert_eq!(
1738 events[0].as_log()[log_schema().message_key().unwrap().to_string()],
1739 "hello".into(),
1740 );
1741 assert_eq!(
1742 events[1].as_log()[log_schema().message_key().unwrap().to_string()],
1743 "world".into(),
1744 );
1745
1746 signal_shutdown(source_name, &mut shutdown).await;
1748 _ = source_handle.await.unwrap();
1749 }
1750
1751 #[tokio::test(flavor = "multi_thread")]
1752 async fn test_spawn_event_handling_tasks() {
1753 let (out, rx) = SourceSender::new_test();
1754
1755 let max_frame_handling_tasks = 20;
1756 let active_task_nums = Arc::new(AtomicUsize::new(0));
1757 let active_task_nums_copy = Arc::clone(&active_task_nums);
1758 let max_task_nums_reached = Arc::new(AtomicUsize::new(0));
1759 let max_task_nums_reached_copy = Arc::clone(&max_task_nums_reached);
1760
1761 let mut join_handles = vec![];
1762 let active_task_nums_copy_2 = Arc::clone(&active_task_nums_copy);
1763 let extra_routine = move || {
1764 thread::sleep(Duration::from_millis(10));
1765 max_task_nums_reached_copy.fetch_max(
1766 active_task_nums_copy_2.load(Ordering::Acquire),
1767 Ordering::AcqRel,
1768 );
1769 };
1770
1771 let total_events = max_frame_handling_tasks * 10;
1772
1773 join_handles.push(tokio::spawn(async move {
1774 future::ready({
1775 let events = collect_n(rx, total_events).await;
1776 assert_eq!(total_events, events.len(), "Missed events");
1777 })
1778 .await;
1779 }));
1780
1781 for i in 0..total_events {
1782 join_handles.push(
1783 spawn_event_handling_tasks(
1784 Bytes::from(format!("event_{i}")),
1785 MockFrameHandler::new("test_content".to_string(), true, extra_routine.clone()),
1786 out.clone(),
1787 None,
1788 Arc::clone(&active_task_nums_copy),
1789 max_frame_handling_tasks,
1790 )
1791 .await,
1792 );
1793 }
1794
1795 future::join_all(join_handles).await;
1796
1797 let final_task_nums = active_task_nums.load(Ordering::Acquire);
1798 assert_eq!(
1799 0, final_task_nums,
1800 "There should be NO left-over tasks at the end"
1801 );
1802
1803 let max_task_nums_reached_value = max_task_nums_reached.load(Ordering::Acquire);
1804 assert!(
1805 max_task_nums_reached_value > 1,
1806 "MultiThreaded mode does NOT work"
1807 );
1808 assert!((max_task_nums_reached_value - max_frame_handling_tasks) < 2, "Max number of tasks at any given time should NOT Exceed max_frame_handling_tasks too much");
1809 }
1810}