1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use async_stream::stream;
use futures::{Stream, StreamExt};
use std::time::Duration;

#[derive(Default)]
pub struct Emitter<T> {
    values: Vec<T>,
}

impl<T> Emitter<T> {
    pub fn new() -> Self {
        Self { values: vec![] }
    }
    pub fn emit(&mut self, value: T) {
        self.values.push(value);
    }
}

/// Similar to `stream.filter_map(..).flatten(..)` but also allows checking for expired events
/// and flushing when the input stream ends.
pub fn map_with_expiration<S, T, M, E, F>(
    initial_state: S,
    input: impl Stream<Item = T> + 'static,
    expiration_interval: Duration,
    // called for each event
    mut map_fn: M,
    // called periodically to allow expiring internal state
    mut expiration_fn: E,
    // called once at the end of the input stream
    mut flush_fn: F,
) -> impl Stream<Item = T>
where
    M: FnMut(&mut S, T, &mut Emitter<T>),
    E: FnMut(&mut S, &mut Emitter<T>),
    F: FnMut(&mut S, &mut Emitter<T>),
{
    let mut state = initial_state;
    let mut flush_stream = tokio::time::interval(expiration_interval);

    Box::pin(stream! {
        futures_util::pin_mut!(input);
              loop {
                let mut emitter = Emitter::<T>::new();
                let done = tokio::select! {
                    _ = flush_stream.tick() => {
                        expiration_fn(&mut state, &mut emitter);
                        false
                    }
                    maybe_event = input.next() => {
                      match maybe_event {
                        None => {
                            flush_fn(&mut state, &mut emitter);
                            true
                        }
                        Some(event) => {
                            map_fn(&mut state, event, &mut emitter);
                            false
                        }
                      }
                    }
                };
                yield futures::stream::iter(emitter.values.into_iter());
                if done { break }
              }

    })
    .flatten()
}

#[cfg(test)]
mod test {
    use super::*;

    #[tokio::test]
    async fn test_simple() {
        let input = futures::stream::iter([1, 2, 3]);

        let map_fn = |state: &mut i32, event, emitter: &mut Emitter<i32>| {
            *state += event;
            emitter.emit(*state);
        };
        let expiration_fn = |_state: &mut i32, _emitter: &mut Emitter<i32>| {
            // do nothing
        };
        let flush_fn = |state: &mut i32, emitter: &mut Emitter<i32>| {
            emitter.emit(*state);
        };
        let stream: Vec<i32> = map_with_expiration(
            0_i32,
            input,
            Duration::from_secs(100),
            map_fn,
            expiration_fn,
            flush_fn,
        )
        .take(4)
        .collect()
        .await;

        assert_eq!(vec![1, 3, 6, 6], stream);
    }

    #[tokio::test]
    async fn test_expiration() {
        // an input that never ends (to test expiration)
        let input = futures::stream::iter([1, 2, 3]).chain(futures::stream::pending());

        let map_fn = |state: &mut i32, event, emitter: &mut Emitter<i32>| {
            *state += event;
            emitter.emit(*state);
        };
        let expiration_fn = |state: &mut i32, emitter: &mut Emitter<i32>| {
            emitter.emit(*state);
        };
        let flush_fn = |_state: &mut i32, _emitter: &mut Emitter<i32>| {
            // do nothing
        };
        let stream: Vec<i32> = map_with_expiration(
            0_i32,
            input,
            Duration::from_secs(1),
            map_fn,
            expiration_fn,
            flush_fn,
        )
        .take(4)
        .collect()
        .await;

        assert_eq!(vec![1, 3, 6, 6], stream);
    }
}