vector/sources/util/
wrappers.rs1use bytes::BytesMut;
2use futures::Stream;
3use std::{
4 io,
5 pin::Pin,
6 task::{Context, Poll},
7};
8use tokio_util::codec::Decoder;
9use vector_lib::codecs::DecoderFramedRead;
10
11use pin_project::pin_project;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13
14pub trait AfterReadExt {
15 fn after_read<F>(self, after_read: F) -> AfterRead<Self, F>
16 where
17 Self: Sized;
18}
19
20impl<T: AsyncRead + AsyncWrite> AfterReadExt for T {
21 fn after_read<F>(self, after_read: F) -> AfterRead<Self, F> {
22 AfterRead::new(self, after_read)
23 }
24}
25
26#[pin_project]
29pub struct AfterRead<T, F> {
30 #[pin]
31 inner: T,
32 after_read: F,
33}
34
35impl<T, F> AfterRead<T, F> {
36 pub const fn new(inner: T, after_read: F) -> Self {
37 Self { inner, after_read }
38 }
39
40 #[cfg(feature = "listenfd")]
41 pub const fn get_ref(&self) -> &T {
42 &self.inner
43 }
44
45 #[cfg(all(unix, feature = "sources-utils-net-unix"))]
46 pub const fn get_mut_ref(&mut self) -> &mut T {
47 &mut self.inner
48 }
49}
50
51impl<T: AsyncRead, F> AsyncRead for AfterRead<T, F>
52where
53 F: Fn(usize),
54{
55 fn poll_read(
56 self: Pin<&mut Self>,
57 cx: &mut Context<'_>,
58 buf: &mut ReadBuf<'_>,
59 ) -> Poll<tokio::io::Result<()>> {
60 let before = buf.filled().len();
61 let this = self.project();
62 let result = this.inner.poll_read(cx, buf);
63 if let Poll::Ready(Ok(())) = result {
64 (this.after_read)(buf.filled().len() - before);
65 }
66 result
67 }
68}
69
70impl<T: AsyncWrite, F> AsyncWrite for AfterRead<T, F> {
71 fn poll_write(
72 self: Pin<&mut Self>,
73 cx: &mut Context<'_>,
74 buf: &[u8],
75 ) -> Poll<Result<usize, io::Error>> {
76 self.project().inner.poll_write(cx, buf)
77 }
78
79 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
80 self.project().inner.poll_flush(cx)
81 }
82
83 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
84 self.project().inner.poll_shutdown(cx)
85 }
86}
87
88pub enum DecoderError<E> {
89 IO(io::Error),
90 Other(E),
91}
92
93impl<E> DecoderError<E>
94where
95 E: From<io::Error>,
96{
97 fn into_inner(self) -> E {
98 match self {
99 DecoderError::IO(e) => E::from(e),
100 DecoderError::Other(e) => e,
101 }
102 }
103}
104
105impl<E> From<io::Error> for DecoderError<E> {
106 fn from(e: io::Error) -> Self {
107 DecoderError::IO(e)
108 }
109}
110
111pub struct LenientFramedReadDecoder<D> {
112 inner: D,
113}
114
115impl<D> LenientFramedReadDecoder<D>
116where
117 D: Decoder,
118{
119 pub const fn new(inner: D) -> Self {
120 Self { inner }
121 }
122}
123
124impl<D> Decoder for LenientFramedReadDecoder<D>
125where
126 D: Decoder,
127 D::Error: From<io::Error>,
128{
129 type Item = D::Item;
130 type Error = DecoderError<D::Error>;
131
132 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
133 self.inner.decode(src).map_err(DecoderError::Other)
134 }
135
136 fn decode_eof(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
137 self.inner.decode_eof(src).map_err(DecoderError::Other)
138 }
139}
140
141#[pin_project]
144pub struct LenientFramedRead<T, D> {
145 #[pin]
146 inner: DecoderFramedRead<T, LenientFramedReadDecoder<D>>,
147}
148
149impl<T, D> LenientFramedRead<T, D>
150where
151 T: AsyncRead,
152 D: Decoder,
153{
154 pub fn new(inner: T, decoder: D) -> Self {
156 Self {
157 inner: DecoderFramedRead::new(inner, LenientFramedReadDecoder::new(decoder)),
158 }
159 }
160
161 pub fn get_ref(&self) -> &T {
168 self.inner.get_ref()
169 }
170
171 pub fn get_mut(&mut self) -> &mut T {
178 self.inner.get_mut()
179 }
180}
181
182impl<T, D> Stream for LenientFramedRead<T, D>
183where
184 T: AsyncRead + Unpin,
185 D: Decoder,
186{
187 type Item = Result<D::Item, D::Error>;
188
189 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
190 let mut this = self.project();
191 match this.inner.as_mut().poll_next(cx) {
192 Poll::Ready(Some(Err(DecoderError::IO(e))))
193 if e.kind() == io::ErrorKind::ConnectionReset =>
194 {
195 let buffer = this.inner.read_buffer();
196
197 if buffer.is_empty() {
198 Poll::Ready(None)
200 } else {
201 Poll::Ready(Some(Err(D::Error::from(e))))
203 }
204 }
205 other => other.map_err(|e| e.into_inner()),
206 }
207 }
208}