vector/sources/util/
wrappers.rs1use std::{
2 io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use pin_project::pin_project;
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9
10pub trait AfterReadExt {
11 fn after_read<F>(self, after_read: F) -> AfterRead<Self, F>
12 where
13 Self: Sized;
14}
15
16impl<T: AsyncRead + AsyncWrite> AfterReadExt for T {
17 fn after_read<F>(self, after_read: F) -> AfterRead<Self, F> {
18 AfterRead::new(self, after_read)
19 }
20}
21
22#[pin_project]
25pub struct AfterRead<T, F> {
26 #[pin]
27 inner: T,
28 after_read: F,
29}
30
31impl<T, F> AfterRead<T, F> {
32 pub const fn new(inner: T, after_read: F) -> Self {
33 Self { inner, after_read }
34 }
35
36 #[cfg(feature = "listenfd")]
37 pub const fn get_ref(&self) -> &T {
38 &self.inner
39 }
40
41 #[cfg(all(unix, feature = "sources-utils-net-unix"))]
42 pub const fn get_mut_ref(&mut self) -> &mut T {
43 &mut self.inner
44 }
45}
46
47impl<T: AsyncRead, F> AsyncRead for AfterRead<T, F>
48where
49 F: Fn(usize),
50{
51 fn poll_read(
52 self: Pin<&mut Self>,
53 cx: &mut Context<'_>,
54 buf: &mut ReadBuf<'_>,
55 ) -> Poll<tokio::io::Result<()>> {
56 let before = buf.filled().len();
57 let this = self.project();
58 let result = this.inner.poll_read(cx, buf);
59 if let Poll::Ready(Ok(())) = result {
60 (this.after_read)(buf.filled().len() - before);
61 }
62 result
63 }
64}
65
66impl<T: AsyncWrite, F> AsyncWrite for AfterRead<T, F> {
67 fn poll_write(
68 self: Pin<&mut Self>,
69 cx: &mut Context<'_>,
70 buf: &[u8],
71 ) -> Poll<Result<usize, io::Error>> {
72 self.project().inner.poll_write(cx, buf)
73 }
74
75 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
76 self.project().inner.poll_flush(cx)
77 }
78
79 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
80 self.project().inner.poll_shutdown(cx)
81 }
82}