vector_stream/
concurrent_map.rs

1use std::{
2    future::Future,
3    num::NonZeroUsize,
4    panic,
5    pin::Pin,
6    task::{ready, Context, Poll},
7};
8
9use futures_util::{
10    stream::{Fuse, FuturesOrdered},
11    Stream, StreamExt,
12};
13use pin_project::pin_project;
14use tokio::task::JoinHandle;
15
16#[pin_project]
17pub struct ConcurrentMap<St, T>
18where
19    St: Stream,
20    T: Send + 'static,
21{
22    #[pin]
23    stream: Fuse<St>,
24    limit: Option<NonZeroUsize>,
25    in_flight: FuturesOrdered<JoinHandle<T>>,
26    f: Box<dyn Fn(St::Item) -> Pin<Box<dyn Future<Output = T> + Send + 'static>> + Send>,
27}
28
29impl<St, T> ConcurrentMap<St, T>
30where
31    St: Stream,
32    T: Send + 'static,
33{
34    pub fn new<F>(stream: St, limit: Option<NonZeroUsize>, f: F) -> Self
35    where
36        F: Fn(St::Item) -> Pin<Box<dyn Future<Output = T> + Send + 'static>> + Send + 'static,
37    {
38        Self {
39            stream: stream.fuse(),
40            limit,
41            in_flight: FuturesOrdered::new(),
42            f: Box::new(f),
43        }
44    }
45}
46
47impl<St, T> Stream for ConcurrentMap<St, T>
48where
49    St: Stream,
50    T: Send + 'static,
51{
52    type Item = T;
53
54    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
55        let mut this = self.project();
56
57        // The underlying stream is done, and we have no more in-flight futures.
58        if this.stream.is_done() && this.in_flight.is_empty() {
59            return Poll::Ready(None);
60        }
61
62        loop {
63            let can_poll_stream = match this.limit {
64                None => true,
65                Some(limit) => this.in_flight.len() < limit.get(),
66            };
67
68            if can_poll_stream {
69                match this.stream.as_mut().poll_next(cx) {
70                    // Even if there's no items from the underlying stream, we still have the in-flight
71                    // futures to check, so we don't return just yet.
72                    Poll::Pending | Poll::Ready(None) => break,
73                    Poll::Ready(Some(item)) => {
74                        let fut = (this.f)(item);
75                        let handle = tokio::spawn(fut);
76                        this.in_flight.push_back(handle);
77                    }
78                }
79            } else {
80                // We're at our in-flight limit, so stop generating tasks for the moment.
81                break;
82            }
83        }
84
85        match ready!(this.in_flight.poll_next_unpin(cx)) {
86            // If the stream is done and there is no futures managed by FuturesOrdered,
87            // we must end the stream by returning Poll::Ready(None).
88            None if this.stream.is_done() => Poll::Ready(None),
89            // If there are no in-flight futures managed by FuturesOrdered but the underlying
90            // stream is not done, then we must keep polling that stream.
91            None => Poll::Pending,
92            Some(result) => match result {
93                Ok(item) => Poll::Ready(Some(item)),
94                Err(e) => {
95                    if let Ok(reason) = e.try_into_panic() {
96                        // Resume the panic here on the calling task.
97                        panic::resume_unwind(reason);
98                    } else {
99                        // The task was cancelled, which makes no sense, because _we_ hold the join
100                        // handle. Only sensible thing to do is panic, because this is a bug.
101                        panic!("concurrent map task cancelled outside of our control");
102                    }
103                }
104            },
105        }
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use futures_util::stream::StreamExt;
113
114    #[tokio::test]
115    async fn test_concurrent_map_on_empty_stream() {
116        let stream = futures_util::stream::empty::<()>();
117        let limit = Some(NonZeroUsize::new(2).unwrap());
118        // The `as _` is required to construct a `dyn Future`
119        let f = |()| Box::pin(async move {}) as _;
120        let mut concurrent_map = ConcurrentMap::new(stream, limit, f);
121
122        // Assert that the stream does not hang
123        assert_eq!(concurrent_map.next().await, None);
124    }
125}