vector/sources/util/
framestream.rs

1use ipnet::IpNet;
2#[cfg(unix)]
3use std::os::unix::{fs::PermissionsExt, io::AsRawFd};
4use std::{
5    convert::TryInto,
6    fs,
7    marker::{Send, Sync},
8    net::SocketAddr,
9    path::PathBuf,
10    sync::{
11        atomic::{AtomicUsize, Ordering},
12        Arc, Mutex,
13    },
14    time::Duration,
15};
16
17use bytes::{Buf, Bytes, BytesMut};
18use futures::{
19    executor::block_on,
20    future::{self, OptionFuture},
21    sink::{Sink, SinkExt},
22    stream::{self, StreamExt, TryStreamExt},
23};
24use futures_util::{future::BoxFuture, Future, FutureExt};
25use listenfd::ListenFd;
26use tokio::{
27    self,
28    io::{AsyncRead, AsyncWrite},
29    net::{TcpStream, UnixListener},
30    task::JoinHandle,
31    time::sleep,
32};
33use tokio_stream::wrappers::UnixListenerStream;
34use tokio_util::codec::{length_delimited, Framed};
35use tracing::{field, Instrument, Span};
36use vector_lib::{
37    lookup::OwnedValuePath,
38    tcp::TcpKeepaliveConfig,
39    tls::{CertificateMetadata, MaybeTlsIncomingStream, MaybeTlsSettings},
40};
41
42use crate::{
43    event::Event,
44    internal_events::{
45        ConnectionOpen, OpenGauge, SocketBindError, SocketMode, SocketReceiveError,
46        TcpBytesReceived, TcpSocketError, TcpSocketTlsConnectionError, UnixSocketError,
47        UnixSocketFileDeleteError,
48    },
49    shutdown::ShutdownSignal,
50    sources::{
51        util::{
52            net::{try_bind_tcp_listener, MAX_IN_FLIGHT_EVENTS_TARGET},
53            AfterReadExt,
54        },
55        Source,
56    },
57    SourceSender,
58};
59
60use super::net::{RequestLimiter, SocketListenAddr};
61
62const FSTRM_CONTROL_FRAME_LENGTH_MAX: usize = 512;
63const FSTRM_CONTROL_FIELD_CONTENT_TYPE_LENGTH_MAX: usize = 256;
64
65/// If a connection does not receive any data during this short timeout,
66/// it should release its permit (and try to obtain a new one) allowing other connections to read.
67/// It is very short because any incoming data will avoid this timeout,
68/// so it mainly prevents holding permits without consuming any data
69const PERMIT_HOLD_TIMEOUT_MS: u64 = 10;
70
71pub type FrameStreamSink = Box<dyn Sink<Bytes, Error = std::io::Error> + Send + Unpin>;
72
73pub struct FrameStreamReader {
74    response_sink: Mutex<FrameStreamSink>,
75    expected_content_type: String,
76    state: FrameStreamState,
77}
78
79struct FrameStreamState {
80    expect_control_frame: bool,
81    control_state: ControlState,
82    is_bidirectional: bool,
83}
84impl FrameStreamState {
85    const fn new() -> Self {
86        FrameStreamState {
87            expect_control_frame: false,
88            //first control frame should be READY (if bidirectional -- if unidirectional first will be START)
89            control_state: ControlState::Initial,
90            is_bidirectional: true, //assume
91        }
92    }
93}
94
95#[derive(PartialEq, Debug)]
96enum ControlState {
97    Initial,
98    GotReady,
99    ReadingData,
100    Stopped,
101}
102
103#[derive(Copy, Clone)]
104enum ControlHeader {
105    Accept,
106    Start,
107    Stop,
108    Ready,
109    Finish,
110}
111
112impl ControlHeader {
113    fn from_u32(val: u32) -> Result<Self, ()> {
114        match val {
115            0x01 => Ok(ControlHeader::Accept),
116            0x02 => Ok(ControlHeader::Start),
117            0x03 => Ok(ControlHeader::Stop),
118            0x04 => Ok(ControlHeader::Ready),
119            0x05 => Ok(ControlHeader::Finish),
120            _ => {
121                error!("Don't know header value {} (expected 0x01 - 0x05).", val);
122                Err(())
123            }
124        }
125    }
126
127    const fn to_u32(self) -> u32 {
128        match self {
129            ControlHeader::Accept => 0x01,
130            ControlHeader::Start => 0x02,
131            ControlHeader::Stop => 0x03,
132            ControlHeader::Ready => 0x04,
133            ControlHeader::Finish => 0x05,
134        }
135    }
136}
137
138enum ControlField {
139    ContentType,
140}
141
142impl ControlField {
143    fn from_u32(val: u32) -> Result<Self, ()> {
144        match val {
145            0x01 => Ok(ControlField::ContentType),
146            _ => {
147                error!("Don't know field type {} (expected 0x01).", val);
148                Err(())
149            }
150        }
151    }
152    const fn to_u32(&self) -> u32 {
153        match self {
154            ControlField::ContentType => 0x01,
155        }
156    }
157}
158
159fn advance_u32(b: &mut Bytes) -> Result<u32, ()> {
160    if b.len() < 4 {
161        error!("Malformed frame.");
162        return Err(());
163    }
164    let a = b.split_to(4);
165    Ok(u32::from_be_bytes(a[..].try_into().unwrap()))
166}
167
168impl FrameStreamReader {
169    pub fn new(response_sink: FrameStreamSink, expected_content_type: String) -> Self {
170        FrameStreamReader {
171            response_sink: Mutex::new(response_sink),
172            expected_content_type,
173            state: FrameStreamState::new(),
174        }
175    }
176
177    pub fn handle_frame(&mut self, frame: Bytes) -> Option<Bytes> {
178        if frame.is_empty() {
179            //frame length of zero means the next frame is a control frame
180            self.state.expect_control_frame = true;
181            None
182        } else if self.state.expect_control_frame {
183            self.state.expect_control_frame = false;
184            _ = self.handle_control_frame(frame);
185            None
186        } else {
187            //data frame
188            if self.state.control_state == ControlState::ReadingData {
189                Some(frame) //return data frame
190            } else {
191                error!(
192                    "Received a data frame while in state {:?}.",
193                    self.state.control_state
194                );
195                None
196            }
197        }
198    }
199
200    fn handle_control_frame(&mut self, mut frame: Bytes) -> Result<(), ()> {
201        //enforce maximum control frame size
202        if frame.len() > FSTRM_CONTROL_FRAME_LENGTH_MAX {
203            error!("Control frame is too long.");
204        }
205
206        let header = ControlHeader::from_u32(advance_u32(&mut frame)?)?;
207
208        //match current state to received header
209        match self.state.control_state {
210            ControlState::Initial => {
211                match header {
212                    ControlHeader::Ready => {
213                        let content_type = self.process_fields(header, &mut frame)?.unwrap();
214
215                        self.send_control_frame(Self::make_frame(
216                            ControlHeader::Accept,
217                            Some(content_type),
218                        ));
219                        self.state.control_state = ControlState::GotReady; //waiting for a START control frame
220                    }
221                    ControlHeader::Start => {
222                        //check for content type
223                        _ = self.process_fields(header, &mut frame)?;
224                        //if didn't error, then we are ok to change state
225                        self.state.control_state = ControlState::ReadingData;
226                        self.state.is_bidirectional = false; //if first message was START then we are unidirectional (no responses)
227                    }
228                    _ => error!("Got wrong control frame, expected READY."),
229                }
230            }
231            ControlState::GotReady => {
232                match header {
233                    ControlHeader::Start => {
234                        //check for content type
235                        _ = self.process_fields(header, &mut frame)?;
236                        //if didn't error, then we are ok to change state
237                        self.state.control_state = ControlState::ReadingData;
238                    }
239                    _ => error!("Got wrong control frame, expected START."),
240                }
241            }
242            ControlState::ReadingData => {
243                match header {
244                    ControlHeader::Stop => {
245                        //check there aren't any fields
246                        _ = self.process_fields(header, &mut frame)?;
247                        if self.state.is_bidirectional {
248                            //send FINISH frame -- but only if we are bidirectional
249                            self.send_control_frame(Self::make_frame(ControlHeader::Finish, None));
250                        }
251                        self.state.control_state = ControlState::Stopped; //stream is now done
252                    }
253                    _ => error!("Got wrong control frame, expected STOP."),
254                }
255            }
256            ControlState::Stopped => error!("Unexpected control frame, current state is STOPPED."),
257        };
258        Ok(())
259    }
260
261    fn process_fields(
262        &mut self,
263        header: ControlHeader,
264        frame: &mut Bytes,
265    ) -> Result<Option<String>, ()> {
266        match header {
267            ControlHeader::Ready => {
268                //should provide 1+ content types
269                //should match expected content type
270                let is_start_frame = false;
271                let content_type = self.process_content_type(frame, is_start_frame)?;
272                Ok(Some(content_type))
273            }
274            ControlHeader::Start => {
275                //can take one or zero content types
276                if frame.is_empty() {
277                    Ok(None)
278                } else {
279                    //should match expected content type
280                    let is_start_frame = true;
281                    let content_type = self.process_content_type(frame, is_start_frame)?;
282                    Ok(Some(content_type))
283                }
284            }
285            ControlHeader::Stop => {
286                //check that there are no fields
287                if !frame.is_empty() {
288                    error!("Unexpected fields in STOP header.");
289                    Err(())
290                } else {
291                    Ok(None)
292                }
293            }
294            _ => {
295                error!("Unexpected control header value {:?}.", header.to_u32());
296                Err(())
297            }
298        }
299    }
300
301    fn process_content_type(&self, frame: &mut Bytes, is_start_frame: bool) -> Result<String, ()> {
302        if frame.is_empty() {
303            error!("No fields in control frame.");
304            return Err(());
305        }
306
307        let mut content_types = vec![];
308        while !frame.is_empty() {
309            //4 bytes of ControlField
310            let field_val = advance_u32(frame)?;
311            let field_type = ControlField::from_u32(field_val)?;
312            match field_type {
313                ControlField::ContentType => {
314                    //4 bytes giving length of content type
315                    let field_len = advance_u32(frame)? as usize;
316
317                    //enforce limit on content type string
318                    if field_len > FSTRM_CONTROL_FIELD_CONTENT_TYPE_LENGTH_MAX {
319                        error!("Content-Type string is too long.");
320                        return Err(());
321                    }
322
323                    let content_type = std::str::from_utf8(&frame[..field_len]).unwrap();
324                    content_types.push(content_type.to_string());
325                    frame.advance(field_len);
326                }
327            }
328        }
329
330        if is_start_frame && content_types.len() > 1 {
331            error!(
332                "START control frame can only have one content-type provided (got {}).",
333                content_types.len()
334            );
335            return Err(());
336        }
337
338        for content_type in &content_types {
339            if *content_type == self.expected_content_type {
340                return Ok(content_type.clone());
341            }
342        }
343
344        error!(
345            "Content types did not match up. Expected {} got {:?}.",
346            self.expected_content_type, content_types
347        );
348        Err(())
349    }
350
351    fn make_frame(header: ControlHeader, content_type: Option<String>) -> Bytes {
352        let mut frame = BytesMut::new();
353        frame.extend(header.to_u32().to_be_bytes());
354        if let Some(s) = content_type {
355            frame.extend(ControlField::ContentType.to_u32().to_be_bytes()); //field type: ContentType
356            frame.extend((s.len() as u32).to_be_bytes()); //length of type
357            frame.extend(s.as_bytes());
358        }
359        Bytes::from(frame)
360    }
361
362    fn send_control_frame(&mut self, frame: Bytes) {
363        let empty_frame = Bytes::from(&b""[..]); //send empty frame to say we are control frame
364        let mut stream = stream::iter(vec![Ok(empty_frame), Ok(frame)]);
365
366        if let Err(e) = block_on(self.response_sink.lock().unwrap().send_all(&mut stream)) {
367            error!("Encountered error '{:#?}' while sending control frame.", e);
368        }
369    }
370}
371
372pub trait FrameHandler {
373    fn content_type(&self) -> String;
374    fn max_frame_length(&self) -> usize;
375    fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event>;
376    fn multithreaded(&self) -> bool;
377    fn max_frame_handling_tasks(&self) -> usize;
378    fn host_key(&self) -> &Option<OwnedValuePath>;
379    fn timestamp_key(&self) -> Option<&OwnedValuePath>;
380    fn source_type_key(&self) -> Option<&OwnedValuePath>;
381}
382
383pub trait UnixFrameHandler: FrameHandler {
384    fn socket_path(&self) -> PathBuf;
385    fn socket_file_mode(&self) -> Option<u32>;
386    fn socket_receive_buffer_size(&self) -> Option<usize>;
387    fn socket_send_buffer_size(&self) -> Option<usize>;
388}
389
390pub trait TcpFrameHandler: FrameHandler {
391    fn address(&self) -> SocketListenAddr;
392    fn keepalive(&self) -> Option<TcpKeepaliveConfig>;
393    fn shutdown_timeout_secs(&self) -> Duration;
394    fn tls(&self) -> MaybeTlsSettings;
395    fn tls_client_metadata_key(&self) -> Option<OwnedValuePath>;
396    fn receive_buffer_bytes(&self) -> Option<usize>;
397    fn max_connection_duration_secs(&self) -> Option<u64>;
398    fn max_connections(&self) -> Option<u32>;
399    fn allowed_origins(&self) -> Option<&[IpNet]>;
400    fn insert_tls_client_metadata(&mut self, metadata: Option<CertificateMetadata>);
401}
402
403/**
404 * Based off of the build_framestream_unix_source function.
405 * Functions similarly, just uses TCP socket instead of unix socket
406 **/
407pub fn build_framestream_tcp_source(
408    frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
409    shutdown: ShutdownSignal,
410    out: SourceSender,
411) -> crate::Result<Source> {
412    let addr = frame_handler.address();
413    let tls = frame_handler.tls();
414    let shutdown = shutdown.clone();
415    let out = out.clone();
416
417    Ok(Box::pin(async move {
418        let listenfd = ListenFd::from_env();
419        let listener = try_bind_tcp_listener(
420            addr,
421            listenfd,
422            &tls,
423            frame_handler
424                .allowed_origins()
425                .map(|origins| origins.to_vec()),
426        )
427        .await
428        .map_err(|error| {
429            emit!(SocketBindError {
430                mode: SocketMode::Tcp,
431                error: &error,
432            })
433        })?;
434
435        info!(
436            message = "Listening.",
437            addr = %listener
438                .local_addr()
439                .map(SocketListenAddr::SocketAddr)
440                .unwrap_or(addr)
441        );
442
443        let tripwire = shutdown.clone();
444        let shutdown_timeout_secs = frame_handler.shutdown_timeout_secs();
445        let tripwire = async move {
446            _ = tripwire.await;
447            sleep(shutdown_timeout_secs).await;
448        }
449        .shared();
450
451        let connection_gauge = OpenGauge::new();
452        let shutdown_clone = shutdown.clone();
453
454        let request_limiter = RequestLimiter::new(
455            MAX_IN_FLIGHT_EVENTS_TARGET,
456            frame_handler.max_frame_handling_tasks(),
457        );
458
459        listener
460            .accept_stream_limited(frame_handler.max_connections())
461            .take_until(shutdown_clone)
462            .for_each(move |(connection, tcp_connection_permit)| {
463                let shutdown_signal = shutdown.clone();
464                let tripwire = tripwire.clone();
465                let out = out.clone();
466                let connection_gauge = connection_gauge.clone();
467                let request_limiter = request_limiter.clone();
468                let frame_handler_clone = frame_handler.clone();
469
470                async move {
471                    let socket = match connection {
472                        Ok(socket) => socket,
473                        Err(error) => {
474                            emit!(SocketReceiveError {
475                                mode: SocketMode::Tcp,
476                                error: &error
477                            });
478                            return;
479                        }
480                    };
481
482                    let peer_addr = socket.peer_addr();
483                    let span = info_span!("connection", %peer_addr);
484
485                    let tripwire = tripwire
486                        .map(move |_| {
487                            info!(
488                                message = "Resetting connection (still open after seconds).",
489                                seconds = ?shutdown_timeout_secs
490                            );
491                        })
492                        .boxed();
493
494                    span.clone().in_scope(|| {
495                        debug!(message = "Accepted a new connection.", peer_addr = %peer_addr);
496
497                        let open_token =
498                            connection_gauge.open(|count| emit!(ConnectionOpen { count }));
499
500                        let fut = handle_stream(
501                            frame_handler_clone,
502                            shutdown_signal,
503                            socket,
504                            tripwire,
505                            peer_addr,
506                            out,
507                            request_limiter,
508                        );
509
510                        tokio::spawn(
511                            fut.map(move |()| {
512                                drop(open_token);
513                                drop(tcp_connection_permit);
514                            })
515                            .instrument(span.or_current()),
516                        );
517                    });
518                }
519            })
520            .map(Ok)
521            .await
522    }))
523}
524
525#[allow(clippy::too_many_arguments)]
526async fn handle_stream(
527    mut frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
528    mut shutdown_signal: ShutdownSignal,
529    mut socket: MaybeTlsIncomingStream<TcpStream>,
530    mut tripwire: BoxFuture<'static, ()>,
531    peer_addr: SocketAddr,
532    out: SourceSender,
533    request_limiter: RequestLimiter,
534) {
535    tokio::select! {
536        result = socket.handshake() => {
537            if let Err(error) = result {
538                emit!(TcpSocketTlsConnectionError { error });
539                return;
540            }
541        },
542        _ = &mut shutdown_signal => {
543            return;
544        }
545    };
546
547    if let Some(keepalive) = frame_handler.keepalive() {
548        if let Err(error) = socket.set_keepalive(keepalive) {
549            warn!(message = "Failed configuring TCP keepalive.", %error);
550        }
551    }
552
553    if let Some(receive_buffer_bytes) = frame_handler.receive_buffer_bytes() {
554        if let Err(error) = socket.set_receive_buffer_bytes(receive_buffer_bytes) {
555            warn!(message = "Failed configuring receive buffer size on TCP socket.", %error);
556        }
557    }
558
559    let socket = socket.after_read(move |byte_size| {
560        emit!(TcpBytesReceived {
561            byte_size,
562            peer_addr
563        });
564    });
565
566    let certificate_metadata = socket
567        .get_ref()
568        .ssl_stream()
569        .and_then(|stream| stream.ssl().peer_certificate())
570        .map(CertificateMetadata::from);
571
572    frame_handler.insert_tls_client_metadata(certificate_metadata);
573
574    let span = info_span!("connection");
575    span.record("peer_addr", field::debug(&peer_addr));
576    let received_from: Option<Bytes> = Some(peer_addr.to_string().into());
577
578    let connection_close_timeout = OptionFuture::from(
579        frame_handler
580            .max_connection_duration_secs()
581            .map(|timeout_secs| tokio::time::sleep(Duration::from_secs(timeout_secs))),
582    );
583    tokio::pin!(connection_close_timeout);
584
585    let content_type = frame_handler.content_type();
586    let mut event_sink = out.clone();
587    let (sock_sink, sock_stream) = Framed::new(
588        socket,
589        length_delimited::Builder::new()
590            .max_frame_length(frame_handler.max_frame_length())
591            .new_codec(),
592    )
593    .split();
594    let mut reader = FrameStreamReader::new(Box::new(sock_sink), content_type);
595    let mut frames = sock_stream
596        .map_err(move |error| {
597            emit!(TcpSocketError {
598                error: &error,
599                peer_addr,
600            });
601        })
602        .filter_map(move |frame| {
603            future::ready(match frame {
604                Ok(f) => reader.handle_frame(Bytes::from(f)),
605                Err(_) => None,
606            })
607        });
608
609    let active_parsing_task_nums = Arc::new(AtomicUsize::new(0));
610    loop {
611        let mut permit = tokio::select! {
612            _ = &mut tripwire => break,
613            Some(_) = &mut connection_close_timeout  => {
614                break;
615            },
616            _ = &mut shutdown_signal => {
617                break;
618            },
619            permit = request_limiter.acquire() => {
620                Some(permit)
621            }
622            else => break,
623        };
624
625        let timeout = tokio::time::sleep(Duration::from_millis(PERMIT_HOLD_TIMEOUT_MS));
626        tokio::pin!(timeout);
627
628        tokio::select! {
629            _ = &mut tripwire => break,
630            _ = &mut shutdown_signal => break,
631            _ = &mut timeout => {
632                // This connection is currently holding a permit, but has not received data for some time. Release
633                // the permit to let another connection try
634                continue;
635            }
636            res = frames.next() => {
637                match res {
638                    Some(frame) => {
639                        if let Some(permit) = &mut permit {
640                            // Note that this is intentionally not the "number of events in a single request", but rather
641                            // the "number of events currently available". This may contain events from multiple events,
642                            // but it should always contain all events from each request.
643                            permit.decoding_finished(1);
644                        };
645                        handle_tcp_frame(&mut frame_handler, frame, &mut event_sink, received_from.clone(), Arc::clone(&active_parsing_task_nums)).await;
646                    }
647                    None => {
648                        debug!("Connection closed.");
649                        break
650                    },
651                }
652            }
653            else => break,
654        }
655
656        drop(permit);
657    }
658}
659
660async fn handle_tcp_frame<T>(
661    frame_handler: &mut T,
662    frame: Bytes,
663    event_sink: &mut SourceSender,
664    received_from: Option<Bytes>,
665    active_parsing_task_nums: Arc<AtomicUsize>,
666) where
667    T: TcpFrameHandler + Send + Sync + Clone + 'static,
668{
669    if frame_handler.multithreaded() {
670        spawn_event_handling_tasks(
671            frame,
672            frame_handler.clone(),
673            event_sink.clone(),
674            received_from,
675            active_parsing_task_nums,
676            frame_handler.max_frame_handling_tasks(),
677        )
678        .await;
679    } else if let Some(event) = frame_handler.handle_event(received_from, frame) {
680        if let Err(e) = event_sink.send_event(event).await {
681            error!(
682                internal_log_rate_limit = true,
683                "Error sending event: {e:?}."
684            );
685        }
686    }
687}
688
689/**
690 * Based off of the build_unix_source function.
691 * Functions similarly, but uses the FrameStreamReader to deal with
692 * framestream control packets, and responds appropriately.
693 **/
694pub fn build_framestream_unix_source(
695    frame_handler: impl UnixFrameHandler + Send + Sync + Clone + 'static,
696    shutdown: ShutdownSignal,
697    out: SourceSender,
698) -> crate::Result<Source> {
699    let path = frame_handler.socket_path();
700
701    //check if the path already exists (and try to delete it)
702    match fs::metadata(&path) {
703        Ok(_) => {
704            //exists, so try to delete it
705            info!(message = "Deleting file.", ?path);
706            fs::remove_file(&path)?;
707        }
708        Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {} //doesn't exist, do nothing
709        Err(e) => {
710            error!("Unable to get socket information; error = {:?}.", e);
711            return Err(Box::new(e));
712        }
713    };
714
715    let listener = UnixListener::bind(&path)?;
716
717    // system's 'net.core.rmem_max' might have to be changed if socket receive buffer is not updated properly
718    if let Some(socket_receive_buffer_size) = frame_handler.socket_receive_buffer_size() {
719        _ = nix::sys::socket::setsockopt(
720            listener.as_raw_fd(),
721            nix::sys::socket::sockopt::RcvBuf,
722            &(socket_receive_buffer_size),
723        );
724        let rcv_buf_size =
725            nix::sys::socket::getsockopt(listener.as_raw_fd(), nix::sys::socket::sockopt::RcvBuf);
726        info!(
727            "Unix socket receive buffer size modified to {}.",
728            rcv_buf_size.unwrap()
729        );
730    }
731
732    // system's 'net.core.wmem_max' might have to be changed if socket send buffer is not updated properly
733    if let Some(socket_send_buffer_size) = frame_handler.socket_send_buffer_size() {
734        _ = nix::sys::socket::setsockopt(
735            listener.as_raw_fd(),
736            nix::sys::socket::sockopt::SndBuf,
737            &(socket_send_buffer_size),
738        );
739        let snd_buf_size =
740            nix::sys::socket::getsockopt(listener.as_raw_fd(), nix::sys::socket::sockopt::SndBuf);
741        info!(
742            "Unix socket buffer send size modified to {}.",
743            snd_buf_size.unwrap()
744        );
745    }
746
747    // the permissions to unix socket are restricted from 0o700 to 0o777, which are 448 and 511 in decimal
748    if let Some(socket_permission) = frame_handler.socket_file_mode() {
749        if !(448..=511).contains(&socket_permission) {
750            return Err(format!(
751                "Invalid Socket permission {socket_permission:#o}. Must between 0o700 and 0o777."
752            )
753            .into());
754        }
755        match fs::set_permissions(&path, fs::Permissions::from_mode(socket_permission)) {
756            Ok(_) => {
757                info!("Socket permissions updated to {:#o}.", socket_permission);
758            }
759            Err(e) => {
760                error!(
761                    "Failed to update listener socket permissions; error = {:?}.",
762                    e
763                );
764                return Err(Box::new(e));
765            }
766        };
767    };
768
769    let fut = async move {
770        let active_parsing_task_nums = Arc::new(AtomicUsize::new(0));
771
772        info!(message = "Listening...", ?path, r#type = "unix");
773
774        let mut stream = UnixListenerStream::new(listener).take_until(shutdown.clone());
775        while let Some(socket) = stream.next().await {
776            let socket = match socket {
777                Err(e) => {
778                    error!("Failed to accept socket; error = {:?}.", e);
779                    continue;
780                }
781                Ok(s) => s,
782            };
783            let peer_addr = socket.peer_addr().ok();
784            let listen_path = path.clone();
785            let active_task_nums_ = Arc::clone(&active_parsing_task_nums);
786
787            let span = info_span!("connection");
788            let path = if let Some(addr) = peer_addr {
789                if let Some(path) = addr.as_pathname().map(|e| e.to_owned()) {
790                    span.record("peer_path", field::debug(&path));
791                    Some(path)
792                } else {
793                    None
794                }
795            } else {
796                None
797            };
798            let received_from: Option<Bytes> =
799                path.map(|p| p.to_string_lossy().into_owned().into());
800
801            build_framestream_source(
802                frame_handler.clone(),
803                socket,
804                received_from,
805                out.clone(),
806                shutdown.clone(),
807                span,
808                active_task_nums_,
809                move |error| {
810                    emit!(UnixSocketError {
811                        error: &error,
812                        path: &listen_path,
813                    });
814                },
815            );
816        }
817
818        // Cleanup
819        drop(stream);
820
821        // Delete socket file
822        if let Err(error) = fs::remove_file(&path) {
823            emit!(UnixSocketFileDeleteError { path: &path, error });
824        }
825
826        Ok(())
827    };
828
829    Ok(Box::pin(fut))
830}
831
832#[allow(clippy::too_many_arguments)]
833fn build_framestream_source<T: Send + 'static>(
834    frame_handler: impl FrameHandler + Send + Sync + Clone + 'static,
835    socket: impl AsyncRead + AsyncWrite + Send + 'static,
836    received_from: Option<Bytes>,
837    out: SourceSender,
838    shutdown: impl Future<Output = T> + Unpin + Send + 'static,
839    span: Span,
840    active_task_nums: Arc<AtomicUsize>,
841    error_mapper: impl FnMut(std::io::Error) + Send + 'static,
842) {
843    let content_type = frame_handler.content_type();
844    let mut event_sink = out.clone();
845    let (sock_sink, sock_stream) = Framed::new(
846        socket,
847        length_delimited::Builder::new()
848            .max_frame_length(frame_handler.max_frame_length())
849            .new_codec(),
850    )
851    .split();
852    let mut fs_reader = FrameStreamReader::new(Box::new(sock_sink), content_type);
853    let frame_handler_copy = frame_handler.clone();
854    let frames = sock_stream
855        .take_until(shutdown)
856        .map_err(error_mapper)
857        .filter_map(move |frame| {
858            future::ready(match frame {
859                Ok(f) => fs_reader.handle_frame(Bytes::from(f)),
860                Err(_) => None,
861            })
862        });
863    if !frame_handler.multithreaded() {
864        let mut events = frames.filter_map(move |f| {
865            future::ready(frame_handler_copy.handle_event(received_from.clone(), f))
866        });
867
868        let handler = async move {
869            if let Err(e) = event_sink.send_event_stream(&mut events).await {
870                error!(
871                    internal_log_rate_limit = true,
872                    "Error sending event: {:?}.", e
873                );
874            }
875
876            info!("Finished sending.");
877        };
878        tokio::spawn(handler.instrument(span.or_current()));
879    } else {
880        let handler = async move {
881            frames
882                .for_each(move |f| {
883                    let max_frame_handling_tasks = frame_handler_copy.max_frame_handling_tasks();
884                    let f_handler = frame_handler_copy.clone();
885                    let received_from_copy = received_from.clone();
886                    let event_sink_copy = event_sink.clone();
887                    let active_task_nums_copy = Arc::clone(&active_task_nums);
888
889                    async move {
890                        spawn_event_handling_tasks(
891                            f,
892                            f_handler,
893                            event_sink_copy,
894                            received_from_copy,
895                            active_task_nums_copy,
896                            max_frame_handling_tasks,
897                        )
898                        .await;
899                    }
900                })
901                .await;
902            info!("Finished sending.");
903        };
904        tokio::spawn(handler.instrument(span.or_current()));
905    }
906}
907
908async fn spawn_event_handling_tasks(
909    event_data: Bytes,
910    event_handler: impl FrameHandler + Send + Sync + 'static,
911    mut event_sink: SourceSender,
912    received_from: Option<Bytes>,
913    active_task_nums: Arc<AtomicUsize>,
914    max_frame_handling_tasks: usize,
915) -> JoinHandle<()> {
916    wait_for_task_quota(&active_task_nums, max_frame_handling_tasks).await;
917
918    tokio::spawn(async move {
919        future::ready({
920            if let Some(evt) = event_handler.handle_event(received_from, event_data) {
921                if event_sink.send_event(evt).await.is_err() {
922                    error!("Encountered error while sending event.");
923                }
924            }
925            active_task_nums.fetch_sub(1, Ordering::AcqRel);
926        })
927        .await;
928    })
929}
930
931async fn wait_for_task_quota(active_task_nums: &Arc<AtomicUsize>, max_tasks: usize) {
932    while max_tasks > 0 && max_tasks < active_task_nums.load(Ordering::Acquire) {
933        tokio::time::sleep(Duration::from_millis(3)).await;
934    }
935    active_task_nums.fetch_add(1, Ordering::AcqRel);
936}
937
938#[cfg(test)]
939mod test {
940    use futures_util::Stream;
941    use std::net::SocketAddr;
942    #[cfg(unix)]
943    use std::{
944        path::PathBuf,
945        sync::{
946            atomic::{AtomicUsize, Ordering},
947            Arc,
948        },
949        thread,
950    };
951    use tokio::net::TcpStream;
952
953    use bytes::{buf::Buf, Bytes, BytesMut};
954    use futures::{
955        future,
956        sink::{Sink, SinkExt},
957        stream::{self, StreamExt},
958    };
959    use ipnet::IpNet;
960    use tokio::{
961        self,
962        net::UnixStream,
963        task::JoinHandle,
964        time::{Duration, Instant},
965    };
966    use tokio_util::codec::{length_delimited, Framed};
967    use vector_lib::{
968        config::{LegacyKey, LogNamespace},
969        tcp::TcpKeepaliveConfig,
970        tls::{CertificateMetadata, MaybeTls},
971    };
972    use vector_lib::{
973        lookup::{owned_value_path, path, OwnedValuePath},
974        tls::MaybeTlsSettings,
975    };
976
977    use super::{
978        build_framestream_tcp_source, build_framestream_unix_source, spawn_event_handling_tasks,
979        ControlField, ControlHeader, FrameHandler, TcpFrameHandler, UnixFrameHandler,
980    };
981    use crate::{
982        config::{log_schema, ComponentKey},
983        event::{Event, LogEvent},
984        shutdown::SourceShutdownCoordinator,
985        sources::util::net::SocketListenAddr,
986        test_util::{collect_n, collect_n_stream, next_addr},
987        SourceSender,
988    };
989
990    #[derive(Clone)]
991    struct MockFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
992        content_type: String,
993        max_frame_length: usize,
994        multithreaded: bool,
995        max_frame_handling_tasks: usize,
996        extra_task_handling_routine: F,
997        host_key: Option<OwnedValuePath>,
998        timestamp_key: Option<OwnedValuePath>,
999        source_type_key: Option<OwnedValuePath>,
1000        log_namespace: LogNamespace,
1001    }
1002
1003    #[derive(Clone)]
1004    struct MockUnixFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
1005        frame_handler: MockFrameHandler<F>,
1006        socket_path: PathBuf,
1007        socket_file_mode: Option<u32>,
1008        socket_receive_buffer_size: Option<usize>,
1009        socket_send_buffer_size: Option<usize>,
1010    }
1011
1012    #[derive(Clone)]
1013    struct MockTcpFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
1014        frame_handler: MockFrameHandler<F>,
1015        address: SocketListenAddr,
1016        keepalive: Option<TcpKeepaliveConfig>,
1017        shutdown_timeout_secs: Duration,
1018        tls: MaybeTlsSettings,
1019        tls_client_metadata_key: Option<OwnedValuePath>,
1020        receive_buffer_bytes: Option<usize>,
1021        max_connection_duration_secs: Option<u64>,
1022        max_connections: Option<u32>,
1023        permit_origin: Option<Vec<IpNet>>,
1024    }
1025
1026    impl<F: Send + Sync + Clone + FnOnce() + 'static> MockTcpFrameHandler<F> {
1027        pub fn new(
1028            addr: SocketAddr,
1029            content_type: String,
1030            multithreaded: bool,
1031            extra_routine: F,
1032            permit_origin: Option<Vec<IpNet>>,
1033        ) -> Self {
1034            Self {
1035                frame_handler: MockFrameHandler::new(content_type, multithreaded, extra_routine),
1036                address: addr.into(),
1037                keepalive: None,
1038                shutdown_timeout_secs: Duration::from_secs(30),
1039                tls: MaybeTls::Raw(()),
1040                tls_client_metadata_key: None,
1041                receive_buffer_bytes: None,
1042                max_connection_duration_secs: None,
1043                max_connections: None,
1044                permit_origin,
1045            }
1046        }
1047    }
1048
1049    impl<F: Send + Sync + Clone + FnOnce() + 'static> MockUnixFrameHandler<F> {
1050        pub fn new(content_type: String, multithreaded: bool, extra_routine: F) -> Self {
1051            Self {
1052                frame_handler: MockFrameHandler::new(content_type, multithreaded, extra_routine),
1053                socket_path: tempfile::tempdir().unwrap().keep().join("unix_test"),
1054                socket_file_mode: None,
1055                socket_receive_buffer_size: None,
1056                socket_send_buffer_size: None,
1057            }
1058        }
1059    }
1060
1061    impl<F: Send + Sync + Clone + FnOnce() + 'static> MockFrameHandler<F> {
1062        pub fn new(content_type: String, multithreaded: bool, extra_routine: F) -> Self {
1063            Self {
1064                content_type,
1065                max_frame_length: bytesize::kib(100u64) as usize,
1066                multithreaded,
1067                max_frame_handling_tasks: 0,
1068                extra_task_handling_routine: extra_routine,
1069                host_key: Some(owned_value_path!("test_framestream")),
1070                timestamp_key: Some(owned_value_path!("my_timestamp")),
1071                source_type_key: Some(owned_value_path!("source_type")),
1072                log_namespace: LogNamespace::Legacy,
1073            }
1074        }
1075    }
1076
1077    impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockFrameHandler<F> {
1078        fn content_type(&self) -> String {
1079            self.content_type.clone()
1080        }
1081        fn max_frame_length(&self) -> usize {
1082            self.max_frame_length
1083        }
1084
1085        fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1086            let mut log_event = LogEvent::from(frame);
1087
1088            log_event.insert(
1089                log_schema().source_type_key_target_path().unwrap(),
1090                "framestream",
1091            );
1092            if let Some(host) = received_from {
1093                self.log_namespace.insert_source_metadata(
1094                    "framestream",
1095                    &mut log_event,
1096                    self.host_key.as_ref().map(LegacyKey::Overwrite),
1097                    path!("host"),
1098                    host,
1099                )
1100            }
1101
1102            (self.extra_task_handling_routine.clone())();
1103
1104            Some(log_event.into())
1105        }
1106
1107        fn multithreaded(&self) -> bool {
1108            self.multithreaded
1109        }
1110        fn max_frame_handling_tasks(&self) -> usize {
1111            self.max_frame_handling_tasks
1112        }
1113
1114        fn host_key(&self) -> &Option<OwnedValuePath> {
1115            &self.host_key
1116        }
1117
1118        fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1119            self.timestamp_key.as_ref()
1120        }
1121
1122        fn source_type_key(&self) -> Option<&OwnedValuePath> {
1123            self.source_type_key.as_ref()
1124        }
1125    }
1126
1127    impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockUnixFrameHandler<F> {
1128        fn content_type(&self) -> String {
1129            self.frame_handler.content_type()
1130        }
1131
1132        fn max_frame_length(&self) -> usize {
1133            self.frame_handler.max_frame_length()
1134        }
1135
1136        fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1137            self.frame_handler.handle_event(received_from, frame)
1138        }
1139
1140        fn multithreaded(&self) -> bool {
1141            self.frame_handler.multithreaded()
1142        }
1143
1144        fn max_frame_handling_tasks(&self) -> usize {
1145            self.frame_handler.max_frame_handling_tasks()
1146        }
1147
1148        fn host_key(&self) -> &Option<OwnedValuePath> {
1149            self.frame_handler.host_key()
1150        }
1151
1152        fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1153            self.frame_handler.timestamp_key()
1154        }
1155
1156        fn source_type_key(&self) -> Option<&OwnedValuePath> {
1157            self.frame_handler.source_type_key()
1158        }
1159    }
1160
1161    impl<F: Send + Sync + Clone + FnOnce() + 'static> UnixFrameHandler for MockUnixFrameHandler<F> {
1162        fn socket_path(&self) -> PathBuf {
1163            self.socket_path.clone()
1164        }
1165
1166        fn socket_file_mode(&self) -> Option<u32> {
1167            self.socket_file_mode
1168        }
1169
1170        fn socket_receive_buffer_size(&self) -> Option<usize> {
1171            self.socket_receive_buffer_size
1172        }
1173
1174        fn socket_send_buffer_size(&self) -> Option<usize> {
1175            self.socket_send_buffer_size
1176        }
1177    }
1178
1179    impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockTcpFrameHandler<F> {
1180        fn content_type(&self) -> String {
1181            self.frame_handler.content_type()
1182        }
1183
1184        fn max_frame_length(&self) -> usize {
1185            self.frame_handler.max_frame_length()
1186        }
1187
1188        fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1189            self.frame_handler.handle_event(received_from, frame)
1190        }
1191
1192        fn multithreaded(&self) -> bool {
1193            self.frame_handler.multithreaded()
1194        }
1195
1196        fn max_frame_handling_tasks(&self) -> usize {
1197            self.frame_handler.max_frame_handling_tasks()
1198        }
1199
1200        fn host_key(&self) -> &Option<OwnedValuePath> {
1201            self.frame_handler.host_key()
1202        }
1203
1204        fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1205            self.frame_handler.timestamp_key()
1206        }
1207
1208        fn source_type_key(&self) -> Option<&OwnedValuePath> {
1209            self.frame_handler.source_type_key()
1210        }
1211    }
1212
1213    impl<F: Send + Sync + Clone + FnOnce() + 'static> TcpFrameHandler for MockTcpFrameHandler<F> {
1214        fn address(&self) -> SocketListenAddr {
1215            self.address
1216        }
1217
1218        fn keepalive(&self) -> Option<TcpKeepaliveConfig> {
1219            self.keepalive
1220        }
1221
1222        fn shutdown_timeout_secs(&self) -> Duration {
1223            self.shutdown_timeout_secs
1224        }
1225
1226        fn tls(&self) -> MaybeTlsSettings {
1227            self.tls.clone()
1228        }
1229
1230        fn tls_client_metadata_key(&self) -> Option<OwnedValuePath> {
1231            self.tls_client_metadata_key.clone()
1232        }
1233
1234        fn receive_buffer_bytes(&self) -> Option<usize> {
1235            self.receive_buffer_bytes
1236        }
1237
1238        fn max_connection_duration_secs(&self) -> Option<u64> {
1239            self.max_connection_duration_secs
1240        }
1241
1242        fn max_connections(&self) -> Option<u32> {
1243            self.max_connections
1244        }
1245
1246        fn insert_tls_client_metadata(&mut self, _: Option<CertificateMetadata>) {}
1247
1248        fn allowed_origins(&self) -> Option<&[IpNet]> {
1249            self.permit_origin.as_deref()
1250        }
1251    }
1252
1253    fn init_framestream_tcp(
1254        source_id: &str,
1255        addr: &SocketAddr,
1256        frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
1257        pipeline: SourceSender,
1258    ) -> (JoinHandle<Result<(), ()>>, SourceShutdownCoordinator) {
1259        let source_id = ComponentKey::from(source_id);
1260        let mut shutdown = SourceShutdownCoordinator::default();
1261        let (shutdown_signal, _) = shutdown.register_source(&source_id, false);
1262        let server = build_framestream_tcp_source(frame_handler, shutdown_signal, pipeline)
1263            .expect("Failed to build framestream tcp source.");
1264
1265        let join_handle = tokio::spawn(server);
1266
1267        while std::net::TcpStream::connect(addr).is_err() {
1268            thread::sleep(Duration::from_millis(2));
1269        }
1270
1271        (join_handle, shutdown)
1272    }
1273
1274    fn init_framestream_unix(
1275        source_id: &str,
1276        frame_handler: impl UnixFrameHandler + Send + Sync + Clone + 'static,
1277        pipeline: SourceSender,
1278    ) -> (
1279        PathBuf,
1280        JoinHandle<Result<(), ()>>,
1281        SourceShutdownCoordinator,
1282    ) {
1283        let source_id = ComponentKey::from(source_id);
1284        let socket_path = frame_handler.socket_path();
1285        let mut shutdown = SourceShutdownCoordinator::default();
1286        let (shutdown_signal, _) = shutdown.register_source(&source_id, false);
1287        let server = build_framestream_unix_source(frame_handler, shutdown_signal, pipeline)
1288            .expect("Failed to build framestream unix source.");
1289
1290        let join_handle = tokio::spawn(server);
1291
1292        // Wait for server to accept traffic
1293        while std::os::unix::net::UnixStream::connect(&socket_path).is_err() {
1294            thread::sleep(Duration::from_millis(2));
1295        }
1296
1297        (socket_path, join_handle, shutdown)
1298    }
1299
1300    async fn make_tcp_stream(
1301        addr: SocketAddr,
1302    ) -> Framed<TcpStream, length_delimited::LengthDelimitedCodec> {
1303        let socket = TcpStream::connect(&addr).await.unwrap();
1304        Framed::new(socket, length_delimited::Builder::new().new_codec())
1305    }
1306
1307    async fn make_unix_stream(
1308        path: PathBuf,
1309    ) -> Framed<UnixStream, length_delimited::LengthDelimitedCodec> {
1310        let socket = UnixStream::connect(&path).await.unwrap();
1311        Framed::new(socket, length_delimited::Builder::new().new_codec())
1312    }
1313
1314    async fn send_data_frames<S: Sink<Bytes, Error = std::io::Error> + Unpin>(
1315        sock_sink: &mut S,
1316        frames: Vec<Result<Bytes, std::io::Error>>,
1317    ) {
1318        let mut stream = stream::iter(frames.into_iter());
1319        //send and send_all consume the sink
1320        _ = sock_sink.send_all(&mut stream).await;
1321    }
1322
1323    async fn send_control_frame<S: Sink<Bytes, Error = std::io::Error> + Unpin>(
1324        sock_sink: &mut S,
1325        frame: Bytes,
1326    ) {
1327        send_data_frames(sock_sink, vec![Ok(Bytes::new()), Ok(frame)]).await; //send empty frame to say we are control frame
1328    }
1329
1330    fn create_control_frame(header: ControlHeader) -> Bytes {
1331        Bytes::from(header.to_u32().to_be_bytes().to_vec())
1332    }
1333
1334    fn create_control_frame_with_content(
1335        header: ControlHeader,
1336        content_types: Vec<Bytes>,
1337    ) -> Bytes {
1338        let mut frame = BytesMut::from(&header.to_u32().to_be_bytes()[..]);
1339        for content_type in content_types {
1340            frame.extend(ControlField::ContentType.to_u32().to_be_bytes());
1341            frame.extend((content_type.len() as u32).to_be_bytes());
1342            frame.extend(content_type.clone());
1343        }
1344        Bytes::from(frame)
1345    }
1346
1347    fn assert_accept_frame(frame: &mut BytesMut, expected_content_type: Bytes) {
1348        //frame should start with 4 bytes saying ACCEPT
1349
1350        assert_eq!(&frame[..4], &ControlHeader::Accept.to_u32().to_be_bytes(),);
1351        frame.advance(4);
1352        //next should be content type field
1353        assert_eq!(
1354            &frame[..4],
1355            &ControlField::ContentType.to_u32().to_be_bytes(),
1356        );
1357        frame.advance(4);
1358        //next should be length of content_type
1359        assert_eq!(
1360            &frame[..4],
1361            &(expected_content_type.len() as u32).to_be_bytes(),
1362        );
1363        frame.advance(4);
1364        //rest should be content type
1365        assert_eq!(&frame[..], &expected_content_type[..]);
1366    }
1367
1368    fn create_frame_handler(multithreaded: bool) -> impl UnixFrameHandler + Send + Sync + Clone {
1369        MockUnixFrameHandler::new("test_content".to_string(), multithreaded, move || {})
1370    }
1371
1372    fn create_tcp_frame_handler(
1373        addr: SocketAddr,
1374        multithreaded: bool,
1375        permit_origin: Option<Vec<IpNet>>,
1376    ) -> impl TcpFrameHandler + Send + Sync + Clone {
1377        MockTcpFrameHandler::new(
1378            addr,
1379            "test_content".to_string(),
1380            multithreaded,
1381            move || {},
1382            permit_origin,
1383        )
1384    }
1385
1386    async fn signal_shutdown(source_name: &str, shutdown: &mut SourceShutdownCoordinator) {
1387        // Now signal to the Source to shut down.
1388        let deadline = Instant::now() + Duration::from_secs(10);
1389        let id = ComponentKey::from(source_name);
1390        let shutdown_complete = shutdown.shutdown_source(&id, deadline);
1391        let shutdown_success = shutdown_complete.await;
1392        assert!(shutdown_success);
1393    }
1394
1395    async fn test_normal_framestream<
1396        T: Sink<Bytes, Error = std::io::Error> + Unpin,
1397        U: Stream<Item = Result<BytesMut, std::io::Error>> + Unpin,
1398        V: Stream<Item = Event> + Unpin,
1399    >(
1400        source_name: &str,
1401        mut sock_sink: T,
1402        mut sock_stream: U,
1403        rx: V,
1404        mut shutdown: SourceShutdownCoordinator,
1405        source_handle: JoinHandle<Result<(), ()>>,
1406    ) {
1407        //1 - send READY frame (with content_type)
1408        let content_type = Bytes::from(&b"test_content"[..]);
1409        let ready_msg =
1410            create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1411        send_control_frame(&mut sock_sink, ready_msg).await;
1412
1413        //2 - wait for ACCEPT frame
1414        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1415        //take second element, because first will be empty (signifying control frame)
1416        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1417        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1418
1419        //3 - send START frame
1420        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Start)).await;
1421
1422        //4 - send data
1423        send_data_frames(
1424            &mut sock_sink,
1425            vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1426        )
1427        .await;
1428        let events = collect_n(rx, 2).await;
1429
1430        //5 - send STOP frame
1431        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1432
1433        let message_key = log_schema().message_key().unwrap().to_string();
1434        assert!(events
1435            .iter()
1436            .any(|e| e.as_log()[&message_key] == "hello".into()));
1437        assert!(events
1438            .iter()
1439            .any(|e| e.as_log()[&message_key] == "world".into()));
1440
1441        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1442
1443        // Ensure source actually shut down successfully.
1444        signal_shutdown(source_name, &mut shutdown).await;
1445        _ = source_handle.await.unwrap();
1446    }
1447
1448    async fn test_multiple_content_types<
1449        T: Sink<Bytes, Error = std::io::Error> + Unpin,
1450        U: Stream<Item = Result<BytesMut, std::io::Error>> + Unpin,
1451    >(
1452        source_name: &str,
1453        mut sock_sink: T,
1454        mut sock_stream: U,
1455        mut shutdown: SourceShutdownCoordinator,
1456        source_handle: JoinHandle<Result<(), ()>>,
1457    ) {
1458        //1 - send READY frame (with content_type)
1459        let content_type = Bytes::from(&b"test_content"[..]);
1460        let ready_msg = create_control_frame_with_content(
1461            ControlHeader::Ready,
1462            vec![Bytes::from(&b"test_content2"[..]), content_type.clone()],
1463        ); //can have multiple content types
1464        send_control_frame(&mut sock_sink, ready_msg).await;
1465
1466        //2 - wait for ACCEPT frame
1467        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1468
1469        //take second element, because first will be empty (signifying control frame)
1470        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1471        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1472
1473        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1474
1475        // Ensure source actually shut down successfully.
1476        signal_shutdown(source_name, &mut shutdown).await;
1477        _ = source_handle.await.unwrap();
1478    }
1479
1480    #[tokio::test(flavor = "multi_thread")]
1481    #[should_panic]
1482    async fn blocked_framestream_tcp() {
1483        let source_name = "test_source";
1484        let (tx, rx) = SourceSender::new_test();
1485        let addr = next_addr();
1486        let (source_handle, shutdown) = init_framestream_tcp(
1487            source_name,
1488            &addr,
1489            create_tcp_frame_handler(addr, false, Some(vec![])),
1490            tx,
1491        );
1492        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1493
1494        test_normal_framestream(
1495            source_name,
1496            sock_sink,
1497            sock_stream,
1498            rx,
1499            shutdown,
1500            source_handle,
1501        )
1502        .await;
1503    }
1504
1505    #[tokio::test(flavor = "multi_thread")]
1506    async fn normal_framestream_singlethreaded_tcp() {
1507        let source_name = "test_source";
1508        let (tx, rx) = SourceSender::new_test();
1509        let addr = next_addr();
1510        let (source_handle, shutdown) = init_framestream_tcp(
1511            source_name,
1512            &addr,
1513            create_tcp_frame_handler(addr, false, None),
1514            tx,
1515        );
1516        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1517
1518        test_normal_framestream(
1519            source_name,
1520            sock_sink,
1521            sock_stream,
1522            rx,
1523            shutdown,
1524            source_handle,
1525        )
1526        .await;
1527    }
1528
1529    #[tokio::test(flavor = "multi_thread")]
1530    async fn normal_framestream_singlethreaded_unix() {
1531        let source_name = "test_source";
1532        let (tx, rx) = SourceSender::new_test();
1533        let (path, source_handle, shutdown) =
1534            init_framestream_unix(source_name, create_frame_handler(false), tx);
1535        let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1536
1537        test_normal_framestream(
1538            source_name,
1539            sock_sink,
1540            sock_stream,
1541            rx,
1542            shutdown,
1543            source_handle,
1544        )
1545        .await;
1546    }
1547
1548    #[tokio::test(flavor = "multi_thread")]
1549    async fn normal_framestream_multithreaded_tcp() {
1550        let source_name = "test_source";
1551        let (tx, rx) = SourceSender::new_test();
1552        let addr = next_addr();
1553        let (source_handle, shutdown) = init_framestream_tcp(
1554            source_name,
1555            &addr,
1556            create_tcp_frame_handler(addr, true, None),
1557            tx,
1558        );
1559        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1560
1561        test_normal_framestream(
1562            source_name,
1563            sock_sink,
1564            sock_stream,
1565            rx,
1566            shutdown,
1567            source_handle,
1568        )
1569        .await;
1570    }
1571
1572    #[tokio::test(flavor = "multi_thread")]
1573    async fn normal_framestream_multithreaded_unix() {
1574        let source_name = "test_source";
1575        let (tx, rx) = SourceSender::new_test();
1576        let (path, source_handle, shutdown) =
1577            init_framestream_unix(source_name, create_frame_handler(true), tx);
1578        let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1579
1580        test_normal_framestream(
1581            source_name,
1582            sock_sink,
1583            sock_stream,
1584            rx,
1585            shutdown,
1586            source_handle,
1587        )
1588        .await;
1589    }
1590
1591    #[tokio::test(flavor = "multi_thread")]
1592    async fn multiple_content_types_tcp() {
1593        let source_name = "test_source";
1594        let (tx, _) = SourceSender::new_test();
1595        let addr = next_addr();
1596        let (source_handle, shutdown) = init_framestream_tcp(
1597            source_name,
1598            &addr,
1599            create_tcp_frame_handler(addr, false, None),
1600            tx,
1601        );
1602        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1603
1604        test_multiple_content_types(source_name, sock_sink, sock_stream, shutdown, source_handle)
1605            .await;
1606    }
1607
1608    #[tokio::test(flavor = "multi_thread")]
1609    async fn multiple_content_types_unix() {
1610        let source_name = "test_source";
1611        let (tx, _) = SourceSender::new_test();
1612        let (path, source_handle, shutdown) =
1613            init_framestream_unix(source_name, create_frame_handler(false), tx);
1614        let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1615
1616        test_multiple_content_types(source_name, sock_sink, sock_stream, shutdown, source_handle)
1617            .await;
1618    }
1619
1620    #[tokio::test(flavor = "multi_thread")]
1621    async fn wrong_content_type() {
1622        let source_name = "test_source";
1623        let (tx, _) = SourceSender::new_test();
1624        let (path, source_handle, mut shutdown) =
1625            init_framestream_unix(source_name, create_frame_handler(false), tx);
1626        let (mut sock_sink, mut sock_stream) = make_unix_stream(path).await.split();
1627
1628        //1 - send READY frame (with WRONG content_type)
1629        let ready_msg = create_control_frame_with_content(
1630            ControlHeader::Ready,
1631            vec![Bytes::from(&b"test_content2"[..])],
1632        ); //can have multiple content types
1633        send_control_frame(&mut sock_sink, ready_msg).await;
1634
1635        //2 - send READY frame (with RIGHT content_type)
1636        let content_type = Bytes::from(&b"test_content"[..]);
1637        let ready_msg =
1638            create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1639        send_control_frame(&mut sock_sink, ready_msg).await;
1640
1641        //3 - wait for ACCEPT frame
1642        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1643
1644        //take second element, because first will be empty (signifying control frame)
1645        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1646        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1647
1648        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1649
1650        // Ensure source actually shut down successfully.
1651        signal_shutdown(source_name, &mut shutdown).await;
1652        _ = source_handle.await.unwrap();
1653    }
1654
1655    #[tokio::test(flavor = "multi_thread")]
1656    async fn data_too_soon() {
1657        let source_name = "test_source";
1658        let (tx, rx) = SourceSender::new_test();
1659        let (path, source_handle, mut shutdown) =
1660            init_framestream_unix(source_name, create_frame_handler(false), tx);
1661        let (mut sock_sink, mut sock_stream) = make_unix_stream(path).await.split();
1662
1663        //1 - send data frame (too soon!)
1664        send_data_frames(
1665            &mut sock_sink,
1666            vec![Ok(Bytes::from("bad")), Ok(Bytes::from("data"))],
1667        )
1668        .await;
1669
1670        //2 - send READY frame (with content_type)
1671        let content_type = Bytes::from(&b"test_content"[..]);
1672        let ready_msg =
1673            create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1674        send_control_frame(&mut sock_sink, ready_msg).await;
1675
1676        //3 - wait for ACCEPT frame
1677        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1678
1679        //take second element, because first will be empty (signifying control frame)
1680        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1681        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1682
1683        //4 - send START frame
1684        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Start)).await;
1685
1686        //5 - send data (will go through)
1687        send_data_frames(
1688            &mut sock_sink,
1689            vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1690        )
1691        .await;
1692        let events = collect_n(rx, 2).await;
1693
1694        //6 - send STOP frame
1695        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1696
1697        assert_eq!(
1698            events[0].as_log()[log_schema().message_key().unwrap().to_string()],
1699            "hello".into(),
1700        );
1701        assert_eq!(
1702            events[1].as_log()[log_schema().message_key().unwrap().to_string()],
1703            "world".into(),
1704        );
1705
1706        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1707
1708        // Ensure source actually shut down successfully.
1709        signal_shutdown(source_name, &mut shutdown).await;
1710        _ = source_handle.await.unwrap();
1711    }
1712
1713    #[tokio::test(flavor = "multi_thread")]
1714    async fn unidirectional_framestream() {
1715        let source_name = "test_source";
1716        let (tx, rx) = SourceSender::new_test();
1717        let (path, source_handle, mut shutdown) =
1718            init_framestream_unix(source_name, create_frame_handler(false), tx);
1719        let (mut sock_sink, _) = make_unix_stream(path).await.split();
1720
1721        //1 - send START frame (with content_type)
1722        let content_type = Bytes::from(&b"test_content"[..]);
1723        let start_msg = create_control_frame_with_content(ControlHeader::Start, vec![content_type]);
1724        send_control_frame(&mut sock_sink, start_msg).await;
1725
1726        //4 - send data
1727        send_data_frames(
1728            &mut sock_sink,
1729            vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1730        )
1731        .await;
1732        let events = collect_n(rx, 2).await;
1733
1734        //5 - send STOP frame
1735        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1736
1737        assert_eq!(
1738            events[0].as_log()[log_schema().message_key().unwrap().to_string()],
1739            "hello".into(),
1740        );
1741        assert_eq!(
1742            events[1].as_log()[log_schema().message_key().unwrap().to_string()],
1743            "world".into(),
1744        );
1745
1746        // Ensure source actually shut down successfully.
1747        signal_shutdown(source_name, &mut shutdown).await;
1748        _ = source_handle.await.unwrap();
1749    }
1750
1751    #[tokio::test(flavor = "multi_thread")]
1752    async fn test_spawn_event_handling_tasks() {
1753        let (out, rx) = SourceSender::new_test();
1754
1755        let max_frame_handling_tasks = 20;
1756        let active_task_nums = Arc::new(AtomicUsize::new(0));
1757        let active_task_nums_copy = Arc::clone(&active_task_nums);
1758        let max_task_nums_reached = Arc::new(AtomicUsize::new(0));
1759        let max_task_nums_reached_copy = Arc::clone(&max_task_nums_reached);
1760
1761        let mut join_handles = vec![];
1762        let active_task_nums_copy_2 = Arc::clone(&active_task_nums_copy);
1763        let extra_routine = move || {
1764            thread::sleep(Duration::from_millis(10));
1765            max_task_nums_reached_copy.fetch_max(
1766                active_task_nums_copy_2.load(Ordering::Acquire),
1767                Ordering::AcqRel,
1768            );
1769        };
1770
1771        let total_events = max_frame_handling_tasks * 10;
1772
1773        join_handles.push(tokio::spawn(async move {
1774            future::ready({
1775                let events = collect_n(rx, total_events).await;
1776                assert_eq!(total_events, events.len(), "Missed events");
1777            })
1778            .await;
1779        }));
1780
1781        for i in 0..total_events {
1782            join_handles.push(
1783                spawn_event_handling_tasks(
1784                    Bytes::from(format!("event_{i}")),
1785                    MockFrameHandler::new("test_content".to_string(), true, extra_routine.clone()),
1786                    out.clone(),
1787                    None,
1788                    Arc::clone(&active_task_nums_copy),
1789                    max_frame_handling_tasks,
1790                )
1791                .await,
1792            );
1793        }
1794
1795        future::join_all(join_handles).await;
1796
1797        let final_task_nums = active_task_nums.load(Ordering::Acquire);
1798        assert_eq!(
1799            0, final_task_nums,
1800            "There should be NO left-over tasks at the end"
1801        );
1802
1803        let max_task_nums_reached_value = max_task_nums_reached.load(Ordering::Acquire);
1804        assert!(
1805            max_task_nums_reached_value > 1,
1806            "MultiThreaded mode does NOT work"
1807        );
1808        assert!((max_task_nums_reached_value - max_frame_handling_tasks) < 2, "Max number of tasks at any given time should NOT Exceed max_frame_handling_tasks too much");
1809    }
1810}