vector/sources/util/
wrappers.rs

1use bytes::BytesMut;
2use futures::Stream;
3use std::{
4    io,
5    pin::Pin,
6    task::{Context, Poll},
7};
8use tokio_util::codec::Decoder;
9use vector_lib::codecs::DecoderFramedRead;
10
11use pin_project::pin_project;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13
14pub trait AfterReadExt {
15    fn after_read<F>(self, after_read: F) -> AfterRead<Self, F>
16    where
17        Self: Sized;
18}
19
20impl<T: AsyncRead + AsyncWrite> AfterReadExt for T {
21    fn after_read<F>(self, after_read: F) -> AfterRead<Self, F> {
22        AfterRead::new(self, after_read)
23    }
24}
25
26/// This wraps the inner socket and emits `BytesReceived` with the
27/// actual number of bytes read before handling framing.
28#[pin_project]
29pub struct AfterRead<T, F> {
30    #[pin]
31    inner: T,
32    after_read: F,
33}
34
35impl<T, F> AfterRead<T, F> {
36    pub const fn new(inner: T, after_read: F) -> Self {
37        Self { inner, after_read }
38    }
39
40    #[cfg(feature = "listenfd")]
41    pub const fn get_ref(&self) -> &T {
42        &self.inner
43    }
44
45    #[cfg(all(unix, feature = "sources-utils-net-unix"))]
46    pub const fn get_mut_ref(&mut self) -> &mut T {
47        &mut self.inner
48    }
49}
50
51impl<T: AsyncRead, F> AsyncRead for AfterRead<T, F>
52where
53    F: Fn(usize),
54{
55    fn poll_read(
56        self: Pin<&mut Self>,
57        cx: &mut Context<'_>,
58        buf: &mut ReadBuf<'_>,
59    ) -> Poll<tokio::io::Result<()>> {
60        let before = buf.filled().len();
61        let this = self.project();
62        let result = this.inner.poll_read(cx, buf);
63        if let Poll::Ready(Ok(())) = result {
64            (this.after_read)(buf.filled().len() - before);
65        }
66        result
67    }
68}
69
70impl<T: AsyncWrite, F> AsyncWrite for AfterRead<T, F> {
71    fn poll_write(
72        self: Pin<&mut Self>,
73        cx: &mut Context<'_>,
74        buf: &[u8],
75    ) -> Poll<Result<usize, io::Error>> {
76        self.project().inner.poll_write(cx, buf)
77    }
78
79    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
80        self.project().inner.poll_flush(cx)
81    }
82
83    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
84        self.project().inner.poll_shutdown(cx)
85    }
86}
87
88pub enum DecoderError<E> {
89    IO(io::Error),
90    Other(E),
91}
92
93impl<E> DecoderError<E>
94where
95    E: From<io::Error>,
96{
97    fn into_inner(self) -> E {
98        match self {
99            DecoderError::IO(e) => E::from(e),
100            DecoderError::Other(e) => e,
101        }
102    }
103}
104
105impl<E> From<io::Error> for DecoderError<E> {
106    fn from(e: io::Error) -> Self {
107        DecoderError::IO(e)
108    }
109}
110
111pub struct LenientFramedReadDecoder<D> {
112    inner: D,
113}
114
115impl<D> LenientFramedReadDecoder<D>
116where
117    D: Decoder,
118{
119    pub const fn new(inner: D) -> Self {
120        Self { inner }
121    }
122}
123
124impl<D> Decoder for LenientFramedReadDecoder<D>
125where
126    D: Decoder,
127    D::Error: From<io::Error>,
128{
129    type Item = D::Item;
130    type Error = DecoderError<D::Error>;
131
132    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
133        self.inner.decode(src).map_err(DecoderError::Other)
134    }
135
136    fn decode_eof(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
137        self.inner.decode_eof(src).map_err(DecoderError::Other)
138    }
139}
140
141/// A wrapper around an `FramedRead` that silently ignores `ConnectionReset`
142/// errors if the frame buffer is empty.
143#[pin_project]
144pub struct LenientFramedRead<T, D> {
145    #[pin]
146    inner: DecoderFramedRead<T, LenientFramedReadDecoder<D>>,
147}
148
149impl<T, D> LenientFramedRead<T, D>
150where
151    T: AsyncRead,
152    D: Decoder,
153{
154    /// Creates a new `LenientFramedRead` with the given `decoder`.
155    pub fn new(inner: T, decoder: D) -> Self {
156        Self {
157            inner: DecoderFramedRead::new(inner, LenientFramedReadDecoder::new(decoder)),
158        }
159    }
160
161    /// Returns a reference to the underlying I/O stream wrapped by
162    /// `FramedRead`.
163    ///
164    /// Note that care should be taken to not tamper with the underlying stream
165    /// of data coming in as it may corrupt the stream of frames otherwise
166    /// being worked with.
167    pub fn get_ref(&self) -> &T {
168        self.inner.get_ref()
169    }
170
171    /// Returns a mutable reference to the underlying I/O stream wrapped by
172    /// `FramedRead`.
173    ///
174    /// Note that care should be taken to not tamper with the underlying stream
175    /// of data coming in as it may corrupt the stream of frames otherwise
176    /// being worked with.
177    pub fn get_mut(&mut self) -> &mut T {
178        self.inner.get_mut()
179    }
180}
181
182impl<T, D> Stream for LenientFramedRead<T, D>
183where
184    T: AsyncRead + Unpin,
185    D: Decoder,
186{
187    type Item = Result<D::Item, D::Error>;
188
189    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
190        let mut this = self.project();
191        match this.inner.as_mut().poll_next(cx) {
192            Poll::Ready(Some(Err(DecoderError::IO(e))))
193                if e.kind() == io::ErrorKind::ConnectionReset =>
194            {
195                let buffer = this.inner.read_buffer();
196
197                if buffer.is_empty() {
198                    // Clean RST - no partial data, treat as EOF
199                    Poll::Ready(None)
200                } else {
201                    // Partial frame in buffer
202                    Poll::Ready(Some(Err(D::Error::from(e))))
203                }
204            }
205            other => other.map_err(|e| e.into_inner()),
206        }
207    }
208}