1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#![allow(missing_docs)]
use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};

use pin_project::pin_project;
use tokio::io::{AsyncRead, ReadBuf, Result as IoResult};

pub trait VecAsyncReadExt: AsyncRead {
    /// Read data from this reader until the given future resolves.
    fn allow_read_until<F>(self, until: F) -> AllowReadUntil<Self, F>
    where
        Self: Sized,
        F: Future<Output = ()>,
    {
        AllowReadUntil {
            reader: self,
            until,
        }
    }
}

impl<S> VecAsyncReadExt for S where S: AsyncRead {}

/// A AsyncRead combinator which reads from a reader until a future resolves.
#[pin_project]
#[derive(Clone, Debug)]
pub struct AllowReadUntil<S, F> {
    #[pin]
    reader: S,
    #[pin]
    until: F,
}

impl<S, F> AllowReadUntil<S, F> {
    pub const fn get_ref(&self) -> &S {
        &self.reader
    }

    pub fn get_mut(&mut self) -> &mut S {
        &mut self.reader
    }
}

impl<S, F> AsyncRead for AllowReadUntil<S, F>
where
    S: AsyncRead,
    F: Future<Output = ()>,
{
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<IoResult<()>> {
        let this = self.project();
        match this.until.poll(cx) {
            Poll::Ready(_) => Poll::Ready(Ok(())),
            Poll::Pending => this.reader.poll_read(cx, buf),
        }
    }
}

#[cfg(test)]
mod tests {
    use futures::FutureExt;
    use tokio::{
        fs::{remove_file, File},
        io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
    };

    use super::*;
    use crate::{shutdown::ShutdownSignal, test_util::temp_file};

    #[tokio::test]
    async fn test_read_line_without_shutdown() {
        let shutdown = ShutdownSignal::noop();
        let temp_path = temp_file();
        let write_file = File::create(temp_path.clone()).await.unwrap();
        let read_file = File::open(temp_path.clone()).await.unwrap();

        // Wrapper AsyncRead
        let read_file = read_file.allow_read_until(shutdown.clone().map(|_| ()));

        let mut reader = BufReader::new(read_file);
        let mut writer = BufWriter::new(write_file);

        writer.write_all(b"First line\n").await.unwrap();
        writer.flush().await.unwrap();

        // Test one of the AsyncBufRead extension functions
        let mut line_one = String::new();
        _ = reader.read_line(&mut line_one).await;

        assert_eq!("First line\n", line_one);

        writer.write_all(b"Second line\n").await.unwrap();
        writer.flush().await.unwrap();

        let mut line_two = String::new();
        _ = reader.read_line(&mut line_two).await;

        assert_eq!("Second line\n", line_two);

        remove_file(temp_path).await.unwrap();
    }

    #[tokio::test]
    async fn test_read_line_with_shutdown() {
        let (trigger_shutdown, shutdown, _) = ShutdownSignal::new_wired();
        let temp_path = temp_file();
        let write_file = File::create(temp_path.clone()).await.unwrap();
        let read_file = File::open(temp_path.clone()).await.unwrap();

        // Wrapper AsyncRead
        let read_file = read_file.allow_read_until(shutdown.clone().map(|_| ()));

        let mut reader = BufReader::new(read_file);
        let mut writer = BufWriter::new(write_file);

        writer.write_all(b"First line\n").await.unwrap();
        writer.flush().await.unwrap();

        // Test one of the AsyncBufRead extension functions
        let mut line_one = String::new();
        _ = reader.read_line(&mut line_one).await;

        assert_eq!("First line\n", line_one);

        drop(trigger_shutdown);

        writer.write_all(b"Second line\n").await.unwrap();
        writer.flush().await.unwrap();

        let mut line_two = String::new();
        _ = reader.read_line(&mut line_two).await;

        assert_eq!("", line_two);

        remove_file(temp_path).await.unwrap();
    }
}