vector_stream/
concurrent_map.rs1use std::{
2 future::Future,
3 num::NonZeroUsize,
4 panic,
5 pin::Pin,
6 task::{ready, Context, Poll},
7};
8
9use futures_util::{
10 stream::{Fuse, FuturesOrdered},
11 Stream, StreamExt,
12};
13use pin_project::pin_project;
14use tokio::task::JoinHandle;
15
16#[pin_project]
17pub struct ConcurrentMap<St, T>
18where
19 St: Stream,
20 T: Send + 'static,
21{
22 #[pin]
23 stream: Fuse<St>,
24 limit: Option<NonZeroUsize>,
25 in_flight: FuturesOrdered<JoinHandle<T>>,
26 f: Box<dyn Fn(St::Item) -> Pin<Box<dyn Future<Output = T> + Send + 'static>> + Send>,
27}
28
29impl<St, T> ConcurrentMap<St, T>
30where
31 St: Stream,
32 T: Send + 'static,
33{
34 pub fn new<F>(stream: St, limit: Option<NonZeroUsize>, f: F) -> Self
35 where
36 F: Fn(St::Item) -> Pin<Box<dyn Future<Output = T> + Send + 'static>> + Send + 'static,
37 {
38 Self {
39 stream: stream.fuse(),
40 limit,
41 in_flight: FuturesOrdered::new(),
42 f: Box::new(f),
43 }
44 }
45}
46
47impl<St, T> Stream for ConcurrentMap<St, T>
48where
49 St: Stream,
50 T: Send + 'static,
51{
52 type Item = T;
53
54 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
55 let mut this = self.project();
56
57 if this.stream.is_done() && this.in_flight.is_empty() {
59 return Poll::Ready(None);
60 }
61
62 loop {
63 let can_poll_stream = match this.limit {
64 None => true,
65 Some(limit) => this.in_flight.len() < limit.get(),
66 };
67
68 if can_poll_stream {
69 match this.stream.as_mut().poll_next(cx) {
70 Poll::Pending | Poll::Ready(None) => break,
73 Poll::Ready(Some(item)) => {
74 let fut = (this.f)(item);
75 let handle = tokio::spawn(fut);
76 this.in_flight.push_back(handle);
77 }
78 }
79 } else {
80 break;
82 }
83 }
84
85 match ready!(this.in_flight.poll_next_unpin(cx)) {
86 None if this.stream.is_done() => Poll::Ready(None),
89 None => Poll::Pending,
92 Some(result) => match result {
93 Ok(item) => Poll::Ready(Some(item)),
94 Err(e) => {
95 if let Ok(reason) = e.try_into_panic() {
96 panic::resume_unwind(reason);
98 } else {
99 panic!("concurrent map task cancelled outside of our control");
102 }
103 }
104 },
105 }
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use futures_util::stream::StreamExt;
113
114 #[tokio::test]
115 async fn test_concurrent_map_on_empty_stream() {
116 let stream = futures_util::stream::empty::<()>();
117 let limit = Some(NonZeroUsize::new(2).unwrap());
118 let f = |()| Box::pin(async move {}) as _;
120 let mut concurrent_map = ConcurrentMap::new(stream, limit, f);
121
122 assert_eq!(concurrent_map.next().await, None);
124 }
125}