vector/sources/util/
framestream.rs

1#[cfg(unix)]
2use std::os::unix::{fs::PermissionsExt, io::AsRawFd};
3use std::{
4    convert::TryInto,
5    fs,
6    marker::{Send, Sync},
7    net::SocketAddr,
8    path::PathBuf,
9    sync::{
10        Arc, Mutex,
11        atomic::{AtomicUsize, Ordering},
12    },
13    time::Duration,
14};
15
16use bytes::{Buf, Bytes, BytesMut};
17use futures::{
18    executor::block_on,
19    future::{self, OptionFuture},
20    sink::{Sink, SinkExt},
21    stream::{self, StreamExt, TryStreamExt},
22};
23use futures_util::{Future, FutureExt, future::BoxFuture};
24use ipnet::IpNet;
25use listenfd::ListenFd;
26use tokio::{
27    self,
28    io::{AsyncRead, AsyncWrite},
29    net::{TcpStream, UnixListener},
30    task::JoinHandle,
31    time::sleep,
32};
33use tokio_stream::wrappers::UnixListenerStream;
34use tokio_util::codec::{Framed, length_delimited};
35use tracing::{Instrument, Span, field};
36use vector_lib::{
37    lookup::OwnedValuePath,
38    tcp::TcpKeepaliveConfig,
39    tls::{CertificateMetadata, MaybeTlsIncomingStream, MaybeTlsSettings},
40};
41
42use super::net::{RequestLimiter, SocketListenAddr};
43use crate::{
44    SourceSender,
45    event::Event,
46    internal_events::{
47        ConnectionOpen, OpenGauge, SocketBindError, SocketMode, SocketReceiveError,
48        TcpBytesReceived, TcpSocketError, TcpSocketTlsConnectionError, UnixSocketError,
49        UnixSocketFileDeleteError,
50    },
51    shutdown::ShutdownSignal,
52    sources::{
53        Source,
54        util::{
55            AfterReadExt,
56            net::{MAX_IN_FLIGHT_EVENTS_TARGET, try_bind_tcp_listener},
57        },
58    },
59};
60
61const FSTRM_CONTROL_FRAME_LENGTH_MAX: usize = 512;
62const FSTRM_CONTROL_FIELD_CONTENT_TYPE_LENGTH_MAX: usize = 256;
63
64/// If a connection does not receive any data during this short timeout,
65/// it should release its permit (and try to obtain a new one) allowing other connections to read.
66/// It is very short because any incoming data will avoid this timeout,
67/// so it mainly prevents holding permits without consuming any data
68const PERMIT_HOLD_TIMEOUT_MS: u64 = 10;
69
70pub type FrameStreamSink = Box<dyn Sink<Bytes, Error = std::io::Error> + Send + Unpin>;
71
72pub struct FrameStreamReader {
73    response_sink: Mutex<FrameStreamSink>,
74    expected_content_type: String,
75    state: FrameStreamState,
76}
77
78struct FrameStreamState {
79    expect_control_frame: bool,
80    control_state: ControlState,
81    is_bidirectional: bool,
82}
83impl FrameStreamState {
84    const fn new() -> Self {
85        FrameStreamState {
86            expect_control_frame: false,
87            //first control frame should be READY (if bidirectional -- if unidirectional first will be START)
88            control_state: ControlState::Initial,
89            is_bidirectional: true, //assume
90        }
91    }
92}
93
94#[derive(PartialEq, Debug)]
95enum ControlState {
96    Initial,
97    GotReady,
98    ReadingData,
99    Stopped,
100}
101
102#[derive(Copy, Clone)]
103enum ControlHeader {
104    Accept,
105    Start,
106    Stop,
107    Ready,
108    Finish,
109}
110
111impl ControlHeader {
112    fn from_u32(val: u32) -> Result<Self, ()> {
113        match val {
114            0x01 => Ok(ControlHeader::Accept),
115            0x02 => Ok(ControlHeader::Start),
116            0x03 => Ok(ControlHeader::Stop),
117            0x04 => Ok(ControlHeader::Ready),
118            0x05 => Ok(ControlHeader::Finish),
119            _ => {
120                error!("Don't know header value {} (expected 0x01 - 0x05).", val);
121                Err(())
122            }
123        }
124    }
125
126    const fn to_u32(self) -> u32 {
127        match self {
128            ControlHeader::Accept => 0x01,
129            ControlHeader::Start => 0x02,
130            ControlHeader::Stop => 0x03,
131            ControlHeader::Ready => 0x04,
132            ControlHeader::Finish => 0x05,
133        }
134    }
135}
136
137enum ControlField {
138    ContentType,
139}
140
141impl ControlField {
142    fn from_u32(val: u32) -> Result<Self, ()> {
143        match val {
144            0x01 => Ok(ControlField::ContentType),
145            _ => {
146                error!("Don't know field type {} (expected 0x01).", val);
147                Err(())
148            }
149        }
150    }
151    const fn to_u32(&self) -> u32 {
152        match self {
153            ControlField::ContentType => 0x01,
154        }
155    }
156}
157
158fn advance_u32(b: &mut Bytes) -> Result<u32, ()> {
159    if b.len() < 4 {
160        error!("Malformed frame.");
161        return Err(());
162    }
163    let a = b.split_to(4);
164    Ok(u32::from_be_bytes(a[..].try_into().unwrap()))
165}
166
167impl FrameStreamReader {
168    pub fn new(response_sink: FrameStreamSink, expected_content_type: String) -> Self {
169        FrameStreamReader {
170            response_sink: Mutex::new(response_sink),
171            expected_content_type,
172            state: FrameStreamState::new(),
173        }
174    }
175
176    pub fn handle_frame(&mut self, frame: Bytes) -> Option<Bytes> {
177        if frame.is_empty() {
178            //frame length of zero means the next frame is a control frame
179            self.state.expect_control_frame = true;
180            None
181        } else if self.state.expect_control_frame {
182            self.state.expect_control_frame = false;
183            _ = self.handle_control_frame(frame);
184            None
185        } else {
186            //data frame
187            if self.state.control_state == ControlState::ReadingData {
188                Some(frame) //return data frame
189            } else {
190                error!(
191                    "Received a data frame while in state {:?}.",
192                    self.state.control_state
193                );
194                None
195            }
196        }
197    }
198
199    fn handle_control_frame(&mut self, mut frame: Bytes) -> Result<(), ()> {
200        //enforce maximum control frame size
201        if frame.len() > FSTRM_CONTROL_FRAME_LENGTH_MAX {
202            error!("Control frame is too long.");
203        }
204
205        let header = ControlHeader::from_u32(advance_u32(&mut frame)?)?;
206
207        //match current state to received header
208        match self.state.control_state {
209            ControlState::Initial => {
210                match header {
211                    ControlHeader::Ready => {
212                        let content_type = self.process_fields(header, &mut frame)?.unwrap();
213
214                        self.send_control_frame(Self::make_frame(
215                            ControlHeader::Accept,
216                            Some(content_type),
217                        ));
218                        self.state.control_state = ControlState::GotReady; //waiting for a START control frame
219                    }
220                    ControlHeader::Start => {
221                        //check for content type
222                        _ = self.process_fields(header, &mut frame)?;
223                        //if didn't error, then we are ok to change state
224                        self.state.control_state = ControlState::ReadingData;
225                        self.state.is_bidirectional = false; //if first message was START then we are unidirectional (no responses)
226                    }
227                    _ => error!("Got wrong control frame, expected READY."),
228                }
229            }
230            ControlState::GotReady => {
231                match header {
232                    ControlHeader::Start => {
233                        //check for content type
234                        _ = self.process_fields(header, &mut frame)?;
235                        //if didn't error, then we are ok to change state
236                        self.state.control_state = ControlState::ReadingData;
237                    }
238                    _ => error!("Got wrong control frame, expected START."),
239                }
240            }
241            ControlState::ReadingData => {
242                match header {
243                    ControlHeader::Stop => {
244                        //check there aren't any fields
245                        _ = self.process_fields(header, &mut frame)?;
246                        if self.state.is_bidirectional {
247                            //send FINISH frame -- but only if we are bidirectional
248                            self.send_control_frame(Self::make_frame(ControlHeader::Finish, None));
249                        }
250                        self.state.control_state = ControlState::Stopped; //stream is now done
251                    }
252                    _ => error!("Got wrong control frame, expected STOP."),
253                }
254            }
255            ControlState::Stopped => error!("Unexpected control frame, current state is STOPPED."),
256        };
257        Ok(())
258    }
259
260    fn process_fields(
261        &mut self,
262        header: ControlHeader,
263        frame: &mut Bytes,
264    ) -> Result<Option<String>, ()> {
265        match header {
266            ControlHeader::Ready => {
267                //should provide 1+ content types
268                //should match expected content type
269                let is_start_frame = false;
270                let content_type = self.process_content_type(frame, is_start_frame)?;
271                Ok(Some(content_type))
272            }
273            ControlHeader::Start => {
274                //can take one or zero content types
275                if frame.is_empty() {
276                    Ok(None)
277                } else {
278                    //should match expected content type
279                    let is_start_frame = true;
280                    let content_type = self.process_content_type(frame, is_start_frame)?;
281                    Ok(Some(content_type))
282                }
283            }
284            ControlHeader::Stop => {
285                //check that there are no fields
286                if !frame.is_empty() {
287                    error!("Unexpected fields in STOP header.");
288                    Err(())
289                } else {
290                    Ok(None)
291                }
292            }
293            _ => {
294                error!("Unexpected control header value {:?}.", header.to_u32());
295                Err(())
296            }
297        }
298    }
299
300    fn process_content_type(&self, frame: &mut Bytes, is_start_frame: bool) -> Result<String, ()> {
301        if frame.is_empty() {
302            error!("No fields in control frame.");
303            return Err(());
304        }
305
306        let mut content_types = vec![];
307        while !frame.is_empty() {
308            //4 bytes of ControlField
309            let field_val = advance_u32(frame)?;
310            let field_type = ControlField::from_u32(field_val)?;
311            match field_type {
312                ControlField::ContentType => {
313                    //4 bytes giving length of content type
314                    let field_len = advance_u32(frame)? as usize;
315
316                    //enforce limit on content type string
317                    if field_len > FSTRM_CONTROL_FIELD_CONTENT_TYPE_LENGTH_MAX {
318                        error!("Content-Type string is too long.");
319                        return Err(());
320                    }
321
322                    let content_type = std::str::from_utf8(&frame[..field_len]).unwrap();
323                    content_types.push(content_type.to_string());
324                    frame.advance(field_len);
325                }
326            }
327        }
328
329        if is_start_frame && content_types.len() > 1 {
330            error!(
331                "START control frame can only have one content-type provided (got {}).",
332                content_types.len()
333            );
334            return Err(());
335        }
336
337        for content_type in &content_types {
338            if *content_type == self.expected_content_type {
339                return Ok(content_type.clone());
340            }
341        }
342
343        error!(
344            "Content types did not match up. Expected {} got {:?}.",
345            self.expected_content_type, content_types
346        );
347        Err(())
348    }
349
350    fn make_frame(header: ControlHeader, content_type: Option<String>) -> Bytes {
351        let mut frame = BytesMut::new();
352        frame.extend(header.to_u32().to_be_bytes());
353        if let Some(s) = content_type {
354            frame.extend(ControlField::ContentType.to_u32().to_be_bytes()); //field type: ContentType
355            frame.extend((s.len() as u32).to_be_bytes()); //length of type
356            frame.extend(s.as_bytes());
357        }
358        Bytes::from(frame)
359    }
360
361    fn send_control_frame(&mut self, frame: Bytes) {
362        let empty_frame = Bytes::from(&b""[..]); //send empty frame to say we are control frame
363        let mut stream = stream::iter(vec![Ok(empty_frame), Ok(frame)]);
364
365        if let Err(e) = block_on(self.response_sink.lock().unwrap().send_all(&mut stream)) {
366            error!("Encountered error '{:#?}' while sending control frame.", e);
367        }
368    }
369}
370
371pub trait FrameHandler {
372    fn content_type(&self) -> String;
373    fn max_frame_length(&self) -> usize;
374    fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event>;
375    fn multithreaded(&self) -> bool;
376    fn max_frame_handling_tasks(&self) -> usize;
377    fn host_key(&self) -> &Option<OwnedValuePath>;
378    fn timestamp_key(&self) -> Option<&OwnedValuePath>;
379    fn source_type_key(&self) -> Option<&OwnedValuePath>;
380}
381
382pub trait UnixFrameHandler: FrameHandler {
383    fn socket_path(&self) -> PathBuf;
384    fn socket_file_mode(&self) -> Option<u32>;
385    fn socket_receive_buffer_size(&self) -> Option<usize>;
386    fn socket_send_buffer_size(&self) -> Option<usize>;
387}
388
389pub trait TcpFrameHandler: FrameHandler {
390    fn address(&self) -> SocketListenAddr;
391    fn keepalive(&self) -> Option<TcpKeepaliveConfig>;
392    fn shutdown_timeout_secs(&self) -> Duration;
393    fn tls(&self) -> MaybeTlsSettings;
394    fn tls_client_metadata_key(&self) -> Option<OwnedValuePath>;
395    fn receive_buffer_bytes(&self) -> Option<usize>;
396    fn max_connection_duration_secs(&self) -> Option<u64>;
397    fn max_connections(&self) -> Option<u32>;
398    fn allowed_origins(&self) -> Option<&[IpNet]>;
399    fn insert_tls_client_metadata(&mut self, metadata: Option<CertificateMetadata>);
400}
401
402/**
403 * Based off of the build_framestream_unix_source function.
404 * Functions similarly, just uses TCP socket instead of unix socket
405 **/
406pub fn build_framestream_tcp_source(
407    frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
408    shutdown: ShutdownSignal,
409    out: SourceSender,
410) -> crate::Result<Source> {
411    let addr = frame_handler.address();
412    let tls = frame_handler.tls();
413    let shutdown = shutdown.clone();
414    let out = out.clone();
415
416    Ok(Box::pin(async move {
417        let listenfd = ListenFd::from_env();
418        let listener = try_bind_tcp_listener(
419            addr,
420            listenfd,
421            &tls,
422            frame_handler
423                .allowed_origins()
424                .map(|origins| origins.to_vec()),
425        )
426        .await
427        .map_err(|error| {
428            emit!(SocketBindError {
429                mode: SocketMode::Tcp,
430                error: &error,
431            })
432        })?;
433
434        info!(
435            message = "Listening.",
436            addr = %listener
437                .local_addr()
438                .map(SocketListenAddr::SocketAddr)
439                .unwrap_or(addr)
440        );
441
442        let tripwire = shutdown.clone();
443        let shutdown_timeout_secs = frame_handler.shutdown_timeout_secs();
444        let tripwire = async move {
445            _ = tripwire.await;
446            sleep(shutdown_timeout_secs).await;
447        }
448        .shared();
449
450        let connection_gauge = OpenGauge::new();
451        let shutdown_clone = shutdown.clone();
452
453        let request_limiter = RequestLimiter::new(
454            MAX_IN_FLIGHT_EVENTS_TARGET,
455            frame_handler.max_frame_handling_tasks(),
456        );
457
458        listener
459            .accept_stream_limited(frame_handler.max_connections())
460            .take_until(shutdown_clone)
461            .for_each(move |(connection, tcp_connection_permit)| {
462                let shutdown_signal = shutdown.clone();
463                let tripwire = tripwire.clone();
464                let out = out.clone();
465                let connection_gauge = connection_gauge.clone();
466                let request_limiter = request_limiter.clone();
467                let frame_handler_clone = frame_handler.clone();
468
469                async move {
470                    let socket = match connection {
471                        Ok(socket) => socket,
472                        Err(error) => {
473                            emit!(SocketReceiveError {
474                                mode: SocketMode::Tcp,
475                                error: &error
476                            });
477                            return;
478                        }
479                    };
480
481                    let peer_addr = socket.peer_addr();
482                    let span = info_span!("connection", %peer_addr);
483
484                    let tripwire = tripwire
485                        .map(move |_| {
486                            info!(
487                                message = "Resetting connection (still open after seconds).",
488                                seconds = ?shutdown_timeout_secs
489                            );
490                        })
491                        .boxed();
492
493                    span.clone().in_scope(|| {
494                        debug!(message = "Accepted a new connection.", peer_addr = %peer_addr);
495
496                        let open_token =
497                            connection_gauge.open(|count| emit!(ConnectionOpen { count }));
498
499                        let fut = handle_stream(
500                            frame_handler_clone,
501                            shutdown_signal,
502                            socket,
503                            tripwire,
504                            peer_addr,
505                            out,
506                            request_limiter,
507                        );
508
509                        tokio::spawn(
510                            fut.map(move |()| {
511                                drop(open_token);
512                                drop(tcp_connection_permit);
513                            })
514                            .instrument(span.or_current()),
515                        );
516                    });
517                }
518            })
519            .map(Ok)
520            .await
521    }))
522}
523
524#[allow(clippy::too_many_arguments)]
525async fn handle_stream(
526    mut frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
527    mut shutdown_signal: ShutdownSignal,
528    mut socket: MaybeTlsIncomingStream<TcpStream>,
529    mut tripwire: BoxFuture<'static, ()>,
530    peer_addr: SocketAddr,
531    out: SourceSender,
532    request_limiter: RequestLimiter,
533) {
534    tokio::select! {
535        result = socket.handshake() => {
536            if let Err(error) = result {
537                emit!(TcpSocketTlsConnectionError { error });
538                return;
539            }
540        },
541        _ = &mut shutdown_signal => {
542            return;
543        }
544    };
545
546    if let Some(keepalive) = frame_handler.keepalive()
547        && let Err(error) = socket.set_keepalive(keepalive)
548    {
549        warn!(message = "Failed configuring TCP keepalive.", %error);
550    }
551
552    if let Some(receive_buffer_bytes) = frame_handler.receive_buffer_bytes()
553        && let Err(error) = socket.set_receive_buffer_bytes(receive_buffer_bytes)
554    {
555        warn!(message = "Failed configuring receive buffer size on TCP socket.", %error);
556    }
557
558    let socket = socket.after_read(move |byte_size| {
559        emit!(TcpBytesReceived {
560            byte_size,
561            peer_addr
562        });
563    });
564
565    let certificate_metadata = socket
566        .get_ref()
567        .ssl_stream()
568        .and_then(|stream| stream.ssl().peer_certificate())
569        .map(CertificateMetadata::from);
570
571    frame_handler.insert_tls_client_metadata(certificate_metadata);
572
573    let span = info_span!("connection");
574    span.record("peer_addr", field::debug(&peer_addr));
575    let received_from: Option<Bytes> = Some(peer_addr.to_string().into());
576
577    let connection_close_timeout = OptionFuture::from(
578        frame_handler
579            .max_connection_duration_secs()
580            .map(|timeout_secs| tokio::time::sleep(Duration::from_secs(timeout_secs))),
581    );
582    tokio::pin!(connection_close_timeout);
583
584    let content_type = frame_handler.content_type();
585    let mut event_sink = out.clone();
586    let (sock_sink, sock_stream) = Framed::new(
587        socket,
588        length_delimited::Builder::new()
589            .max_frame_length(frame_handler.max_frame_length())
590            .new_codec(),
591    )
592    .split();
593    let mut reader = FrameStreamReader::new(Box::new(sock_sink), content_type);
594    let mut frames = sock_stream
595        .map_err(move |error| {
596            emit!(TcpSocketError {
597                error: &error,
598                peer_addr,
599            });
600        })
601        .filter_map(move |frame| {
602            future::ready(match frame {
603                Ok(f) => reader.handle_frame(Bytes::from(f)),
604                Err(_) => None,
605            })
606        });
607
608    let active_parsing_task_nums = Arc::new(AtomicUsize::new(0));
609    loop {
610        let mut permit = tokio::select! {
611            _ = &mut tripwire => break,
612            Some(_) = &mut connection_close_timeout  => {
613                break;
614            },
615            _ = &mut shutdown_signal => {
616                break;
617            },
618            permit = request_limiter.acquire() => {
619                Some(permit)
620            }
621            else => break,
622        };
623
624        let timeout = tokio::time::sleep(Duration::from_millis(PERMIT_HOLD_TIMEOUT_MS));
625        tokio::pin!(timeout);
626
627        tokio::select! {
628            _ = &mut tripwire => break,
629            _ = &mut shutdown_signal => break,
630            _ = &mut timeout => {
631                // This connection is currently holding a permit, but has not received data for some time. Release
632                // the permit to let another connection try
633                continue;
634            }
635            res = frames.next() => {
636                match res {
637                    Some(frame) => {
638                        if let Some(permit) = &mut permit {
639                            // Note that this is intentionally not the "number of events in a single request", but rather
640                            // the "number of events currently available". This may contain events from multiple events,
641                            // but it should always contain all events from each request.
642                            permit.decoding_finished(1);
643                        };
644                        handle_tcp_frame(&mut frame_handler, frame, &mut event_sink, received_from.clone(), Arc::clone(&active_parsing_task_nums)).await;
645                    }
646                    None => {
647                        debug!("Connection closed.");
648                        break
649                    },
650                }
651            }
652            else => break,
653        }
654
655        drop(permit);
656    }
657}
658
659async fn handle_tcp_frame<T>(
660    frame_handler: &mut T,
661    frame: Bytes,
662    event_sink: &mut SourceSender,
663    received_from: Option<Bytes>,
664    active_parsing_task_nums: Arc<AtomicUsize>,
665) where
666    T: TcpFrameHandler + Send + Sync + Clone + 'static,
667{
668    if frame_handler.multithreaded() {
669        spawn_event_handling_tasks(
670            frame,
671            frame_handler.clone(),
672            event_sink.clone(),
673            received_from,
674            active_parsing_task_nums,
675            frame_handler.max_frame_handling_tasks(),
676        )
677        .await;
678    } else if let Some(event) = frame_handler.handle_event(received_from, frame)
679        && let Err(e) = event_sink.send_event(event).await
680    {
681        error!("Error sending event: {e:?}.");
682    }
683}
684
685/**
686 * Based off of the build_unix_source function.
687 * Functions similarly, but uses the FrameStreamReader to deal with
688 * framestream control packets, and responds appropriately.
689 **/
690pub fn build_framestream_unix_source(
691    frame_handler: impl UnixFrameHandler + Send + Sync + Clone + 'static,
692    shutdown: ShutdownSignal,
693    out: SourceSender,
694) -> crate::Result<Source> {
695    let path = frame_handler.socket_path();
696
697    //check if the path already exists (and try to delete it)
698    match fs::metadata(&path) {
699        Ok(_) => {
700            //exists, so try to delete it
701            info!(message = "Deleting file.", ?path);
702            fs::remove_file(&path)?;
703        }
704        Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {} //doesn't exist, do nothing
705        Err(e) => {
706            error!("Unable to get socket information; error = {:?}.", e);
707            return Err(Box::new(e));
708        }
709    };
710
711    let listener = UnixListener::bind(&path)?;
712
713    // system's 'net.core.rmem_max' might have to be changed if socket receive buffer is not updated properly
714    if let Some(socket_receive_buffer_size) = frame_handler.socket_receive_buffer_size() {
715        _ = nix::sys::socket::setsockopt(
716            listener.as_raw_fd(),
717            nix::sys::socket::sockopt::RcvBuf,
718            &(socket_receive_buffer_size),
719        );
720        let rcv_buf_size =
721            nix::sys::socket::getsockopt(listener.as_raw_fd(), nix::sys::socket::sockopt::RcvBuf);
722        info!(
723            "Unix socket receive buffer size modified to {}.",
724            rcv_buf_size.unwrap()
725        );
726    }
727
728    // system's 'net.core.wmem_max' might have to be changed if socket send buffer is not updated properly
729    if let Some(socket_send_buffer_size) = frame_handler.socket_send_buffer_size() {
730        _ = nix::sys::socket::setsockopt(
731            listener.as_raw_fd(),
732            nix::sys::socket::sockopt::SndBuf,
733            &(socket_send_buffer_size),
734        );
735        let snd_buf_size =
736            nix::sys::socket::getsockopt(listener.as_raw_fd(), nix::sys::socket::sockopt::SndBuf);
737        info!(
738            "Unix socket buffer send size modified to {}.",
739            snd_buf_size.unwrap()
740        );
741    }
742
743    // the permissions to unix socket are restricted from 0o700 to 0o777, which are 448 and 511 in decimal
744    if let Some(socket_permission) = frame_handler.socket_file_mode() {
745        if !(448..=511).contains(&socket_permission) {
746            return Err(format!(
747                "Invalid Socket permission {socket_permission:#o}. Must between 0o700 and 0o777."
748            )
749            .into());
750        }
751        match fs::set_permissions(&path, fs::Permissions::from_mode(socket_permission)) {
752            Ok(_) => {
753                info!("Socket permissions updated to {:#o}.", socket_permission);
754            }
755            Err(e) => {
756                error!(
757                    "Failed to update listener socket permissions; error = {:?}.",
758                    e
759                );
760                return Err(Box::new(e));
761            }
762        };
763    };
764
765    let fut = async move {
766        let active_parsing_task_nums = Arc::new(AtomicUsize::new(0));
767
768        info!(message = "Listening...", ?path, r#type = "unix");
769
770        let mut stream = UnixListenerStream::new(listener).take_until(shutdown.clone());
771        while let Some(socket) = stream.next().await {
772            let socket = match socket {
773                Err(e) => {
774                    error!("Failed to accept socket; error = {:?}.", e);
775                    continue;
776                }
777                Ok(s) => s,
778            };
779            let peer_addr = socket.peer_addr().ok();
780            let listen_path = path.clone();
781            let active_task_nums_ = Arc::clone(&active_parsing_task_nums);
782
783            let span = info_span!("connection");
784            let path = if let Some(addr) = peer_addr {
785                if let Some(path) = addr.as_pathname().map(|e| e.to_owned()) {
786                    span.record("peer_path", field::debug(&path));
787                    Some(path)
788                } else {
789                    None
790                }
791            } else {
792                None
793            };
794            let received_from: Option<Bytes> =
795                path.map(|p| p.to_string_lossy().into_owned().into());
796
797            build_framestream_source(
798                frame_handler.clone(),
799                socket,
800                received_from,
801                out.clone(),
802                shutdown.clone(),
803                span,
804                active_task_nums_,
805                move |error| {
806                    emit!(UnixSocketError {
807                        error: &error,
808                        path: &listen_path,
809                    });
810                },
811            );
812        }
813
814        // Cleanup
815        drop(stream);
816
817        // Delete socket file
818        if let Err(error) = fs::remove_file(&path) {
819            emit!(UnixSocketFileDeleteError { path: &path, error });
820        }
821
822        Ok(())
823    };
824
825    Ok(Box::pin(fut))
826}
827
828#[allow(clippy::too_many_arguments)]
829fn build_framestream_source<T: Send + 'static>(
830    frame_handler: impl FrameHandler + Send + Sync + Clone + 'static,
831    socket: impl AsyncRead + AsyncWrite + Send + 'static,
832    received_from: Option<Bytes>,
833    out: SourceSender,
834    shutdown: impl Future<Output = T> + Unpin + Send + 'static,
835    span: Span,
836    active_task_nums: Arc<AtomicUsize>,
837    error_mapper: impl FnMut(std::io::Error) + Send + 'static,
838) {
839    let content_type = frame_handler.content_type();
840    let mut event_sink = out.clone();
841    let (sock_sink, sock_stream) = Framed::new(
842        socket,
843        length_delimited::Builder::new()
844            .max_frame_length(frame_handler.max_frame_length())
845            .new_codec(),
846    )
847    .split();
848    let mut fs_reader = FrameStreamReader::new(Box::new(sock_sink), content_type);
849    let frame_handler_copy = frame_handler.clone();
850    let frames = sock_stream
851        .take_until(shutdown)
852        .map_err(error_mapper)
853        .filter_map(move |frame| {
854            future::ready(match frame {
855                Ok(f) => fs_reader.handle_frame(Bytes::from(f)),
856                Err(_) => None,
857            })
858        });
859    if !frame_handler.multithreaded() {
860        let mut events = frames.filter_map(move |f| {
861            future::ready(frame_handler_copy.handle_event(received_from.clone(), f))
862        });
863
864        let handler = async move {
865            if let Err(e) = event_sink.send_event_stream(&mut events).await {
866                error!("Error sending event: {:?}.", e);
867            }
868
869            info!("Finished sending.");
870        };
871        tokio::spawn(handler.instrument(span.or_current()));
872    } else {
873        let handler = async move {
874            frames
875                .for_each(move |f| {
876                    let max_frame_handling_tasks = frame_handler_copy.max_frame_handling_tasks();
877                    let f_handler = frame_handler_copy.clone();
878                    let received_from_copy = received_from.clone();
879                    let event_sink_copy = event_sink.clone();
880                    let active_task_nums_copy = Arc::clone(&active_task_nums);
881
882                    async move {
883                        spawn_event_handling_tasks(
884                            f,
885                            f_handler,
886                            event_sink_copy,
887                            received_from_copy,
888                            active_task_nums_copy,
889                            max_frame_handling_tasks,
890                        )
891                        .await;
892                    }
893                })
894                .await;
895            info!("Finished sending.");
896        };
897        tokio::spawn(handler.instrument(span.or_current()));
898    }
899}
900
901async fn spawn_event_handling_tasks(
902    event_data: Bytes,
903    event_handler: impl FrameHandler + Send + Sync + 'static,
904    mut event_sink: SourceSender,
905    received_from: Option<Bytes>,
906    active_task_nums: Arc<AtomicUsize>,
907    max_frame_handling_tasks: usize,
908) -> JoinHandle<()> {
909    wait_for_task_quota(&active_task_nums, max_frame_handling_tasks).await;
910
911    tokio::spawn(async move {
912        future::ready({
913            if let Some(evt) = event_handler.handle_event(received_from, event_data)
914                && event_sink.send_event(evt).await.is_err()
915            {
916                error!("Encountered error while sending event.");
917            }
918            active_task_nums.fetch_sub(1, Ordering::AcqRel);
919        })
920        .await;
921    })
922}
923
924async fn wait_for_task_quota(active_task_nums: &Arc<AtomicUsize>, max_tasks: usize) {
925    while max_tasks > 0 && max_tasks < active_task_nums.load(Ordering::Acquire) {
926        tokio::time::sleep(Duration::from_millis(3)).await;
927    }
928    active_task_nums.fetch_add(1, Ordering::AcqRel);
929}
930
931#[cfg(test)]
932mod test {
933    use std::net::SocketAddr;
934    #[cfg(unix)]
935    use std::{
936        path::PathBuf,
937        sync::{
938            Arc,
939            atomic::{AtomicUsize, Ordering},
940        },
941        thread,
942    };
943
944    use bytes::{Bytes, BytesMut, buf::Buf};
945    use futures::{
946        future,
947        sink::{Sink, SinkExt},
948        stream::{self, StreamExt},
949    };
950    use futures_util::Stream;
951    use ipnet::IpNet;
952    use tokio::{
953        self,
954        net::{TcpStream, UnixStream},
955        task::JoinHandle,
956        time::{Duration, Instant},
957    };
958    use tokio_util::codec::{Framed, length_delimited};
959    use vector_lib::{
960        config::{LegacyKey, LogNamespace},
961        lookup::{OwnedValuePath, owned_value_path, path},
962        tcp::TcpKeepaliveConfig,
963        tls::{CertificateMetadata, MaybeTls, MaybeTlsSettings},
964    };
965
966    use super::{
967        ControlField, ControlHeader, FrameHandler, TcpFrameHandler, UnixFrameHandler,
968        build_framestream_tcp_source, build_framestream_unix_source, spawn_event_handling_tasks,
969    };
970    use crate::{
971        SourceSender,
972        config::{ComponentKey, log_schema},
973        event::{Event, LogEvent},
974        shutdown::SourceShutdownCoordinator,
975        sources::util::net::SocketListenAddr,
976        test_util::{collect_n, collect_n_stream, next_addr},
977    };
978
979    #[derive(Clone)]
980    struct MockFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
981        content_type: String,
982        max_frame_length: usize,
983        multithreaded: bool,
984        max_frame_handling_tasks: usize,
985        extra_task_handling_routine: F,
986        host_key: Option<OwnedValuePath>,
987        timestamp_key: Option<OwnedValuePath>,
988        source_type_key: Option<OwnedValuePath>,
989        log_namespace: LogNamespace,
990    }
991
992    #[derive(Clone)]
993    struct MockUnixFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
994        frame_handler: MockFrameHandler<F>,
995        socket_path: PathBuf,
996        socket_file_mode: Option<u32>,
997        socket_receive_buffer_size: Option<usize>,
998        socket_send_buffer_size: Option<usize>,
999    }
1000
1001    #[derive(Clone)]
1002    struct MockTcpFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
1003        frame_handler: MockFrameHandler<F>,
1004        address: SocketListenAddr,
1005        keepalive: Option<TcpKeepaliveConfig>,
1006        shutdown_timeout_secs: Duration,
1007        tls: MaybeTlsSettings,
1008        tls_client_metadata_key: Option<OwnedValuePath>,
1009        receive_buffer_bytes: Option<usize>,
1010        max_connection_duration_secs: Option<u64>,
1011        max_connections: Option<u32>,
1012        permit_origin: Option<Vec<IpNet>>,
1013    }
1014
1015    impl<F: Send + Sync + Clone + FnOnce() + 'static> MockTcpFrameHandler<F> {
1016        pub fn new(
1017            addr: SocketAddr,
1018            content_type: String,
1019            multithreaded: bool,
1020            extra_routine: F,
1021            permit_origin: Option<Vec<IpNet>>,
1022        ) -> Self {
1023            Self {
1024                frame_handler: MockFrameHandler::new(content_type, multithreaded, extra_routine),
1025                address: addr.into(),
1026                keepalive: None,
1027                shutdown_timeout_secs: Duration::from_secs(30),
1028                tls: MaybeTls::Raw(()),
1029                tls_client_metadata_key: None,
1030                receive_buffer_bytes: None,
1031                max_connection_duration_secs: None,
1032                max_connections: None,
1033                permit_origin,
1034            }
1035        }
1036    }
1037
1038    impl<F: Send + Sync + Clone + FnOnce() + 'static> MockUnixFrameHandler<F> {
1039        pub fn new(content_type: String, multithreaded: bool, extra_routine: F) -> Self {
1040            Self {
1041                frame_handler: MockFrameHandler::new(content_type, multithreaded, extra_routine),
1042                socket_path: tempfile::tempdir().unwrap().keep().join("unix_test"),
1043                socket_file_mode: None,
1044                socket_receive_buffer_size: None,
1045                socket_send_buffer_size: None,
1046            }
1047        }
1048    }
1049
1050    impl<F: Send + Sync + Clone + FnOnce() + 'static> MockFrameHandler<F> {
1051        pub fn new(content_type: String, multithreaded: bool, extra_routine: F) -> Self {
1052            Self {
1053                content_type,
1054                max_frame_length: bytesize::kib(100u64) as usize,
1055                multithreaded,
1056                max_frame_handling_tasks: 0,
1057                extra_task_handling_routine: extra_routine,
1058                host_key: Some(owned_value_path!("test_framestream")),
1059                timestamp_key: Some(owned_value_path!("my_timestamp")),
1060                source_type_key: Some(owned_value_path!("source_type")),
1061                log_namespace: LogNamespace::Legacy,
1062            }
1063        }
1064    }
1065
1066    impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockFrameHandler<F> {
1067        fn content_type(&self) -> String {
1068            self.content_type.clone()
1069        }
1070        fn max_frame_length(&self) -> usize {
1071            self.max_frame_length
1072        }
1073
1074        fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1075            let mut log_event = LogEvent::from(frame);
1076
1077            log_event.insert(
1078                log_schema().source_type_key_target_path().unwrap(),
1079                "framestream",
1080            );
1081            if let Some(host) = received_from {
1082                self.log_namespace.insert_source_metadata(
1083                    "framestream",
1084                    &mut log_event,
1085                    self.host_key.as_ref().map(LegacyKey::Overwrite),
1086                    path!("host"),
1087                    host,
1088                )
1089            }
1090
1091            (self.extra_task_handling_routine.clone())();
1092
1093            Some(log_event.into())
1094        }
1095
1096        fn multithreaded(&self) -> bool {
1097            self.multithreaded
1098        }
1099        fn max_frame_handling_tasks(&self) -> usize {
1100            self.max_frame_handling_tasks
1101        }
1102
1103        fn host_key(&self) -> &Option<OwnedValuePath> {
1104            &self.host_key
1105        }
1106
1107        fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1108            self.timestamp_key.as_ref()
1109        }
1110
1111        fn source_type_key(&self) -> Option<&OwnedValuePath> {
1112            self.source_type_key.as_ref()
1113        }
1114    }
1115
1116    impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockUnixFrameHandler<F> {
1117        fn content_type(&self) -> String {
1118            self.frame_handler.content_type()
1119        }
1120
1121        fn max_frame_length(&self) -> usize {
1122            self.frame_handler.max_frame_length()
1123        }
1124
1125        fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1126            self.frame_handler.handle_event(received_from, frame)
1127        }
1128
1129        fn multithreaded(&self) -> bool {
1130            self.frame_handler.multithreaded()
1131        }
1132
1133        fn max_frame_handling_tasks(&self) -> usize {
1134            self.frame_handler.max_frame_handling_tasks()
1135        }
1136
1137        fn host_key(&self) -> &Option<OwnedValuePath> {
1138            self.frame_handler.host_key()
1139        }
1140
1141        fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1142            self.frame_handler.timestamp_key()
1143        }
1144
1145        fn source_type_key(&self) -> Option<&OwnedValuePath> {
1146            self.frame_handler.source_type_key()
1147        }
1148    }
1149
1150    impl<F: Send + Sync + Clone + FnOnce() + 'static> UnixFrameHandler for MockUnixFrameHandler<F> {
1151        fn socket_path(&self) -> PathBuf {
1152            self.socket_path.clone()
1153        }
1154
1155        fn socket_file_mode(&self) -> Option<u32> {
1156            self.socket_file_mode
1157        }
1158
1159        fn socket_receive_buffer_size(&self) -> Option<usize> {
1160            self.socket_receive_buffer_size
1161        }
1162
1163        fn socket_send_buffer_size(&self) -> Option<usize> {
1164            self.socket_send_buffer_size
1165        }
1166    }
1167
1168    impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockTcpFrameHandler<F> {
1169        fn content_type(&self) -> String {
1170            self.frame_handler.content_type()
1171        }
1172
1173        fn max_frame_length(&self) -> usize {
1174            self.frame_handler.max_frame_length()
1175        }
1176
1177        fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1178            self.frame_handler.handle_event(received_from, frame)
1179        }
1180
1181        fn multithreaded(&self) -> bool {
1182            self.frame_handler.multithreaded()
1183        }
1184
1185        fn max_frame_handling_tasks(&self) -> usize {
1186            self.frame_handler.max_frame_handling_tasks()
1187        }
1188
1189        fn host_key(&self) -> &Option<OwnedValuePath> {
1190            self.frame_handler.host_key()
1191        }
1192
1193        fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1194            self.frame_handler.timestamp_key()
1195        }
1196
1197        fn source_type_key(&self) -> Option<&OwnedValuePath> {
1198            self.frame_handler.source_type_key()
1199        }
1200    }
1201
1202    impl<F: Send + Sync + Clone + FnOnce() + 'static> TcpFrameHandler for MockTcpFrameHandler<F> {
1203        fn address(&self) -> SocketListenAddr {
1204            self.address
1205        }
1206
1207        fn keepalive(&self) -> Option<TcpKeepaliveConfig> {
1208            self.keepalive
1209        }
1210
1211        fn shutdown_timeout_secs(&self) -> Duration {
1212            self.shutdown_timeout_secs
1213        }
1214
1215        fn tls(&self) -> MaybeTlsSettings {
1216            self.tls.clone()
1217        }
1218
1219        fn tls_client_metadata_key(&self) -> Option<OwnedValuePath> {
1220            self.tls_client_metadata_key.clone()
1221        }
1222
1223        fn receive_buffer_bytes(&self) -> Option<usize> {
1224            self.receive_buffer_bytes
1225        }
1226
1227        fn max_connection_duration_secs(&self) -> Option<u64> {
1228            self.max_connection_duration_secs
1229        }
1230
1231        fn max_connections(&self) -> Option<u32> {
1232            self.max_connections
1233        }
1234
1235        fn insert_tls_client_metadata(&mut self, _: Option<CertificateMetadata>) {}
1236
1237        fn allowed_origins(&self) -> Option<&[IpNet]> {
1238            self.permit_origin.as_deref()
1239        }
1240    }
1241
1242    fn init_framestream_tcp(
1243        source_id: &str,
1244        addr: &SocketAddr,
1245        frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
1246        pipeline: SourceSender,
1247    ) -> (JoinHandle<Result<(), ()>>, SourceShutdownCoordinator) {
1248        let source_id = ComponentKey::from(source_id);
1249        let mut shutdown = SourceShutdownCoordinator::default();
1250        let (shutdown_signal, _) = shutdown.register_source(&source_id, false);
1251        let server = build_framestream_tcp_source(frame_handler, shutdown_signal, pipeline)
1252            .expect("Failed to build framestream tcp source.");
1253
1254        let join_handle = tokio::spawn(server);
1255
1256        while std::net::TcpStream::connect(addr).is_err() {
1257            thread::sleep(Duration::from_millis(2));
1258        }
1259
1260        (join_handle, shutdown)
1261    }
1262
1263    fn init_framestream_unix(
1264        source_id: &str,
1265        frame_handler: impl UnixFrameHandler + Send + Sync + Clone + 'static,
1266        pipeline: SourceSender,
1267    ) -> (
1268        PathBuf,
1269        JoinHandle<Result<(), ()>>,
1270        SourceShutdownCoordinator,
1271    ) {
1272        let source_id = ComponentKey::from(source_id);
1273        let socket_path = frame_handler.socket_path();
1274        let mut shutdown = SourceShutdownCoordinator::default();
1275        let (shutdown_signal, _) = shutdown.register_source(&source_id, false);
1276        let server = build_framestream_unix_source(frame_handler, shutdown_signal, pipeline)
1277            .expect("Failed to build framestream unix source.");
1278
1279        let join_handle = tokio::spawn(server);
1280
1281        // Wait for server to accept traffic
1282        while std::os::unix::net::UnixStream::connect(&socket_path).is_err() {
1283            thread::sleep(Duration::from_millis(2));
1284        }
1285
1286        (socket_path, join_handle, shutdown)
1287    }
1288
1289    async fn make_tcp_stream(
1290        addr: SocketAddr,
1291    ) -> Framed<TcpStream, length_delimited::LengthDelimitedCodec> {
1292        let socket = TcpStream::connect(&addr).await.unwrap();
1293        Framed::new(socket, length_delimited::Builder::new().new_codec())
1294    }
1295
1296    async fn make_unix_stream(
1297        path: PathBuf,
1298    ) -> Framed<UnixStream, length_delimited::LengthDelimitedCodec> {
1299        let socket = UnixStream::connect(&path).await.unwrap();
1300        Framed::new(socket, length_delimited::Builder::new().new_codec())
1301    }
1302
1303    async fn send_data_frames<S: Sink<Bytes, Error = std::io::Error> + Unpin>(
1304        sock_sink: &mut S,
1305        frames: Vec<Result<Bytes, std::io::Error>>,
1306    ) {
1307        let mut stream = stream::iter(frames.into_iter());
1308        //send and send_all consume the sink
1309        _ = sock_sink.send_all(&mut stream).await;
1310    }
1311
1312    async fn send_control_frame<S: Sink<Bytes, Error = std::io::Error> + Unpin>(
1313        sock_sink: &mut S,
1314        frame: Bytes,
1315    ) {
1316        send_data_frames(sock_sink, vec![Ok(Bytes::new()), Ok(frame)]).await; //send empty frame to say we are control frame
1317    }
1318
1319    fn create_control_frame(header: ControlHeader) -> Bytes {
1320        Bytes::from(header.to_u32().to_be_bytes().to_vec())
1321    }
1322
1323    fn create_control_frame_with_content(
1324        header: ControlHeader,
1325        content_types: Vec<Bytes>,
1326    ) -> Bytes {
1327        let mut frame = BytesMut::from(&header.to_u32().to_be_bytes()[..]);
1328        for content_type in content_types {
1329            frame.extend(ControlField::ContentType.to_u32().to_be_bytes());
1330            frame.extend((content_type.len() as u32).to_be_bytes());
1331            frame.extend(content_type.clone());
1332        }
1333        Bytes::from(frame)
1334    }
1335
1336    fn assert_accept_frame(frame: &mut BytesMut, expected_content_type: Bytes) {
1337        //frame should start with 4 bytes saying ACCEPT
1338
1339        assert_eq!(&frame[..4], &ControlHeader::Accept.to_u32().to_be_bytes(),);
1340        frame.advance(4);
1341        //next should be content type field
1342        assert_eq!(
1343            &frame[..4],
1344            &ControlField::ContentType.to_u32().to_be_bytes(),
1345        );
1346        frame.advance(4);
1347        //next should be length of content_type
1348        assert_eq!(
1349            &frame[..4],
1350            &(expected_content_type.len() as u32).to_be_bytes(),
1351        );
1352        frame.advance(4);
1353        //rest should be content type
1354        assert_eq!(&frame[..], &expected_content_type[..]);
1355    }
1356
1357    fn create_frame_handler(multithreaded: bool) -> impl UnixFrameHandler + Send + Sync + Clone {
1358        MockUnixFrameHandler::new("test_content".to_string(), multithreaded, move || {})
1359    }
1360
1361    fn create_tcp_frame_handler(
1362        addr: SocketAddr,
1363        multithreaded: bool,
1364        permit_origin: Option<Vec<IpNet>>,
1365    ) -> impl TcpFrameHandler + Send + Sync + Clone {
1366        MockTcpFrameHandler::new(
1367            addr,
1368            "test_content".to_string(),
1369            multithreaded,
1370            move || {},
1371            permit_origin,
1372        )
1373    }
1374
1375    async fn signal_shutdown(source_name: &str, shutdown: &mut SourceShutdownCoordinator) {
1376        // Now signal to the Source to shut down.
1377        let deadline = Instant::now() + Duration::from_secs(10);
1378        let id = ComponentKey::from(source_name);
1379        let shutdown_complete = shutdown.shutdown_source(&id, deadline);
1380        let shutdown_success = shutdown_complete.await;
1381        assert!(shutdown_success);
1382    }
1383
1384    async fn test_normal_framestream<
1385        T: Sink<Bytes, Error = std::io::Error> + Unpin,
1386        U: Stream<Item = Result<BytesMut, std::io::Error>> + Unpin,
1387        V: Stream<Item = Event> + Unpin,
1388    >(
1389        source_name: &str,
1390        mut sock_sink: T,
1391        mut sock_stream: U,
1392        rx: V,
1393        mut shutdown: SourceShutdownCoordinator,
1394        source_handle: JoinHandle<Result<(), ()>>,
1395    ) {
1396        //1 - send READY frame (with content_type)
1397        let content_type = Bytes::from(&b"test_content"[..]);
1398        let ready_msg =
1399            create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1400        send_control_frame(&mut sock_sink, ready_msg).await;
1401
1402        //2 - wait for ACCEPT frame
1403        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1404        //take second element, because first will be empty (signifying control frame)
1405        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1406        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1407
1408        //3 - send START frame
1409        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Start)).await;
1410
1411        //4 - send data
1412        send_data_frames(
1413            &mut sock_sink,
1414            vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1415        )
1416        .await;
1417        let events = collect_n(rx, 2).await;
1418
1419        //5 - send STOP frame
1420        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1421
1422        let message_key = log_schema().message_key().unwrap().to_string();
1423        assert!(
1424            events
1425                .iter()
1426                .any(|e| e.as_log()[&message_key] == "hello".into())
1427        );
1428        assert!(
1429            events
1430                .iter()
1431                .any(|e| e.as_log()[&message_key] == "world".into())
1432        );
1433
1434        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1435
1436        // Ensure source actually shut down successfully.
1437        signal_shutdown(source_name, &mut shutdown).await;
1438        _ = source_handle.await.unwrap();
1439    }
1440
1441    async fn test_multiple_content_types<
1442        T: Sink<Bytes, Error = std::io::Error> + Unpin,
1443        U: Stream<Item = Result<BytesMut, std::io::Error>> + Unpin,
1444    >(
1445        source_name: &str,
1446        mut sock_sink: T,
1447        mut sock_stream: U,
1448        mut shutdown: SourceShutdownCoordinator,
1449        source_handle: JoinHandle<Result<(), ()>>,
1450    ) {
1451        //1 - send READY frame (with content_type)
1452        let content_type = Bytes::from(&b"test_content"[..]);
1453        let ready_msg = create_control_frame_with_content(
1454            ControlHeader::Ready,
1455            vec![Bytes::from(&b"test_content2"[..]), content_type.clone()],
1456        ); //can have multiple content types
1457        send_control_frame(&mut sock_sink, ready_msg).await;
1458
1459        //2 - wait for ACCEPT frame
1460        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1461
1462        //take second element, because first will be empty (signifying control frame)
1463        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1464        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1465
1466        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1467
1468        // Ensure source actually shut down successfully.
1469        signal_shutdown(source_name, &mut shutdown).await;
1470        _ = source_handle.await.unwrap();
1471    }
1472
1473    #[tokio::test(flavor = "multi_thread")]
1474    #[should_panic]
1475    async fn blocked_framestream_tcp() {
1476        let source_name = "test_source";
1477        let (tx, rx) = SourceSender::new_test();
1478        let addr = next_addr();
1479        let (source_handle, shutdown) = init_framestream_tcp(
1480            source_name,
1481            &addr,
1482            create_tcp_frame_handler(addr, false, Some(vec![])),
1483            tx,
1484        );
1485        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1486
1487        test_normal_framestream(
1488            source_name,
1489            sock_sink,
1490            sock_stream,
1491            rx,
1492            shutdown,
1493            source_handle,
1494        )
1495        .await;
1496    }
1497
1498    #[tokio::test(flavor = "multi_thread")]
1499    async fn normal_framestream_singlethreaded_tcp() {
1500        let source_name = "test_source";
1501        let (tx, rx) = SourceSender::new_test();
1502        let addr = next_addr();
1503        let (source_handle, shutdown) = init_framestream_tcp(
1504            source_name,
1505            &addr,
1506            create_tcp_frame_handler(addr, false, None),
1507            tx,
1508        );
1509        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1510
1511        test_normal_framestream(
1512            source_name,
1513            sock_sink,
1514            sock_stream,
1515            rx,
1516            shutdown,
1517            source_handle,
1518        )
1519        .await;
1520    }
1521
1522    #[tokio::test(flavor = "multi_thread")]
1523    async fn normal_framestream_singlethreaded_unix() {
1524        let source_name = "test_source";
1525        let (tx, rx) = SourceSender::new_test();
1526        let (path, source_handle, shutdown) =
1527            init_framestream_unix(source_name, create_frame_handler(false), tx);
1528        let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1529
1530        test_normal_framestream(
1531            source_name,
1532            sock_sink,
1533            sock_stream,
1534            rx,
1535            shutdown,
1536            source_handle,
1537        )
1538        .await;
1539    }
1540
1541    #[tokio::test(flavor = "multi_thread")]
1542    async fn normal_framestream_multithreaded_tcp() {
1543        let source_name = "test_source";
1544        let (tx, rx) = SourceSender::new_test();
1545        let addr = next_addr();
1546        let (source_handle, shutdown) = init_framestream_tcp(
1547            source_name,
1548            &addr,
1549            create_tcp_frame_handler(addr, true, None),
1550            tx,
1551        );
1552        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1553
1554        test_normal_framestream(
1555            source_name,
1556            sock_sink,
1557            sock_stream,
1558            rx,
1559            shutdown,
1560            source_handle,
1561        )
1562        .await;
1563    }
1564
1565    #[tokio::test(flavor = "multi_thread")]
1566    async fn normal_framestream_multithreaded_unix() {
1567        let source_name = "test_source";
1568        let (tx, rx) = SourceSender::new_test();
1569        let (path, source_handle, shutdown) =
1570            init_framestream_unix(source_name, create_frame_handler(true), tx);
1571        let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1572
1573        test_normal_framestream(
1574            source_name,
1575            sock_sink,
1576            sock_stream,
1577            rx,
1578            shutdown,
1579            source_handle,
1580        )
1581        .await;
1582    }
1583
1584    #[tokio::test(flavor = "multi_thread")]
1585    async fn multiple_content_types_tcp() {
1586        let source_name = "test_source";
1587        let (tx, _) = SourceSender::new_test();
1588        let addr = next_addr();
1589        let (source_handle, shutdown) = init_framestream_tcp(
1590            source_name,
1591            &addr,
1592            create_tcp_frame_handler(addr, false, None),
1593            tx,
1594        );
1595        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1596
1597        test_multiple_content_types(source_name, sock_sink, sock_stream, shutdown, source_handle)
1598            .await;
1599    }
1600
1601    #[tokio::test(flavor = "multi_thread")]
1602    async fn multiple_content_types_unix() {
1603        let source_name = "test_source";
1604        let (tx, _) = SourceSender::new_test();
1605        let (path, source_handle, shutdown) =
1606            init_framestream_unix(source_name, create_frame_handler(false), tx);
1607        let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1608
1609        test_multiple_content_types(source_name, sock_sink, sock_stream, shutdown, source_handle)
1610            .await;
1611    }
1612
1613    #[tokio::test(flavor = "multi_thread")]
1614    async fn wrong_content_type() {
1615        let source_name = "test_source";
1616        let (tx, _) = SourceSender::new_test();
1617        let (path, source_handle, mut shutdown) =
1618            init_framestream_unix(source_name, create_frame_handler(false), tx);
1619        let (mut sock_sink, mut sock_stream) = make_unix_stream(path).await.split();
1620
1621        //1 - send READY frame (with WRONG content_type)
1622        let ready_msg = create_control_frame_with_content(
1623            ControlHeader::Ready,
1624            vec![Bytes::from(&b"test_content2"[..])],
1625        ); //can have multiple content types
1626        send_control_frame(&mut sock_sink, ready_msg).await;
1627
1628        //2 - send READY frame (with RIGHT content_type)
1629        let content_type = Bytes::from(&b"test_content"[..]);
1630        let ready_msg =
1631            create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1632        send_control_frame(&mut sock_sink, ready_msg).await;
1633
1634        //3 - wait for ACCEPT frame
1635        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1636
1637        //take second element, because first will be empty (signifying control frame)
1638        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1639        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1640
1641        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1642
1643        // Ensure source actually shut down successfully.
1644        signal_shutdown(source_name, &mut shutdown).await;
1645        _ = source_handle.await.unwrap();
1646    }
1647
1648    #[tokio::test(flavor = "multi_thread")]
1649    async fn data_too_soon() {
1650        let source_name = "test_source";
1651        let (tx, rx) = SourceSender::new_test();
1652        let (path, source_handle, mut shutdown) =
1653            init_framestream_unix(source_name, create_frame_handler(false), tx);
1654        let (mut sock_sink, mut sock_stream) = make_unix_stream(path).await.split();
1655
1656        //1 - send data frame (too soon!)
1657        send_data_frames(
1658            &mut sock_sink,
1659            vec![Ok(Bytes::from("bad")), Ok(Bytes::from("data"))],
1660        )
1661        .await;
1662
1663        //2 - send READY frame (with content_type)
1664        let content_type = Bytes::from(&b"test_content"[..]);
1665        let ready_msg =
1666            create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1667        send_control_frame(&mut sock_sink, ready_msg).await;
1668
1669        //3 - wait for ACCEPT frame
1670        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1671
1672        //take second element, because first will be empty (signifying control frame)
1673        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1674        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1675
1676        //4 - send START frame
1677        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Start)).await;
1678
1679        //5 - send data (will go through)
1680        send_data_frames(
1681            &mut sock_sink,
1682            vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1683        )
1684        .await;
1685        let events = collect_n(rx, 2).await;
1686
1687        //6 - send STOP frame
1688        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1689
1690        assert_eq!(
1691            events[0].as_log()[log_schema().message_key().unwrap().to_string()],
1692            "hello".into(),
1693        );
1694        assert_eq!(
1695            events[1].as_log()[log_schema().message_key().unwrap().to_string()],
1696            "world".into(),
1697        );
1698
1699        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1700
1701        // Ensure source actually shut down successfully.
1702        signal_shutdown(source_name, &mut shutdown).await;
1703        _ = source_handle.await.unwrap();
1704    }
1705
1706    #[tokio::test(flavor = "multi_thread")]
1707    async fn unidirectional_framestream() {
1708        let source_name = "test_source";
1709        let (tx, rx) = SourceSender::new_test();
1710        let (path, source_handle, mut shutdown) =
1711            init_framestream_unix(source_name, create_frame_handler(false), tx);
1712        let (mut sock_sink, _) = make_unix_stream(path).await.split();
1713
1714        //1 - send START frame (with content_type)
1715        let content_type = Bytes::from(&b"test_content"[..]);
1716        let start_msg = create_control_frame_with_content(ControlHeader::Start, vec![content_type]);
1717        send_control_frame(&mut sock_sink, start_msg).await;
1718
1719        //4 - send data
1720        send_data_frames(
1721            &mut sock_sink,
1722            vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1723        )
1724        .await;
1725        let events = collect_n(rx, 2).await;
1726
1727        //5 - send STOP frame
1728        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1729
1730        assert_eq!(
1731            events[0].as_log()[log_schema().message_key().unwrap().to_string()],
1732            "hello".into(),
1733        );
1734        assert_eq!(
1735            events[1].as_log()[log_schema().message_key().unwrap().to_string()],
1736            "world".into(),
1737        );
1738
1739        // Ensure source actually shut down successfully.
1740        signal_shutdown(source_name, &mut shutdown).await;
1741        _ = source_handle.await.unwrap();
1742    }
1743
1744    #[tokio::test(flavor = "multi_thread")]
1745    async fn test_spawn_event_handling_tasks() {
1746        let (out, rx) = SourceSender::new_test();
1747
1748        let max_frame_handling_tasks = 20;
1749        let active_task_nums = Arc::new(AtomicUsize::new(0));
1750        let active_task_nums_copy = Arc::clone(&active_task_nums);
1751        let max_task_nums_reached = Arc::new(AtomicUsize::new(0));
1752        let max_task_nums_reached_copy = Arc::clone(&max_task_nums_reached);
1753
1754        let mut join_handles = vec![];
1755        let active_task_nums_copy_2 = Arc::clone(&active_task_nums_copy);
1756        let extra_routine = move || {
1757            thread::sleep(Duration::from_millis(10));
1758            max_task_nums_reached_copy.fetch_max(
1759                active_task_nums_copy_2.load(Ordering::Acquire),
1760                Ordering::AcqRel,
1761            );
1762        };
1763
1764        let total_events = max_frame_handling_tasks * 10;
1765
1766        join_handles.push(tokio::spawn(async move {
1767            future::ready({
1768                let events = collect_n(rx, total_events).await;
1769                assert_eq!(total_events, events.len(), "Missed events");
1770            })
1771            .await;
1772        }));
1773
1774        for i in 0..total_events {
1775            join_handles.push(
1776                spawn_event_handling_tasks(
1777                    Bytes::from(format!("event_{i}")),
1778                    MockFrameHandler::new("test_content".to_string(), true, extra_routine.clone()),
1779                    out.clone(),
1780                    None,
1781                    Arc::clone(&active_task_nums_copy),
1782                    max_frame_handling_tasks,
1783                )
1784                .await,
1785            );
1786        }
1787
1788        future::join_all(join_handles).await;
1789
1790        let final_task_nums = active_task_nums.load(Ordering::Acquire);
1791        assert_eq!(
1792            0, final_task_nums,
1793            "There should be NO left-over tasks at the end"
1794        );
1795
1796        let max_task_nums_reached_value = max_task_nums_reached.load(Ordering::Acquire);
1797        assert!(
1798            max_task_nums_reached_value > 1,
1799            "MultiThreaded mode does NOT work"
1800        );
1801        assert!(
1802            (max_task_nums_reached_value - max_frame_handling_tasks) < 2,
1803            "Max number of tasks at any given time should NOT Exceed max_frame_handling_tasks too much"
1804        );
1805    }
1806}