vector/
async_read.rs

1#![allow(missing_docs)]
2use std::{
3    future::Future,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use pin_project::pin_project;
9use tokio::io::{AsyncRead, ReadBuf, Result as IoResult};
10
11pub trait VecAsyncReadExt: AsyncRead {
12    /// Read data from this reader until the given future resolves.
13    fn allow_read_until<F>(self, until: F) -> AllowReadUntil<Self, F>
14    where
15        Self: Sized,
16        F: Future<Output = ()>,
17    {
18        AllowReadUntil {
19            reader: self,
20            until,
21        }
22    }
23}
24
25impl<S> VecAsyncReadExt for S where S: AsyncRead {}
26
27/// A AsyncRead combinator which reads from a reader until a future resolves.
28#[pin_project]
29#[derive(Clone, Debug)]
30pub struct AllowReadUntil<S, F> {
31    #[pin]
32    reader: S,
33    #[pin]
34    until: F,
35}
36
37impl<S, F> AllowReadUntil<S, F> {
38    pub const fn get_ref(&self) -> &S {
39        &self.reader
40    }
41
42    pub const fn get_mut(&mut self) -> &mut S {
43        &mut self.reader
44    }
45}
46
47impl<S, F> AsyncRead for AllowReadUntil<S, F>
48where
49    S: AsyncRead,
50    F: Future<Output = ()>,
51{
52    fn poll_read(
53        self: Pin<&mut Self>,
54        cx: &mut Context,
55        buf: &mut ReadBuf<'_>,
56    ) -> Poll<IoResult<()>> {
57        let this = self.project();
58        match this.until.poll(cx) {
59            Poll::Ready(_) => Poll::Ready(Ok(())),
60            Poll::Pending => this.reader.poll_read(cx, buf),
61        }
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use futures::FutureExt;
68    use tokio::{
69        fs::{remove_file, File},
70        io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
71    };
72
73    use super::*;
74    use crate::{shutdown::ShutdownSignal, test_util::temp_file};
75
76    #[tokio::test]
77    async fn test_read_line_without_shutdown() {
78        let shutdown = ShutdownSignal::noop();
79        let temp_path = temp_file();
80        let write_file = File::create(temp_path.clone()).await.unwrap();
81        let read_file = File::open(temp_path.clone()).await.unwrap();
82
83        // Wrapper AsyncRead
84        let read_file = read_file.allow_read_until(shutdown.clone().map(|_| ()));
85
86        let mut reader = BufReader::new(read_file);
87        let mut writer = BufWriter::new(write_file);
88
89        writer.write_all(b"First line\n").await.unwrap();
90        writer.flush().await.unwrap();
91
92        // Test one of the AsyncBufRead extension functions
93        let mut line_one = String::new();
94        _ = reader.read_line(&mut line_one).await;
95
96        assert_eq!("First line\n", line_one);
97
98        writer.write_all(b"Second line\n").await.unwrap();
99        writer.flush().await.unwrap();
100
101        let mut line_two = String::new();
102        _ = reader.read_line(&mut line_two).await;
103
104        assert_eq!("Second line\n", line_two);
105
106        remove_file(temp_path).await.unwrap();
107    }
108
109    #[tokio::test]
110    async fn test_read_line_with_shutdown() {
111        let (trigger_shutdown, shutdown, _) = ShutdownSignal::new_wired();
112        let temp_path = temp_file();
113        let write_file = File::create(temp_path.clone()).await.unwrap();
114        let read_file = File::open(temp_path.clone()).await.unwrap();
115
116        // Wrapper AsyncRead
117        let read_file = read_file.allow_read_until(shutdown.clone().map(|_| ()));
118
119        let mut reader = BufReader::new(read_file);
120        let mut writer = BufWriter::new(write_file);
121
122        writer.write_all(b"First line\n").await.unwrap();
123        writer.flush().await.unwrap();
124
125        // Test one of the AsyncBufRead extension functions
126        let mut line_one = String::new();
127        _ = reader.read_line(&mut line_one).await;
128
129        assert_eq!("First line\n", line_one);
130
131        drop(trigger_shutdown);
132
133        writer.write_all(b"Second line\n").await.unwrap();
134        writer.flush().await.unwrap();
135
136        let mut line_two = String::new();
137        _ = reader.read_line(&mut line_two).await;
138
139        assert_eq!("", line_two);
140
141        remove_file(temp_path).await.unwrap();
142    }
143}