1use std::pin::Pin;
2
3use futures::{
4 task::{Context, Poll},
5 {Stream, StreamExt},
6};
7
8const DEFAULT_CAPACITY: usize = 1024;
9
10pub struct ReadyFrames<T, U, E> {
16 inner: T,
17 enqueued: Vec<U>,
18 enqueued_size: usize,
19 error_slot: Option<E>,
20 enqueued_limit: usize,
21}
22
23impl<T, U, E> ReadyFrames<T, U, E>
24where
25 T: Stream<Item = Result<(U, usize), E>> + Unpin,
26 U: Unpin,
27 E: Unpin,
28{
29 pub fn new(inner: T) -> Self {
31 Self::with_capacity(inner, DEFAULT_CAPACITY)
32 }
33
34 pub fn with_capacity(inner: T, cap: usize) -> Self {
40 Self {
41 inner,
42 enqueued: Vec::with_capacity(cap),
43 enqueued_size: 0,
44 error_slot: None,
45 enqueued_limit: cap,
46 }
47 }
48
49 pub const fn get_ref(&self) -> &T {
51 &self.inner
52 }
53
54 pub const fn get_mut(&mut self) -> &mut T {
56 &mut self.inner
57 }
58
59 fn flush(&mut self) -> (Vec<U>, usize) {
60 let frames = std::mem::take(&mut self.enqueued);
61 let size = self.enqueued_size;
62 self.enqueued_size = 0;
63 (frames, size)
64 }
65}
66
67impl<T, U, E> Stream for ReadyFrames<T, U, E>
68where
69 T: Stream<Item = Result<(U, usize), E>> + Unpin,
70 U: Unpin,
71 E: Unpin,
72{
73 type Item = Result<(Vec<U>, usize), E>;
74
75 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
76 if let Some(error) = self.error_slot.take() {
77 return Poll::Ready(Some(Err(error)));
78 }
79
80 loop {
81 match self.inner.poll_next_unpin(cx) {
82 Poll::Ready(Some(Ok((frame, size)))) => {
83 self.enqueued.push(frame);
84 self.enqueued_size += size;
85 if self.enqueued.len() >= self.enqueued_limit {
86 return Poll::Ready(Some(Ok(self.flush())));
87 }
88 }
89 Poll::Ready(Some(Err(error))) => {
90 if self.enqueued.is_empty() {
91 return Poll::Ready(Some(Err(error)));
92 } else {
93 self.error_slot = Some(error);
94 return Poll::Ready(Some(Ok(self.flush())));
95 }
96 }
97 Poll::Ready(None) => {
98 if !self.enqueued.is_empty() {
99 return Poll::Ready(Some(Ok(self.flush())));
100 } else {
101 return Poll::Ready(None);
102 }
103 }
104 Poll::Pending => {
105 if !self.enqueued.is_empty() {
106 return Poll::Ready(Some(Ok(self.flush())));
107 } else {
108 return Poll::Pending;
109 }
110 }
111 }
112 }
113 }
114}
115
116#[cfg(test)]
117mod test {
118 use futures::{channel::mpsc, poll, task::Poll, SinkExt, StreamExt};
119
120 use super::ReadyFrames;
121
122 #[tokio::test]
123 async fn idle_passthrough() {
124 let (mut tx, rx) = mpsc::channel::<Result<(&str, usize), &str>>(5);
125 let mut rf = ReadyFrames::with_capacity(rx, 2);
126
127 assert_eq!(Poll::Pending, poll!(rf.next()));
128
129 tx.send(Ok(("foo", 1))).await.unwrap();
130
131 assert_eq!(Poll::Ready(Some(Ok((vec!["foo"], 1)))), poll!(rf.next()));
132 assert_eq!(Poll::Pending, poll!(rf.next()));
133 }
134
135 #[tokio::test]
136 async fn limits_to_capacity() {
137 let (mut tx, rx) = mpsc::channel::<Result<(&str, usize), &str>>(5);
138 let mut rf = ReadyFrames::with_capacity(rx, 2);
139
140 tx.send(Ok(("foo", 2))).await.unwrap();
141 tx.send(Ok(("bar", 3))).await.unwrap();
142
143 assert_eq!(
144 Poll::Ready(Some(Ok((vec!["foo", "bar"], 5)))),
145 poll!(rf.next())
146 );
147 assert_eq!(Poll::Pending, poll!(rf.next()));
148
149 tx.send(Ok(("foo", 4))).await.unwrap();
150 tx.send(Ok(("bar", 5))).await.unwrap();
151 tx.send(Ok(("baz", 6))).await.unwrap();
152
153 assert_eq!(
154 Poll::Ready(Some(Ok((vec!["foo", "bar"], 9)))),
155 poll!(rf.next())
156 );
157 assert_eq!(Poll::Ready(Some(Ok((vec!["baz"], 6)))), poll!(rf.next()));
158 assert_eq!(Poll::Pending, poll!(rf.next()));
159 }
160
161 #[tokio::test]
162 async fn error_passing() {
163 let (mut tx, rx) = mpsc::channel::<Result<(&str, usize), &str>>(5);
164 let mut rf = ReadyFrames::with_capacity(rx, 2);
165
166 tx.send(Err("oops")).await.unwrap();
167
168 assert_eq!(Poll::Ready(Some(Err("oops"))), poll!(rf.next()));
169 assert_eq!(Poll::Pending, poll!(rf.next()));
170
171 tx.send(Ok(("foo", 7))).await.unwrap();
172 tx.send(Err("oops")).await.unwrap();
173
174 assert_eq!(Poll::Ready(Some(Ok((vec!["foo"], 7)))), poll!(rf.next()));
175 assert_eq!(Poll::Ready(Some(Err("oops"))), poll!(rf.next()));
176 assert_eq!(Poll::Pending, poll!(rf.next()));
177 }
178
179 #[tokio::test]
180 async fn closing() {
181 let (mut tx, rx) = mpsc::channel::<Result<(&str, usize), &str>>(5);
182 let mut rf = ReadyFrames::with_capacity(rx, 2);
183
184 tx.send(Ok(("foo", 8))).await.unwrap();
185 tx.send(Ok(("bar", 9))).await.unwrap();
186 tx.send(Ok(("baz", 10))).await.unwrap();
187 drop(tx);
188
189 assert_eq!(
190 Poll::Ready(Some(Ok((vec!["foo", "bar"], 17)))),
191 poll!(rf.next())
192 );
193 assert_eq!(Poll::Ready(Some(Ok((vec!["baz"], 10)))), poll!(rf.next()));
194 assert_eq!(Poll::Ready(None), poll!(rf.next()));
195 }
196}