vector/sources/util/
wrappers.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use pin_project::pin_project;
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9
10pub trait AfterReadExt {
11    fn after_read<F>(self, after_read: F) -> AfterRead<Self, F>
12    where
13        Self: Sized;
14}
15
16impl<T: AsyncRead + AsyncWrite> AfterReadExt for T {
17    fn after_read<F>(self, after_read: F) -> AfterRead<Self, F> {
18        AfterRead::new(self, after_read)
19    }
20}
21
22/// This wraps the inner socket and emits `BytesReceived` with the
23/// actual number of bytes read before handling framing.
24#[pin_project]
25pub struct AfterRead<T, F> {
26    #[pin]
27    inner: T,
28    after_read: F,
29}
30
31impl<T, F> AfterRead<T, F> {
32    pub const fn new(inner: T, after_read: F) -> Self {
33        Self { inner, after_read }
34    }
35
36    #[cfg(feature = "listenfd")]
37    pub const fn get_ref(&self) -> &T {
38        &self.inner
39    }
40
41    #[cfg(all(unix, feature = "sources-utils-net-unix"))]
42    pub const fn get_mut_ref(&mut self) -> &mut T {
43        &mut self.inner
44    }
45}
46
47impl<T: AsyncRead, F> AsyncRead for AfterRead<T, F>
48where
49    F: Fn(usize),
50{
51    fn poll_read(
52        self: Pin<&mut Self>,
53        cx: &mut Context<'_>,
54        buf: &mut ReadBuf<'_>,
55    ) -> Poll<tokio::io::Result<()>> {
56        let before = buf.filled().len();
57        let this = self.project();
58        let result = this.inner.poll_read(cx, buf);
59        if let Poll::Ready(Ok(())) = result {
60            (this.after_read)(buf.filled().len() - before);
61        }
62        result
63    }
64}
65
66impl<T: AsyncWrite, F> AsyncWrite for AfterRead<T, F> {
67    fn poll_write(
68        self: Pin<&mut Self>,
69        cx: &mut Context<'_>,
70        buf: &[u8],
71    ) -> Poll<Result<usize, io::Error>> {
72        self.project().inner.poll_write(cx, buf)
73    }
74
75    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
76        self.project().inner.poll_flush(cx)
77    }
78
79    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
80        self.project().inner.poll_shutdown(cx)
81    }
82}