vector/sources/util/
framestream.rs

1#[cfg(unix)]
2use std::os::unix::{fs::PermissionsExt, io::AsRawFd};
3use std::{
4    convert::TryInto,
5    fs,
6    marker::{Send, Sync},
7    net::SocketAddr,
8    path::PathBuf,
9    sync::{
10        Arc, Mutex,
11        atomic::{AtomicUsize, Ordering},
12    },
13    time::Duration,
14};
15
16use bytes::{Buf, Bytes, BytesMut};
17use futures::{
18    executor::block_on,
19    future::{self, OptionFuture},
20    sink::{Sink, SinkExt},
21    stream::{self, StreamExt, TryStreamExt},
22};
23use futures_util::{Future, FutureExt, future::BoxFuture};
24use ipnet::IpNet;
25use listenfd::ListenFd;
26use tokio::{
27    self,
28    io::{AsyncRead, AsyncWrite},
29    net::{TcpStream, UnixListener},
30    task::JoinHandle,
31    time::sleep,
32};
33use tokio_stream::wrappers::UnixListenerStream;
34use tokio_util::codec::{Framed, length_delimited};
35use tracing::{Instrument, Span, field};
36use vector_lib::{
37    lookup::OwnedValuePath,
38    tcp::TcpKeepaliveConfig,
39    tls::{CertificateMetadata, MaybeTlsIncomingStream, MaybeTlsSettings},
40};
41
42use super::net::{RequestLimiter, SocketListenAddr};
43use crate::{
44    SourceSender,
45    event::Event,
46    internal_events::{
47        ConnectionOpen, OpenGauge, SocketBindError, SocketMode, SocketReceiveError,
48        TcpBytesReceived, TcpSocketError, TcpSocketTlsConnectionError, UnixSocketError,
49        UnixSocketFileDeleteError,
50    },
51    shutdown::ShutdownSignal,
52    sources::{
53        Source,
54        util::{
55            AfterReadExt,
56            net::{MAX_IN_FLIGHT_EVENTS_TARGET, try_bind_tcp_listener},
57        },
58    },
59};
60
61const FSTRM_CONTROL_FRAME_LENGTH_MAX: usize = 512;
62const FSTRM_CONTROL_FIELD_CONTENT_TYPE_LENGTH_MAX: usize = 256;
63
64/// If a connection does not receive any data during this short timeout,
65/// it should release its permit (and try to obtain a new one) allowing other connections to read.
66/// It is very short because any incoming data will avoid this timeout,
67/// so it mainly prevents holding permits without consuming any data
68const PERMIT_HOLD_TIMEOUT_MS: u64 = 10;
69
70pub type FrameStreamSink = Box<dyn Sink<Bytes, Error = std::io::Error> + Send + Unpin>;
71
72pub struct FrameStreamReader {
73    response_sink: Mutex<FrameStreamSink>,
74    expected_content_type: String,
75    state: FrameStreamState,
76}
77
78struct FrameStreamState {
79    expect_control_frame: bool,
80    control_state: ControlState,
81    is_bidirectional: bool,
82}
83impl FrameStreamState {
84    const fn new() -> Self {
85        FrameStreamState {
86            expect_control_frame: false,
87            //first control frame should be READY (if bidirectional -- if unidirectional first will be START)
88            control_state: ControlState::Initial,
89            is_bidirectional: true, //assume
90        }
91    }
92}
93
94#[derive(PartialEq, Debug)]
95enum ControlState {
96    Initial,
97    GotReady,
98    ReadingData,
99    Stopped,
100}
101
102#[derive(Copy, Clone)]
103enum ControlHeader {
104    Accept,
105    Start,
106    Stop,
107    Ready,
108    Finish,
109}
110
111impl ControlHeader {
112    fn from_u32(val: u32) -> Result<Self, ()> {
113        match val {
114            0x01 => Ok(ControlHeader::Accept),
115            0x02 => Ok(ControlHeader::Start),
116            0x03 => Ok(ControlHeader::Stop),
117            0x04 => Ok(ControlHeader::Ready),
118            0x05 => Ok(ControlHeader::Finish),
119            _ => {
120                error!("Don't know header value {} (expected 0x01 - 0x05).", val);
121                Err(())
122            }
123        }
124    }
125
126    const fn to_u32(self) -> u32 {
127        match self {
128            ControlHeader::Accept => 0x01,
129            ControlHeader::Start => 0x02,
130            ControlHeader::Stop => 0x03,
131            ControlHeader::Ready => 0x04,
132            ControlHeader::Finish => 0x05,
133        }
134    }
135}
136
137enum ControlField {
138    ContentType,
139}
140
141impl ControlField {
142    fn from_u32(val: u32) -> Result<Self, ()> {
143        match val {
144            0x01 => Ok(ControlField::ContentType),
145            _ => {
146                error!("Don't know field type {} (expected 0x01).", val);
147                Err(())
148            }
149        }
150    }
151    const fn to_u32(&self) -> u32 {
152        match self {
153            ControlField::ContentType => 0x01,
154        }
155    }
156}
157
158fn advance_u32(b: &mut Bytes) -> Result<u32, ()> {
159    if b.len() < 4 {
160        error!("Malformed frame.");
161        return Err(());
162    }
163    let a = b.split_to(4);
164    Ok(u32::from_be_bytes(a[..].try_into().unwrap()))
165}
166
167impl FrameStreamReader {
168    pub fn new(response_sink: FrameStreamSink, expected_content_type: String) -> Self {
169        FrameStreamReader {
170            response_sink: Mutex::new(response_sink),
171            expected_content_type,
172            state: FrameStreamState::new(),
173        }
174    }
175
176    pub fn handle_frame(&mut self, frame: Bytes) -> Option<Bytes> {
177        if frame.is_empty() {
178            //frame length of zero means the next frame is a control frame
179            self.state.expect_control_frame = true;
180            None
181        } else if self.state.expect_control_frame {
182            self.state.expect_control_frame = false;
183            _ = self.handle_control_frame(frame);
184            None
185        } else {
186            //data frame
187            if self.state.control_state == ControlState::ReadingData {
188                Some(frame) //return data frame
189            } else {
190                error!(
191                    "Received a data frame while in state {:?}.",
192                    self.state.control_state
193                );
194                None
195            }
196        }
197    }
198
199    fn handle_control_frame(&mut self, mut frame: Bytes) -> Result<(), ()> {
200        //enforce maximum control frame size
201        if frame.len() > FSTRM_CONTROL_FRAME_LENGTH_MAX {
202            error!("Control frame is too long.");
203        }
204
205        let header = ControlHeader::from_u32(advance_u32(&mut frame)?)?;
206
207        //match current state to received header
208        match self.state.control_state {
209            ControlState::Initial => {
210                match header {
211                    ControlHeader::Ready => {
212                        let content_type = self.process_fields(header, &mut frame)?.unwrap();
213
214                        self.send_control_frame(Self::make_frame(
215                            ControlHeader::Accept,
216                            Some(content_type),
217                        ));
218                        self.state.control_state = ControlState::GotReady; //waiting for a START control frame
219                    }
220                    ControlHeader::Start => {
221                        //check for content type
222                        _ = self.process_fields(header, &mut frame)?;
223                        //if didn't error, then we are ok to change state
224                        self.state.control_state = ControlState::ReadingData;
225                        self.state.is_bidirectional = false; //if first message was START then we are unidirectional (no responses)
226                    }
227                    _ => error!("Got wrong control frame, expected READY."),
228                }
229            }
230            ControlState::GotReady => {
231                match header {
232                    ControlHeader::Start => {
233                        //check for content type
234                        _ = self.process_fields(header, &mut frame)?;
235                        //if didn't error, then we are ok to change state
236                        self.state.control_state = ControlState::ReadingData;
237                    }
238                    _ => error!("Got wrong control frame, expected START."),
239                }
240            }
241            ControlState::ReadingData => {
242                match header {
243                    ControlHeader::Stop => {
244                        //check there aren't any fields
245                        _ = self.process_fields(header, &mut frame)?;
246                        if self.state.is_bidirectional {
247                            //send FINISH frame -- but only if we are bidirectional
248                            self.send_control_frame(Self::make_frame(ControlHeader::Finish, None));
249                        }
250                        self.state.control_state = ControlState::Stopped; //stream is now done
251                    }
252                    _ => error!("Got wrong control frame, expected STOP."),
253                }
254            }
255            ControlState::Stopped => error!("Unexpected control frame, current state is STOPPED."),
256        };
257        Ok(())
258    }
259
260    fn process_fields(
261        &mut self,
262        header: ControlHeader,
263        frame: &mut Bytes,
264    ) -> Result<Option<String>, ()> {
265        match header {
266            ControlHeader::Ready => {
267                //should provide 1+ content types
268                //should match expected content type
269                let is_start_frame = false;
270                let content_type = self.process_content_type(frame, is_start_frame)?;
271                Ok(Some(content_type))
272            }
273            ControlHeader::Start => {
274                //can take one or zero content types
275                if frame.is_empty() {
276                    Ok(None)
277                } else {
278                    //should match expected content type
279                    let is_start_frame = true;
280                    let content_type = self.process_content_type(frame, is_start_frame)?;
281                    Ok(Some(content_type))
282                }
283            }
284            ControlHeader::Stop => {
285                //check that there are no fields
286                if !frame.is_empty() {
287                    error!("Unexpected fields in STOP header.");
288                    Err(())
289                } else {
290                    Ok(None)
291                }
292            }
293            _ => {
294                error!("Unexpected control header value {:?}.", header.to_u32());
295                Err(())
296            }
297        }
298    }
299
300    fn process_content_type(&self, frame: &mut Bytes, is_start_frame: bool) -> Result<String, ()> {
301        if frame.is_empty() {
302            error!("No fields in control frame.");
303            return Err(());
304        }
305
306        let mut content_types = vec![];
307        while !frame.is_empty() {
308            //4 bytes of ControlField
309            let field_val = advance_u32(frame)?;
310            let field_type = ControlField::from_u32(field_val)?;
311            match field_type {
312                ControlField::ContentType => {
313                    //4 bytes giving length of content type
314                    let field_len = advance_u32(frame)? as usize;
315
316                    //enforce limit on content type string
317                    if field_len > FSTRM_CONTROL_FIELD_CONTENT_TYPE_LENGTH_MAX {
318                        error!("Content-Type string is too long.");
319                        return Err(());
320                    }
321
322                    let content_type = std::str::from_utf8(&frame[..field_len]).unwrap();
323                    content_types.push(content_type.to_string());
324                    frame.advance(field_len);
325                }
326            }
327        }
328
329        if is_start_frame && content_types.len() > 1 {
330            error!(
331                "START control frame can only have one content-type provided (got {}).",
332                content_types.len()
333            );
334            return Err(());
335        }
336
337        for content_type in &content_types {
338            if *content_type == self.expected_content_type {
339                return Ok(content_type.clone());
340            }
341        }
342
343        error!(
344            "Content types did not match up. Expected {} got {:?}.",
345            self.expected_content_type, content_types
346        );
347        Err(())
348    }
349
350    fn make_frame(header: ControlHeader, content_type: Option<String>) -> Bytes {
351        let mut frame = BytesMut::new();
352        frame.extend(header.to_u32().to_be_bytes());
353        if let Some(s) = content_type {
354            frame.extend(ControlField::ContentType.to_u32().to_be_bytes()); //field type: ContentType
355            frame.extend((s.len() as u32).to_be_bytes()); //length of type
356            frame.extend(s.as_bytes());
357        }
358        Bytes::from(frame)
359    }
360
361    fn send_control_frame(&mut self, frame: Bytes) {
362        let empty_frame = Bytes::from(&b""[..]); //send empty frame to say we are control frame
363        let mut stream = stream::iter(vec![Ok(empty_frame), Ok(frame)]);
364
365        if let Err(e) = block_on(self.response_sink.lock().unwrap().send_all(&mut stream)) {
366            error!("Encountered error '{:#?}' while sending control frame.", e);
367        }
368    }
369}
370
371pub trait FrameHandler {
372    fn content_type(&self) -> String;
373    fn max_frame_length(&self) -> usize;
374    fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event>;
375    fn multithreaded(&self) -> bool;
376    fn max_frame_handling_tasks(&self) -> usize;
377    fn host_key(&self) -> &Option<OwnedValuePath>;
378    fn timestamp_key(&self) -> Option<&OwnedValuePath>;
379    fn source_type_key(&self) -> Option<&OwnedValuePath>;
380}
381
382pub trait UnixFrameHandler: FrameHandler {
383    fn socket_path(&self) -> PathBuf;
384    fn socket_file_mode(&self) -> Option<u32>;
385    fn socket_receive_buffer_size(&self) -> Option<usize>;
386    fn socket_send_buffer_size(&self) -> Option<usize>;
387}
388
389pub trait TcpFrameHandler: FrameHandler {
390    fn address(&self) -> SocketListenAddr;
391    fn keepalive(&self) -> Option<TcpKeepaliveConfig>;
392    fn shutdown_timeout_secs(&self) -> Duration;
393    fn tls(&self) -> MaybeTlsSettings;
394    fn tls_client_metadata_key(&self) -> Option<OwnedValuePath>;
395    fn receive_buffer_bytes(&self) -> Option<usize>;
396    fn max_connection_duration_secs(&self) -> Option<u64>;
397    fn max_connections(&self) -> Option<u32>;
398    fn allowed_origins(&self) -> Option<&[IpNet]>;
399    fn insert_tls_client_metadata(&mut self, metadata: Option<CertificateMetadata>);
400}
401
402/**
403 * Based off of the build_framestream_unix_source function.
404 * Functions similarly, just uses TCP socket instead of unix socket
405 **/
406pub fn build_framestream_tcp_source(
407    frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
408    shutdown: ShutdownSignal,
409    out: SourceSender,
410) -> crate::Result<Source> {
411    let addr = frame_handler.address();
412    let tls = frame_handler.tls();
413    let shutdown = shutdown.clone();
414    let out = out.clone();
415
416    Ok(Box::pin(async move {
417        let listenfd = ListenFd::from_env();
418        let listener = try_bind_tcp_listener(
419            addr,
420            listenfd,
421            &tls,
422            frame_handler
423                .allowed_origins()
424                .map(|origins| origins.to_vec()),
425        )
426        .await
427        .map_err(|error| {
428            emit!(SocketBindError {
429                mode: SocketMode::Tcp,
430                error: &error,
431            })
432        })?;
433
434        info!(
435            message = "Listening.",
436            addr = %listener
437                .local_addr()
438                .map(SocketListenAddr::SocketAddr)
439                .unwrap_or(addr)
440        );
441
442        let tripwire = shutdown.clone();
443        let shutdown_timeout_secs = frame_handler.shutdown_timeout_secs();
444        let tripwire = async move {
445            _ = tripwire.await;
446            sleep(shutdown_timeout_secs).await;
447        }
448        .shared();
449
450        let connection_gauge = OpenGauge::new();
451        let shutdown_clone = shutdown.clone();
452
453        let request_limiter = RequestLimiter::new(
454            MAX_IN_FLIGHT_EVENTS_TARGET,
455            frame_handler.max_frame_handling_tasks(),
456        );
457
458        listener
459            .accept_stream_limited(frame_handler.max_connections())
460            .take_until(shutdown_clone)
461            .for_each(move |(connection, tcp_connection_permit)| {
462                let shutdown_signal = shutdown.clone();
463                let tripwire = tripwire.clone();
464                let out = out.clone();
465                let connection_gauge = connection_gauge.clone();
466                let request_limiter = request_limiter.clone();
467                let frame_handler_clone = frame_handler.clone();
468
469                async move {
470                    let socket = match connection {
471                        Ok(socket) => socket,
472                        Err(error) => {
473                            emit!(SocketReceiveError {
474                                mode: SocketMode::Tcp,
475                                error: &error
476                            });
477                            return;
478                        }
479                    };
480
481                    let peer_addr = socket.peer_addr();
482                    let span = info_span!("connection", %peer_addr);
483
484                    let tripwire = tripwire
485                        .map(move |_| {
486                            info!(
487                                message = "Resetting connection (still open after seconds).",
488                                seconds = ?shutdown_timeout_secs
489                            );
490                        })
491                        .boxed();
492
493                    span.clone().in_scope(|| {
494                        debug!(message = "Accepted a new connection.", peer_addr = %peer_addr);
495
496                        let open_token =
497                            connection_gauge.open(|count| emit!(ConnectionOpen { count }));
498
499                        let fut = handle_stream(
500                            frame_handler_clone,
501                            shutdown_signal,
502                            socket,
503                            tripwire,
504                            peer_addr,
505                            out,
506                            request_limiter,
507                        );
508
509                        tokio::spawn(
510                            fut.map(move |()| {
511                                drop(open_token);
512                                drop(tcp_connection_permit);
513                            })
514                            .instrument(span.or_current()),
515                        );
516                    });
517                }
518            })
519            .map(Ok)
520            .await
521    }))
522}
523
524#[allow(clippy::too_many_arguments)]
525async fn handle_stream(
526    mut frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
527    mut shutdown_signal: ShutdownSignal,
528    mut socket: MaybeTlsIncomingStream<TcpStream>,
529    mut tripwire: BoxFuture<'static, ()>,
530    peer_addr: SocketAddr,
531    out: SourceSender,
532    request_limiter: RequestLimiter,
533) {
534    tokio::select! {
535        result = socket.handshake() => {
536            if let Err(error) = result {
537                emit!(TcpSocketTlsConnectionError { error });
538                return;
539            }
540        },
541        _ = &mut shutdown_signal => {
542            return;
543        }
544    };
545
546    if let Some(keepalive) = frame_handler.keepalive()
547        && let Err(error) = socket.set_keepalive(keepalive)
548    {
549        warn!(message = "Failed configuring TCP keepalive.", %error);
550    }
551
552    if let Some(receive_buffer_bytes) = frame_handler.receive_buffer_bytes()
553        && let Err(error) = socket.set_receive_buffer_bytes(receive_buffer_bytes)
554    {
555        warn!(message = "Failed configuring receive buffer size on TCP socket.", %error);
556    }
557
558    let socket = socket.after_read(move |byte_size| {
559        emit!(TcpBytesReceived {
560            byte_size,
561            peer_addr
562        });
563    });
564
565    let certificate_metadata = socket
566        .get_ref()
567        .ssl_stream()
568        .and_then(|stream| stream.ssl().peer_certificate())
569        .map(CertificateMetadata::from);
570
571    frame_handler.insert_tls_client_metadata(certificate_metadata);
572
573    let span = info_span!("connection");
574    span.record("peer_addr", field::debug(&peer_addr));
575    let received_from: Option<Bytes> = Some(peer_addr.to_string().into());
576
577    let connection_close_timeout = OptionFuture::from(
578        frame_handler
579            .max_connection_duration_secs()
580            .map(|timeout_secs| tokio::time::sleep(Duration::from_secs(timeout_secs))),
581    );
582    tokio::pin!(connection_close_timeout);
583
584    let content_type = frame_handler.content_type();
585    let mut event_sink = out.clone();
586    let (sock_sink, sock_stream) = Framed::new(
587        socket,
588        length_delimited::Builder::new()
589            .max_frame_length(frame_handler.max_frame_length())
590            .new_codec(),
591    )
592    .split();
593    let mut reader = FrameStreamReader::new(Box::new(sock_sink), content_type);
594    let mut frames = sock_stream
595        .map_err(move |error| {
596            emit!(TcpSocketError {
597                error: &error,
598                peer_addr,
599            });
600        })
601        .filter_map(move |frame| {
602            future::ready(match frame {
603                Ok(f) => reader.handle_frame(Bytes::from(f)),
604                Err(_) => None,
605            })
606        });
607
608    let active_parsing_task_nums = Arc::new(AtomicUsize::new(0));
609    loop {
610        let mut permit = tokio::select! {
611            _ = &mut tripwire => break,
612            Some(_) = &mut connection_close_timeout  => {
613                break;
614            },
615            _ = &mut shutdown_signal => {
616                break;
617            },
618            permit = request_limiter.acquire() => {
619                Some(permit)
620            }
621            else => break,
622        };
623
624        let timeout = tokio::time::sleep(Duration::from_millis(PERMIT_HOLD_TIMEOUT_MS));
625        tokio::pin!(timeout);
626
627        tokio::select! {
628            _ = &mut tripwire => break,
629            _ = &mut shutdown_signal => break,
630            _ = &mut timeout => {
631                // This connection is currently holding a permit, but has not received data for some time. Release
632                // the permit to let another connection try
633                continue;
634            }
635            res = frames.next() => {
636                match res {
637                    Some(frame) => {
638                        if let Some(permit) = &mut permit {
639                            // Note that this is intentionally not the "number of events in a single request", but rather
640                            // the "number of events currently available". This may contain events from multiple events,
641                            // but it should always contain all events from each request.
642                            permit.decoding_finished(1);
643                        };
644                        handle_tcp_frame(&mut frame_handler, frame, &mut event_sink, received_from.clone(), Arc::clone(&active_parsing_task_nums)).await;
645                    }
646                    None => {
647                        debug!("Connection closed.");
648                        break
649                    },
650                }
651            }
652            else => break,
653        }
654
655        drop(permit);
656    }
657}
658
659async fn handle_tcp_frame<T>(
660    frame_handler: &mut T,
661    frame: Bytes,
662    event_sink: &mut SourceSender,
663    received_from: Option<Bytes>,
664    active_parsing_task_nums: Arc<AtomicUsize>,
665) where
666    T: TcpFrameHandler + Send + Sync + Clone + 'static,
667{
668    if frame_handler.multithreaded() {
669        spawn_event_handling_tasks(
670            frame,
671            frame_handler.clone(),
672            event_sink.clone(),
673            received_from,
674            active_parsing_task_nums,
675            frame_handler.max_frame_handling_tasks(),
676        )
677        .await;
678    } else if let Some(event) = frame_handler.handle_event(received_from, frame)
679        && let Err(e) = event_sink.send_event(event).await
680    {
681        error!(
682            internal_log_rate_limit = true,
683            "Error sending event: {e:?}."
684        );
685    }
686}
687
688/**
689 * Based off of the build_unix_source function.
690 * Functions similarly, but uses the FrameStreamReader to deal with
691 * framestream control packets, and responds appropriately.
692 **/
693pub fn build_framestream_unix_source(
694    frame_handler: impl UnixFrameHandler + Send + Sync + Clone + 'static,
695    shutdown: ShutdownSignal,
696    out: SourceSender,
697) -> crate::Result<Source> {
698    let path = frame_handler.socket_path();
699
700    //check if the path already exists (and try to delete it)
701    match fs::metadata(&path) {
702        Ok(_) => {
703            //exists, so try to delete it
704            info!(message = "Deleting file.", ?path);
705            fs::remove_file(&path)?;
706        }
707        Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {} //doesn't exist, do nothing
708        Err(e) => {
709            error!("Unable to get socket information; error = {:?}.", e);
710            return Err(Box::new(e));
711        }
712    };
713
714    let listener = UnixListener::bind(&path)?;
715
716    // system's 'net.core.rmem_max' might have to be changed if socket receive buffer is not updated properly
717    if let Some(socket_receive_buffer_size) = frame_handler.socket_receive_buffer_size() {
718        _ = nix::sys::socket::setsockopt(
719            listener.as_raw_fd(),
720            nix::sys::socket::sockopt::RcvBuf,
721            &(socket_receive_buffer_size),
722        );
723        let rcv_buf_size =
724            nix::sys::socket::getsockopt(listener.as_raw_fd(), nix::sys::socket::sockopt::RcvBuf);
725        info!(
726            "Unix socket receive buffer size modified to {}.",
727            rcv_buf_size.unwrap()
728        );
729    }
730
731    // system's 'net.core.wmem_max' might have to be changed if socket send buffer is not updated properly
732    if let Some(socket_send_buffer_size) = frame_handler.socket_send_buffer_size() {
733        _ = nix::sys::socket::setsockopt(
734            listener.as_raw_fd(),
735            nix::sys::socket::sockopt::SndBuf,
736            &(socket_send_buffer_size),
737        );
738        let snd_buf_size =
739            nix::sys::socket::getsockopt(listener.as_raw_fd(), nix::sys::socket::sockopt::SndBuf);
740        info!(
741            "Unix socket buffer send size modified to {}.",
742            snd_buf_size.unwrap()
743        );
744    }
745
746    // the permissions to unix socket are restricted from 0o700 to 0o777, which are 448 and 511 in decimal
747    if let Some(socket_permission) = frame_handler.socket_file_mode() {
748        if !(448..=511).contains(&socket_permission) {
749            return Err(format!(
750                "Invalid Socket permission {socket_permission:#o}. Must between 0o700 and 0o777."
751            )
752            .into());
753        }
754        match fs::set_permissions(&path, fs::Permissions::from_mode(socket_permission)) {
755            Ok(_) => {
756                info!("Socket permissions updated to {:#o}.", socket_permission);
757            }
758            Err(e) => {
759                error!(
760                    "Failed to update listener socket permissions; error = {:?}.",
761                    e
762                );
763                return Err(Box::new(e));
764            }
765        };
766    };
767
768    let fut = async move {
769        let active_parsing_task_nums = Arc::new(AtomicUsize::new(0));
770
771        info!(message = "Listening...", ?path, r#type = "unix");
772
773        let mut stream = UnixListenerStream::new(listener).take_until(shutdown.clone());
774        while let Some(socket) = stream.next().await {
775            let socket = match socket {
776                Err(e) => {
777                    error!("Failed to accept socket; error = {:?}.", e);
778                    continue;
779                }
780                Ok(s) => s,
781            };
782            let peer_addr = socket.peer_addr().ok();
783            let listen_path = path.clone();
784            let active_task_nums_ = Arc::clone(&active_parsing_task_nums);
785
786            let span = info_span!("connection");
787            let path = if let Some(addr) = peer_addr {
788                if let Some(path) = addr.as_pathname().map(|e| e.to_owned()) {
789                    span.record("peer_path", field::debug(&path));
790                    Some(path)
791                } else {
792                    None
793                }
794            } else {
795                None
796            };
797            let received_from: Option<Bytes> =
798                path.map(|p| p.to_string_lossy().into_owned().into());
799
800            build_framestream_source(
801                frame_handler.clone(),
802                socket,
803                received_from,
804                out.clone(),
805                shutdown.clone(),
806                span,
807                active_task_nums_,
808                move |error| {
809                    emit!(UnixSocketError {
810                        error: &error,
811                        path: &listen_path,
812                    });
813                },
814            );
815        }
816
817        // Cleanup
818        drop(stream);
819
820        // Delete socket file
821        if let Err(error) = fs::remove_file(&path) {
822            emit!(UnixSocketFileDeleteError { path: &path, error });
823        }
824
825        Ok(())
826    };
827
828    Ok(Box::pin(fut))
829}
830
831#[allow(clippy::too_many_arguments)]
832fn build_framestream_source<T: Send + 'static>(
833    frame_handler: impl FrameHandler + Send + Sync + Clone + 'static,
834    socket: impl AsyncRead + AsyncWrite + Send + 'static,
835    received_from: Option<Bytes>,
836    out: SourceSender,
837    shutdown: impl Future<Output = T> + Unpin + Send + 'static,
838    span: Span,
839    active_task_nums: Arc<AtomicUsize>,
840    error_mapper: impl FnMut(std::io::Error) + Send + 'static,
841) {
842    let content_type = frame_handler.content_type();
843    let mut event_sink = out.clone();
844    let (sock_sink, sock_stream) = Framed::new(
845        socket,
846        length_delimited::Builder::new()
847            .max_frame_length(frame_handler.max_frame_length())
848            .new_codec(),
849    )
850    .split();
851    let mut fs_reader = FrameStreamReader::new(Box::new(sock_sink), content_type);
852    let frame_handler_copy = frame_handler.clone();
853    let frames = sock_stream
854        .take_until(shutdown)
855        .map_err(error_mapper)
856        .filter_map(move |frame| {
857            future::ready(match frame {
858                Ok(f) => fs_reader.handle_frame(Bytes::from(f)),
859                Err(_) => None,
860            })
861        });
862    if !frame_handler.multithreaded() {
863        let mut events = frames.filter_map(move |f| {
864            future::ready(frame_handler_copy.handle_event(received_from.clone(), f))
865        });
866
867        let handler = async move {
868            if let Err(e) = event_sink.send_event_stream(&mut events).await {
869                error!(
870                    internal_log_rate_limit = true,
871                    "Error sending event: {:?}.", e
872                );
873            }
874
875            info!("Finished sending.");
876        };
877        tokio::spawn(handler.instrument(span.or_current()));
878    } else {
879        let handler = async move {
880            frames
881                .for_each(move |f| {
882                    let max_frame_handling_tasks = frame_handler_copy.max_frame_handling_tasks();
883                    let f_handler = frame_handler_copy.clone();
884                    let received_from_copy = received_from.clone();
885                    let event_sink_copy = event_sink.clone();
886                    let active_task_nums_copy = Arc::clone(&active_task_nums);
887
888                    async move {
889                        spawn_event_handling_tasks(
890                            f,
891                            f_handler,
892                            event_sink_copy,
893                            received_from_copy,
894                            active_task_nums_copy,
895                            max_frame_handling_tasks,
896                        )
897                        .await;
898                    }
899                })
900                .await;
901            info!("Finished sending.");
902        };
903        tokio::spawn(handler.instrument(span.or_current()));
904    }
905}
906
907async fn spawn_event_handling_tasks(
908    event_data: Bytes,
909    event_handler: impl FrameHandler + Send + Sync + 'static,
910    mut event_sink: SourceSender,
911    received_from: Option<Bytes>,
912    active_task_nums: Arc<AtomicUsize>,
913    max_frame_handling_tasks: usize,
914) -> JoinHandle<()> {
915    wait_for_task_quota(&active_task_nums, max_frame_handling_tasks).await;
916
917    tokio::spawn(async move {
918        future::ready({
919            if let Some(evt) = event_handler.handle_event(received_from, event_data)
920                && event_sink.send_event(evt).await.is_err()
921            {
922                error!("Encountered error while sending event.");
923            }
924            active_task_nums.fetch_sub(1, Ordering::AcqRel);
925        })
926        .await;
927    })
928}
929
930async fn wait_for_task_quota(active_task_nums: &Arc<AtomicUsize>, max_tasks: usize) {
931    while max_tasks > 0 && max_tasks < active_task_nums.load(Ordering::Acquire) {
932        tokio::time::sleep(Duration::from_millis(3)).await;
933    }
934    active_task_nums.fetch_add(1, Ordering::AcqRel);
935}
936
937#[cfg(test)]
938mod test {
939    use std::net::SocketAddr;
940    #[cfg(unix)]
941    use std::{
942        path::PathBuf,
943        sync::{
944            Arc,
945            atomic::{AtomicUsize, Ordering},
946        },
947        thread,
948    };
949
950    use bytes::{Bytes, BytesMut, buf::Buf};
951    use futures::{
952        future,
953        sink::{Sink, SinkExt},
954        stream::{self, StreamExt},
955    };
956    use futures_util::Stream;
957    use ipnet::IpNet;
958    use tokio::{
959        self,
960        net::{TcpStream, UnixStream},
961        task::JoinHandle,
962        time::{Duration, Instant},
963    };
964    use tokio_util::codec::{Framed, length_delimited};
965    use vector_lib::{
966        config::{LegacyKey, LogNamespace},
967        lookup::{OwnedValuePath, owned_value_path, path},
968        tcp::TcpKeepaliveConfig,
969        tls::{CertificateMetadata, MaybeTls, MaybeTlsSettings},
970    };
971
972    use super::{
973        ControlField, ControlHeader, FrameHandler, TcpFrameHandler, UnixFrameHandler,
974        build_framestream_tcp_source, build_framestream_unix_source, spawn_event_handling_tasks,
975    };
976    use crate::{
977        SourceSender,
978        config::{ComponentKey, log_schema},
979        event::{Event, LogEvent},
980        shutdown::SourceShutdownCoordinator,
981        sources::util::net::SocketListenAddr,
982        test_util::{collect_n, collect_n_stream, next_addr},
983    };
984
985    #[derive(Clone)]
986    struct MockFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
987        content_type: String,
988        max_frame_length: usize,
989        multithreaded: bool,
990        max_frame_handling_tasks: usize,
991        extra_task_handling_routine: F,
992        host_key: Option<OwnedValuePath>,
993        timestamp_key: Option<OwnedValuePath>,
994        source_type_key: Option<OwnedValuePath>,
995        log_namespace: LogNamespace,
996    }
997
998    #[derive(Clone)]
999    struct MockUnixFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
1000        frame_handler: MockFrameHandler<F>,
1001        socket_path: PathBuf,
1002        socket_file_mode: Option<u32>,
1003        socket_receive_buffer_size: Option<usize>,
1004        socket_send_buffer_size: Option<usize>,
1005    }
1006
1007    #[derive(Clone)]
1008    struct MockTcpFrameHandler<F: Send + Sync + Clone + FnOnce() + 'static> {
1009        frame_handler: MockFrameHandler<F>,
1010        address: SocketListenAddr,
1011        keepalive: Option<TcpKeepaliveConfig>,
1012        shutdown_timeout_secs: Duration,
1013        tls: MaybeTlsSettings,
1014        tls_client_metadata_key: Option<OwnedValuePath>,
1015        receive_buffer_bytes: Option<usize>,
1016        max_connection_duration_secs: Option<u64>,
1017        max_connections: Option<u32>,
1018        permit_origin: Option<Vec<IpNet>>,
1019    }
1020
1021    impl<F: Send + Sync + Clone + FnOnce() + 'static> MockTcpFrameHandler<F> {
1022        pub fn new(
1023            addr: SocketAddr,
1024            content_type: String,
1025            multithreaded: bool,
1026            extra_routine: F,
1027            permit_origin: Option<Vec<IpNet>>,
1028        ) -> Self {
1029            Self {
1030                frame_handler: MockFrameHandler::new(content_type, multithreaded, extra_routine),
1031                address: addr.into(),
1032                keepalive: None,
1033                shutdown_timeout_secs: Duration::from_secs(30),
1034                tls: MaybeTls::Raw(()),
1035                tls_client_metadata_key: None,
1036                receive_buffer_bytes: None,
1037                max_connection_duration_secs: None,
1038                max_connections: None,
1039                permit_origin,
1040            }
1041        }
1042    }
1043
1044    impl<F: Send + Sync + Clone + FnOnce() + 'static> MockUnixFrameHandler<F> {
1045        pub fn new(content_type: String, multithreaded: bool, extra_routine: F) -> Self {
1046            Self {
1047                frame_handler: MockFrameHandler::new(content_type, multithreaded, extra_routine),
1048                socket_path: tempfile::tempdir().unwrap().keep().join("unix_test"),
1049                socket_file_mode: None,
1050                socket_receive_buffer_size: None,
1051                socket_send_buffer_size: None,
1052            }
1053        }
1054    }
1055
1056    impl<F: Send + Sync + Clone + FnOnce() + 'static> MockFrameHandler<F> {
1057        pub fn new(content_type: String, multithreaded: bool, extra_routine: F) -> Self {
1058            Self {
1059                content_type,
1060                max_frame_length: bytesize::kib(100u64) as usize,
1061                multithreaded,
1062                max_frame_handling_tasks: 0,
1063                extra_task_handling_routine: extra_routine,
1064                host_key: Some(owned_value_path!("test_framestream")),
1065                timestamp_key: Some(owned_value_path!("my_timestamp")),
1066                source_type_key: Some(owned_value_path!("source_type")),
1067                log_namespace: LogNamespace::Legacy,
1068            }
1069        }
1070    }
1071
1072    impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockFrameHandler<F> {
1073        fn content_type(&self) -> String {
1074            self.content_type.clone()
1075        }
1076        fn max_frame_length(&self) -> usize {
1077            self.max_frame_length
1078        }
1079
1080        fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1081            let mut log_event = LogEvent::from(frame);
1082
1083            log_event.insert(
1084                log_schema().source_type_key_target_path().unwrap(),
1085                "framestream",
1086            );
1087            if let Some(host) = received_from {
1088                self.log_namespace.insert_source_metadata(
1089                    "framestream",
1090                    &mut log_event,
1091                    self.host_key.as_ref().map(LegacyKey::Overwrite),
1092                    path!("host"),
1093                    host,
1094                )
1095            }
1096
1097            (self.extra_task_handling_routine.clone())();
1098
1099            Some(log_event.into())
1100        }
1101
1102        fn multithreaded(&self) -> bool {
1103            self.multithreaded
1104        }
1105        fn max_frame_handling_tasks(&self) -> usize {
1106            self.max_frame_handling_tasks
1107        }
1108
1109        fn host_key(&self) -> &Option<OwnedValuePath> {
1110            &self.host_key
1111        }
1112
1113        fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1114            self.timestamp_key.as_ref()
1115        }
1116
1117        fn source_type_key(&self) -> Option<&OwnedValuePath> {
1118            self.source_type_key.as_ref()
1119        }
1120    }
1121
1122    impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockUnixFrameHandler<F> {
1123        fn content_type(&self) -> String {
1124            self.frame_handler.content_type()
1125        }
1126
1127        fn max_frame_length(&self) -> usize {
1128            self.frame_handler.max_frame_length()
1129        }
1130
1131        fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1132            self.frame_handler.handle_event(received_from, frame)
1133        }
1134
1135        fn multithreaded(&self) -> bool {
1136            self.frame_handler.multithreaded()
1137        }
1138
1139        fn max_frame_handling_tasks(&self) -> usize {
1140            self.frame_handler.max_frame_handling_tasks()
1141        }
1142
1143        fn host_key(&self) -> &Option<OwnedValuePath> {
1144            self.frame_handler.host_key()
1145        }
1146
1147        fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1148            self.frame_handler.timestamp_key()
1149        }
1150
1151        fn source_type_key(&self) -> Option<&OwnedValuePath> {
1152            self.frame_handler.source_type_key()
1153        }
1154    }
1155
1156    impl<F: Send + Sync + Clone + FnOnce() + 'static> UnixFrameHandler for MockUnixFrameHandler<F> {
1157        fn socket_path(&self) -> PathBuf {
1158            self.socket_path.clone()
1159        }
1160
1161        fn socket_file_mode(&self) -> Option<u32> {
1162            self.socket_file_mode
1163        }
1164
1165        fn socket_receive_buffer_size(&self) -> Option<usize> {
1166            self.socket_receive_buffer_size
1167        }
1168
1169        fn socket_send_buffer_size(&self) -> Option<usize> {
1170            self.socket_send_buffer_size
1171        }
1172    }
1173
1174    impl<F: Send + Sync + Clone + FnOnce() + 'static> FrameHandler for MockTcpFrameHandler<F> {
1175        fn content_type(&self) -> String {
1176            self.frame_handler.content_type()
1177        }
1178
1179        fn max_frame_length(&self) -> usize {
1180            self.frame_handler.max_frame_length()
1181        }
1182
1183        fn handle_event(&self, received_from: Option<Bytes>, frame: Bytes) -> Option<Event> {
1184            self.frame_handler.handle_event(received_from, frame)
1185        }
1186
1187        fn multithreaded(&self) -> bool {
1188            self.frame_handler.multithreaded()
1189        }
1190
1191        fn max_frame_handling_tasks(&self) -> usize {
1192            self.frame_handler.max_frame_handling_tasks()
1193        }
1194
1195        fn host_key(&self) -> &Option<OwnedValuePath> {
1196            self.frame_handler.host_key()
1197        }
1198
1199        fn timestamp_key(&self) -> Option<&OwnedValuePath> {
1200            self.frame_handler.timestamp_key()
1201        }
1202
1203        fn source_type_key(&self) -> Option<&OwnedValuePath> {
1204            self.frame_handler.source_type_key()
1205        }
1206    }
1207
1208    impl<F: Send + Sync + Clone + FnOnce() + 'static> TcpFrameHandler for MockTcpFrameHandler<F> {
1209        fn address(&self) -> SocketListenAddr {
1210            self.address
1211        }
1212
1213        fn keepalive(&self) -> Option<TcpKeepaliveConfig> {
1214            self.keepalive
1215        }
1216
1217        fn shutdown_timeout_secs(&self) -> Duration {
1218            self.shutdown_timeout_secs
1219        }
1220
1221        fn tls(&self) -> MaybeTlsSettings {
1222            self.tls.clone()
1223        }
1224
1225        fn tls_client_metadata_key(&self) -> Option<OwnedValuePath> {
1226            self.tls_client_metadata_key.clone()
1227        }
1228
1229        fn receive_buffer_bytes(&self) -> Option<usize> {
1230            self.receive_buffer_bytes
1231        }
1232
1233        fn max_connection_duration_secs(&self) -> Option<u64> {
1234            self.max_connection_duration_secs
1235        }
1236
1237        fn max_connections(&self) -> Option<u32> {
1238            self.max_connections
1239        }
1240
1241        fn insert_tls_client_metadata(&mut self, _: Option<CertificateMetadata>) {}
1242
1243        fn allowed_origins(&self) -> Option<&[IpNet]> {
1244            self.permit_origin.as_deref()
1245        }
1246    }
1247
1248    fn init_framestream_tcp(
1249        source_id: &str,
1250        addr: &SocketAddr,
1251        frame_handler: impl TcpFrameHandler + Send + Sync + Clone + 'static,
1252        pipeline: SourceSender,
1253    ) -> (JoinHandle<Result<(), ()>>, SourceShutdownCoordinator) {
1254        let source_id = ComponentKey::from(source_id);
1255        let mut shutdown = SourceShutdownCoordinator::default();
1256        let (shutdown_signal, _) = shutdown.register_source(&source_id, false);
1257        let server = build_framestream_tcp_source(frame_handler, shutdown_signal, pipeline)
1258            .expect("Failed to build framestream tcp source.");
1259
1260        let join_handle = tokio::spawn(server);
1261
1262        while std::net::TcpStream::connect(addr).is_err() {
1263            thread::sleep(Duration::from_millis(2));
1264        }
1265
1266        (join_handle, shutdown)
1267    }
1268
1269    fn init_framestream_unix(
1270        source_id: &str,
1271        frame_handler: impl UnixFrameHandler + Send + Sync + Clone + 'static,
1272        pipeline: SourceSender,
1273    ) -> (
1274        PathBuf,
1275        JoinHandle<Result<(), ()>>,
1276        SourceShutdownCoordinator,
1277    ) {
1278        let source_id = ComponentKey::from(source_id);
1279        let socket_path = frame_handler.socket_path();
1280        let mut shutdown = SourceShutdownCoordinator::default();
1281        let (shutdown_signal, _) = shutdown.register_source(&source_id, false);
1282        let server = build_framestream_unix_source(frame_handler, shutdown_signal, pipeline)
1283            .expect("Failed to build framestream unix source.");
1284
1285        let join_handle = tokio::spawn(server);
1286
1287        // Wait for server to accept traffic
1288        while std::os::unix::net::UnixStream::connect(&socket_path).is_err() {
1289            thread::sleep(Duration::from_millis(2));
1290        }
1291
1292        (socket_path, join_handle, shutdown)
1293    }
1294
1295    async fn make_tcp_stream(
1296        addr: SocketAddr,
1297    ) -> Framed<TcpStream, length_delimited::LengthDelimitedCodec> {
1298        let socket = TcpStream::connect(&addr).await.unwrap();
1299        Framed::new(socket, length_delimited::Builder::new().new_codec())
1300    }
1301
1302    async fn make_unix_stream(
1303        path: PathBuf,
1304    ) -> Framed<UnixStream, length_delimited::LengthDelimitedCodec> {
1305        let socket = UnixStream::connect(&path).await.unwrap();
1306        Framed::new(socket, length_delimited::Builder::new().new_codec())
1307    }
1308
1309    async fn send_data_frames<S: Sink<Bytes, Error = std::io::Error> + Unpin>(
1310        sock_sink: &mut S,
1311        frames: Vec<Result<Bytes, std::io::Error>>,
1312    ) {
1313        let mut stream = stream::iter(frames.into_iter());
1314        //send and send_all consume the sink
1315        _ = sock_sink.send_all(&mut stream).await;
1316    }
1317
1318    async fn send_control_frame<S: Sink<Bytes, Error = std::io::Error> + Unpin>(
1319        sock_sink: &mut S,
1320        frame: Bytes,
1321    ) {
1322        send_data_frames(sock_sink, vec![Ok(Bytes::new()), Ok(frame)]).await; //send empty frame to say we are control frame
1323    }
1324
1325    fn create_control_frame(header: ControlHeader) -> Bytes {
1326        Bytes::from(header.to_u32().to_be_bytes().to_vec())
1327    }
1328
1329    fn create_control_frame_with_content(
1330        header: ControlHeader,
1331        content_types: Vec<Bytes>,
1332    ) -> Bytes {
1333        let mut frame = BytesMut::from(&header.to_u32().to_be_bytes()[..]);
1334        for content_type in content_types {
1335            frame.extend(ControlField::ContentType.to_u32().to_be_bytes());
1336            frame.extend((content_type.len() as u32).to_be_bytes());
1337            frame.extend(content_type.clone());
1338        }
1339        Bytes::from(frame)
1340    }
1341
1342    fn assert_accept_frame(frame: &mut BytesMut, expected_content_type: Bytes) {
1343        //frame should start with 4 bytes saying ACCEPT
1344
1345        assert_eq!(&frame[..4], &ControlHeader::Accept.to_u32().to_be_bytes(),);
1346        frame.advance(4);
1347        //next should be content type field
1348        assert_eq!(
1349            &frame[..4],
1350            &ControlField::ContentType.to_u32().to_be_bytes(),
1351        );
1352        frame.advance(4);
1353        //next should be length of content_type
1354        assert_eq!(
1355            &frame[..4],
1356            &(expected_content_type.len() as u32).to_be_bytes(),
1357        );
1358        frame.advance(4);
1359        //rest should be content type
1360        assert_eq!(&frame[..], &expected_content_type[..]);
1361    }
1362
1363    fn create_frame_handler(multithreaded: bool) -> impl UnixFrameHandler + Send + Sync + Clone {
1364        MockUnixFrameHandler::new("test_content".to_string(), multithreaded, move || {})
1365    }
1366
1367    fn create_tcp_frame_handler(
1368        addr: SocketAddr,
1369        multithreaded: bool,
1370        permit_origin: Option<Vec<IpNet>>,
1371    ) -> impl TcpFrameHandler + Send + Sync + Clone {
1372        MockTcpFrameHandler::new(
1373            addr,
1374            "test_content".to_string(),
1375            multithreaded,
1376            move || {},
1377            permit_origin,
1378        )
1379    }
1380
1381    async fn signal_shutdown(source_name: &str, shutdown: &mut SourceShutdownCoordinator) {
1382        // Now signal to the Source to shut down.
1383        let deadline = Instant::now() + Duration::from_secs(10);
1384        let id = ComponentKey::from(source_name);
1385        let shutdown_complete = shutdown.shutdown_source(&id, deadline);
1386        let shutdown_success = shutdown_complete.await;
1387        assert!(shutdown_success);
1388    }
1389
1390    async fn test_normal_framestream<
1391        T: Sink<Bytes, Error = std::io::Error> + Unpin,
1392        U: Stream<Item = Result<BytesMut, std::io::Error>> + Unpin,
1393        V: Stream<Item = Event> + Unpin,
1394    >(
1395        source_name: &str,
1396        mut sock_sink: T,
1397        mut sock_stream: U,
1398        rx: V,
1399        mut shutdown: SourceShutdownCoordinator,
1400        source_handle: JoinHandle<Result<(), ()>>,
1401    ) {
1402        //1 - send READY frame (with content_type)
1403        let content_type = Bytes::from(&b"test_content"[..]);
1404        let ready_msg =
1405            create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1406        send_control_frame(&mut sock_sink, ready_msg).await;
1407
1408        //2 - wait for ACCEPT frame
1409        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1410        //take second element, because first will be empty (signifying control frame)
1411        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1412        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1413
1414        //3 - send START frame
1415        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Start)).await;
1416
1417        //4 - send data
1418        send_data_frames(
1419            &mut sock_sink,
1420            vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1421        )
1422        .await;
1423        let events = collect_n(rx, 2).await;
1424
1425        //5 - send STOP frame
1426        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1427
1428        let message_key = log_schema().message_key().unwrap().to_string();
1429        assert!(
1430            events
1431                .iter()
1432                .any(|e| e.as_log()[&message_key] == "hello".into())
1433        );
1434        assert!(
1435            events
1436                .iter()
1437                .any(|e| e.as_log()[&message_key] == "world".into())
1438        );
1439
1440        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1441
1442        // Ensure source actually shut down successfully.
1443        signal_shutdown(source_name, &mut shutdown).await;
1444        _ = source_handle.await.unwrap();
1445    }
1446
1447    async fn test_multiple_content_types<
1448        T: Sink<Bytes, Error = std::io::Error> + Unpin,
1449        U: Stream<Item = Result<BytesMut, std::io::Error>> + Unpin,
1450    >(
1451        source_name: &str,
1452        mut sock_sink: T,
1453        mut sock_stream: U,
1454        mut shutdown: SourceShutdownCoordinator,
1455        source_handle: JoinHandle<Result<(), ()>>,
1456    ) {
1457        //1 - send READY frame (with content_type)
1458        let content_type = Bytes::from(&b"test_content"[..]);
1459        let ready_msg = create_control_frame_with_content(
1460            ControlHeader::Ready,
1461            vec![Bytes::from(&b"test_content2"[..]), content_type.clone()],
1462        ); //can have multiple content types
1463        send_control_frame(&mut sock_sink, ready_msg).await;
1464
1465        //2 - wait for ACCEPT frame
1466        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1467
1468        //take second element, because first will be empty (signifying control frame)
1469        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1470        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1471
1472        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1473
1474        // Ensure source actually shut down successfully.
1475        signal_shutdown(source_name, &mut shutdown).await;
1476        _ = source_handle.await.unwrap();
1477    }
1478
1479    #[tokio::test(flavor = "multi_thread")]
1480    #[should_panic]
1481    async fn blocked_framestream_tcp() {
1482        let source_name = "test_source";
1483        let (tx, rx) = SourceSender::new_test();
1484        let addr = next_addr();
1485        let (source_handle, shutdown) = init_framestream_tcp(
1486            source_name,
1487            &addr,
1488            create_tcp_frame_handler(addr, false, Some(vec![])),
1489            tx,
1490        );
1491        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1492
1493        test_normal_framestream(
1494            source_name,
1495            sock_sink,
1496            sock_stream,
1497            rx,
1498            shutdown,
1499            source_handle,
1500        )
1501        .await;
1502    }
1503
1504    #[tokio::test(flavor = "multi_thread")]
1505    async fn normal_framestream_singlethreaded_tcp() {
1506        let source_name = "test_source";
1507        let (tx, rx) = SourceSender::new_test();
1508        let addr = next_addr();
1509        let (source_handle, shutdown) = init_framestream_tcp(
1510            source_name,
1511            &addr,
1512            create_tcp_frame_handler(addr, false, None),
1513            tx,
1514        );
1515        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1516
1517        test_normal_framestream(
1518            source_name,
1519            sock_sink,
1520            sock_stream,
1521            rx,
1522            shutdown,
1523            source_handle,
1524        )
1525        .await;
1526    }
1527
1528    #[tokio::test(flavor = "multi_thread")]
1529    async fn normal_framestream_singlethreaded_unix() {
1530        let source_name = "test_source";
1531        let (tx, rx) = SourceSender::new_test();
1532        let (path, source_handle, shutdown) =
1533            init_framestream_unix(source_name, create_frame_handler(false), tx);
1534        let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1535
1536        test_normal_framestream(
1537            source_name,
1538            sock_sink,
1539            sock_stream,
1540            rx,
1541            shutdown,
1542            source_handle,
1543        )
1544        .await;
1545    }
1546
1547    #[tokio::test(flavor = "multi_thread")]
1548    async fn normal_framestream_multithreaded_tcp() {
1549        let source_name = "test_source";
1550        let (tx, rx) = SourceSender::new_test();
1551        let addr = next_addr();
1552        let (source_handle, shutdown) = init_framestream_tcp(
1553            source_name,
1554            &addr,
1555            create_tcp_frame_handler(addr, true, None),
1556            tx,
1557        );
1558        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1559
1560        test_normal_framestream(
1561            source_name,
1562            sock_sink,
1563            sock_stream,
1564            rx,
1565            shutdown,
1566            source_handle,
1567        )
1568        .await;
1569    }
1570
1571    #[tokio::test(flavor = "multi_thread")]
1572    async fn normal_framestream_multithreaded_unix() {
1573        let source_name = "test_source";
1574        let (tx, rx) = SourceSender::new_test();
1575        let (path, source_handle, shutdown) =
1576            init_framestream_unix(source_name, create_frame_handler(true), tx);
1577        let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1578
1579        test_normal_framestream(
1580            source_name,
1581            sock_sink,
1582            sock_stream,
1583            rx,
1584            shutdown,
1585            source_handle,
1586        )
1587        .await;
1588    }
1589
1590    #[tokio::test(flavor = "multi_thread")]
1591    async fn multiple_content_types_tcp() {
1592        let source_name = "test_source";
1593        let (tx, _) = SourceSender::new_test();
1594        let addr = next_addr();
1595        let (source_handle, shutdown) = init_framestream_tcp(
1596            source_name,
1597            &addr,
1598            create_tcp_frame_handler(addr, false, None),
1599            tx,
1600        );
1601        let (sock_sink, sock_stream) = make_tcp_stream(addr).await.split();
1602
1603        test_multiple_content_types(source_name, sock_sink, sock_stream, shutdown, source_handle)
1604            .await;
1605    }
1606
1607    #[tokio::test(flavor = "multi_thread")]
1608    async fn multiple_content_types_unix() {
1609        let source_name = "test_source";
1610        let (tx, _) = SourceSender::new_test();
1611        let (path, source_handle, shutdown) =
1612            init_framestream_unix(source_name, create_frame_handler(false), tx);
1613        let (sock_sink, sock_stream) = make_unix_stream(path).await.split();
1614
1615        test_multiple_content_types(source_name, sock_sink, sock_stream, shutdown, source_handle)
1616            .await;
1617    }
1618
1619    #[tokio::test(flavor = "multi_thread")]
1620    async fn wrong_content_type() {
1621        let source_name = "test_source";
1622        let (tx, _) = SourceSender::new_test();
1623        let (path, source_handle, mut shutdown) =
1624            init_framestream_unix(source_name, create_frame_handler(false), tx);
1625        let (mut sock_sink, mut sock_stream) = make_unix_stream(path).await.split();
1626
1627        //1 - send READY frame (with WRONG content_type)
1628        let ready_msg = create_control_frame_with_content(
1629            ControlHeader::Ready,
1630            vec![Bytes::from(&b"test_content2"[..])],
1631        ); //can have multiple content types
1632        send_control_frame(&mut sock_sink, ready_msg).await;
1633
1634        //2 - send READY frame (with RIGHT content_type)
1635        let content_type = Bytes::from(&b"test_content"[..]);
1636        let ready_msg =
1637            create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1638        send_control_frame(&mut sock_sink, ready_msg).await;
1639
1640        //3 - wait for ACCEPT frame
1641        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1642
1643        //take second element, because first will be empty (signifying control frame)
1644        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1645        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1646
1647        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1648
1649        // Ensure source actually shut down successfully.
1650        signal_shutdown(source_name, &mut shutdown).await;
1651        _ = source_handle.await.unwrap();
1652    }
1653
1654    #[tokio::test(flavor = "multi_thread")]
1655    async fn data_too_soon() {
1656        let source_name = "test_source";
1657        let (tx, rx) = SourceSender::new_test();
1658        let (path, source_handle, mut shutdown) =
1659            init_framestream_unix(source_name, create_frame_handler(false), tx);
1660        let (mut sock_sink, mut sock_stream) = make_unix_stream(path).await.split();
1661
1662        //1 - send data frame (too soon!)
1663        send_data_frames(
1664            &mut sock_sink,
1665            vec![Ok(Bytes::from("bad")), Ok(Bytes::from("data"))],
1666        )
1667        .await;
1668
1669        //2 - send READY frame (with content_type)
1670        let content_type = Bytes::from(&b"test_content"[..]);
1671        let ready_msg =
1672            create_control_frame_with_content(ControlHeader::Ready, vec![content_type.clone()]);
1673        send_control_frame(&mut sock_sink, ready_msg).await;
1674
1675        //3 - wait for ACCEPT frame
1676        let mut frame_vec = collect_n_stream(&mut sock_stream, 2).await;
1677
1678        //take second element, because first will be empty (signifying control frame)
1679        assert_eq!(frame_vec[0].as_ref().unwrap().len(), 0);
1680        assert_accept_frame(frame_vec[1].as_mut().unwrap(), content_type);
1681
1682        //4 - send START frame
1683        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Start)).await;
1684
1685        //5 - send data (will go through)
1686        send_data_frames(
1687            &mut sock_sink,
1688            vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1689        )
1690        .await;
1691        let events = collect_n(rx, 2).await;
1692
1693        //6 - send STOP frame
1694        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1695
1696        assert_eq!(
1697            events[0].as_log()[log_schema().message_key().unwrap().to_string()],
1698            "hello".into(),
1699        );
1700        assert_eq!(
1701            events[1].as_log()[log_schema().message_key().unwrap().to_string()],
1702            "world".into(),
1703        );
1704
1705        drop(sock_stream); //explicitly drop the stream so we don't get warnings about not using it
1706
1707        // Ensure source actually shut down successfully.
1708        signal_shutdown(source_name, &mut shutdown).await;
1709        _ = source_handle.await.unwrap();
1710    }
1711
1712    #[tokio::test(flavor = "multi_thread")]
1713    async fn unidirectional_framestream() {
1714        let source_name = "test_source";
1715        let (tx, rx) = SourceSender::new_test();
1716        let (path, source_handle, mut shutdown) =
1717            init_framestream_unix(source_name, create_frame_handler(false), tx);
1718        let (mut sock_sink, _) = make_unix_stream(path).await.split();
1719
1720        //1 - send START frame (with content_type)
1721        let content_type = Bytes::from(&b"test_content"[..]);
1722        let start_msg = create_control_frame_with_content(ControlHeader::Start, vec![content_type]);
1723        send_control_frame(&mut sock_sink, start_msg).await;
1724
1725        //4 - send data
1726        send_data_frames(
1727            &mut sock_sink,
1728            vec![Ok(Bytes::from("hello")), Ok(Bytes::from("world"))],
1729        )
1730        .await;
1731        let events = collect_n(rx, 2).await;
1732
1733        //5 - send STOP frame
1734        send_control_frame(&mut sock_sink, create_control_frame(ControlHeader::Stop)).await;
1735
1736        assert_eq!(
1737            events[0].as_log()[log_schema().message_key().unwrap().to_string()],
1738            "hello".into(),
1739        );
1740        assert_eq!(
1741            events[1].as_log()[log_schema().message_key().unwrap().to_string()],
1742            "world".into(),
1743        );
1744
1745        // Ensure source actually shut down successfully.
1746        signal_shutdown(source_name, &mut shutdown).await;
1747        _ = source_handle.await.unwrap();
1748    }
1749
1750    #[tokio::test(flavor = "multi_thread")]
1751    async fn test_spawn_event_handling_tasks() {
1752        let (out, rx) = SourceSender::new_test();
1753
1754        let max_frame_handling_tasks = 20;
1755        let active_task_nums = Arc::new(AtomicUsize::new(0));
1756        let active_task_nums_copy = Arc::clone(&active_task_nums);
1757        let max_task_nums_reached = Arc::new(AtomicUsize::new(0));
1758        let max_task_nums_reached_copy = Arc::clone(&max_task_nums_reached);
1759
1760        let mut join_handles = vec![];
1761        let active_task_nums_copy_2 = Arc::clone(&active_task_nums_copy);
1762        let extra_routine = move || {
1763            thread::sleep(Duration::from_millis(10));
1764            max_task_nums_reached_copy.fetch_max(
1765                active_task_nums_copy_2.load(Ordering::Acquire),
1766                Ordering::AcqRel,
1767            );
1768        };
1769
1770        let total_events = max_frame_handling_tasks * 10;
1771
1772        join_handles.push(tokio::spawn(async move {
1773            future::ready({
1774                let events = collect_n(rx, total_events).await;
1775                assert_eq!(total_events, events.len(), "Missed events");
1776            })
1777            .await;
1778        }));
1779
1780        for i in 0..total_events {
1781            join_handles.push(
1782                spawn_event_handling_tasks(
1783                    Bytes::from(format!("event_{i}")),
1784                    MockFrameHandler::new("test_content".to_string(), true, extra_routine.clone()),
1785                    out.clone(),
1786                    None,
1787                    Arc::clone(&active_task_nums_copy),
1788                    max_frame_handling_tasks,
1789                )
1790                .await,
1791            );
1792        }
1793
1794        future::join_all(join_handles).await;
1795
1796        let final_task_nums = active_task_nums.load(Ordering::Acquire);
1797        assert_eq!(
1798            0, final_task_nums,
1799            "There should be NO left-over tasks at the end"
1800        );
1801
1802        let max_task_nums_reached_value = max_task_nums_reached.load(Ordering::Acquire);
1803        assert!(
1804            max_task_nums_reached_value > 1,
1805            "MultiThreaded mode does NOT work"
1806        );
1807        assert!(
1808            (max_task_nums_reached_value - max_frame_handling_tasks) < 2,
1809            "Max number of tasks at any given time should NOT Exceed max_frame_handling_tasks too much"
1810        );
1811    }
1812}