1#![allow(missing_docs)]
2use std::{
3 future::Future,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use pin_project::pin_project;
9use tokio::io::{AsyncRead, ReadBuf, Result as IoResult};
10
11pub trait VecAsyncReadExt: AsyncRead {
12 fn allow_read_until<F>(self, until: F) -> AllowReadUntil<Self, F>
14 where
15 Self: Sized,
16 F: Future<Output = ()>,
17 {
18 AllowReadUntil {
19 reader: self,
20 until,
21 }
22 }
23}
24
25impl<S> VecAsyncReadExt for S where S: AsyncRead {}
26
27#[pin_project]
29#[derive(Clone, Debug)]
30pub struct AllowReadUntil<S, F> {
31 #[pin]
32 reader: S,
33 #[pin]
34 until: F,
35}
36
37impl<S, F> AllowReadUntil<S, F> {
38 pub const fn get_ref(&self) -> &S {
39 &self.reader
40 }
41
42 pub const fn get_mut(&mut self) -> &mut S {
43 &mut self.reader
44 }
45}
46
47impl<S, F> AsyncRead for AllowReadUntil<S, F>
48where
49 S: AsyncRead,
50 F: Future<Output = ()>,
51{
52 fn poll_read(
53 self: Pin<&mut Self>,
54 cx: &mut Context,
55 buf: &mut ReadBuf<'_>,
56 ) -> Poll<IoResult<()>> {
57 let this = self.project();
58 match this.until.poll(cx) {
59 Poll::Ready(_) => Poll::Ready(Ok(())),
60 Poll::Pending => this.reader.poll_read(cx, buf),
61 }
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use futures::FutureExt;
68 use tokio::{
69 fs::{remove_file, File},
70 io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
71 };
72
73 use super::*;
74 use crate::{shutdown::ShutdownSignal, test_util::temp_file};
75
76 #[tokio::test]
77 async fn test_read_line_without_shutdown() {
78 let shutdown = ShutdownSignal::noop();
79 let temp_path = temp_file();
80 let write_file = File::create(temp_path.clone()).await.unwrap();
81 let read_file = File::open(temp_path.clone()).await.unwrap();
82
83 let read_file = read_file.allow_read_until(shutdown.clone().map(|_| ()));
85
86 let mut reader = BufReader::new(read_file);
87 let mut writer = BufWriter::new(write_file);
88
89 writer.write_all(b"First line\n").await.unwrap();
90 writer.flush().await.unwrap();
91
92 let mut line_one = String::new();
94 _ = reader.read_line(&mut line_one).await;
95
96 assert_eq!("First line\n", line_one);
97
98 writer.write_all(b"Second line\n").await.unwrap();
99 writer.flush().await.unwrap();
100
101 let mut line_two = String::new();
102 _ = reader.read_line(&mut line_two).await;
103
104 assert_eq!("Second line\n", line_two);
105
106 remove_file(temp_path).await.unwrap();
107 }
108
109 #[tokio::test]
110 async fn test_read_line_with_shutdown() {
111 let (trigger_shutdown, shutdown, _) = ShutdownSignal::new_wired();
112 let temp_path = temp_file();
113 let write_file = File::create(temp_path.clone()).await.unwrap();
114 let read_file = File::open(temp_path.clone()).await.unwrap();
115
116 let read_file = read_file.allow_read_until(shutdown.clone().map(|_| ()));
118
119 let mut reader = BufReader::new(read_file);
120 let mut writer = BufWriter::new(write_file);
121
122 writer.write_all(b"First line\n").await.unwrap();
123 writer.flush().await.unwrap();
124
125 let mut line_one = String::new();
127 _ = reader.read_line(&mut line_one).await;
128
129 assert_eq!("First line\n", line_one);
130
131 drop(trigger_shutdown);
132
133 writer.write_all(b"Second line\n").await.unwrap();
134 writer.flush().await.unwrap();
135
136 let mut line_two = String::new();
137 _ = reader.read_line(&mut line_two).await;
138
139 assert_eq!("", line_two);
140
141 remove_file(temp_path).await.unwrap();
142 }
143}