vector_stream/
driver.rs

1use std::{collections::VecDeque, fmt, future::poll_fn, task::Poll};
2
3use futures::{poll, FutureExt, Stream, StreamExt, TryFutureExt};
4use tokio::{pin, select};
5use tower::Service;
6use tracing::Instrument;
7use vector_common::internal_event::emit;
8use vector_common::internal_event::{
9    register, ByteSize, BytesSent, CallError, InternalEventHandle as _, PollReadyError, Registered,
10    RegisteredEventCache, SharedString, TaggedEventsSent,
11};
12use vector_common::request_metadata::{GroupedCountByteSize, MetaDescriptive};
13use vector_core::event::{EventFinalizers, EventStatus, Finalizable};
14
15use super::FuturesUnorderedCount;
16
17pub trait DriverResponse {
18    fn event_status(&self) -> EventStatus;
19    fn events_sent(&self) -> &GroupedCountByteSize;
20
21    /// Return the number of bytes that were sent in the request that returned this response.
22    // TODO, remove the default implementation once all sinks have
23    // implemented this function.
24    fn bytes_sent(&self) -> Option<usize> {
25        None
26    }
27}
28
29/// Drives the interaction between a stream of items and a service which processes them
30/// asynchronously.
31///
32/// `Driver`, as a high-level, facilitates taking items from an arbitrary `Stream` and pushing them
33/// through a `Service`, spawning each call to the service so that work can be run concurrently,
34/// managing waiting for the service to be ready before processing more items, and so on.
35///
36/// Additionally, `Driver` handles event finalization, which triggers acknowledgements
37/// to the source or disk buffer.
38///
39/// This capability is parameterized so any implementation which can define how to interpret the
40/// response for each request, as well as define how many events a request is compromised of, can be
41/// used with `Driver`.
42pub struct Driver<St, Svc> {
43    input: St,
44    service: Svc,
45    protocol: Option<SharedString>,
46}
47
48impl<St, Svc> Driver<St, Svc> {
49    pub fn new(input: St, service: Svc) -> Self {
50        Self {
51            input,
52            service,
53            protocol: None,
54        }
55    }
56
57    /// Set the protocol name for this driver.
58    ///
59    /// If this is set, the driver will fetch and use the `bytes_sent` value from responses in a
60    /// `BytesSent` event.
61    #[must_use]
62    pub fn protocol(mut self, protocol: impl Into<SharedString>) -> Self {
63        self.protocol = Some(protocol.into());
64        self
65    }
66}
67
68impl<St, Svc> Driver<St, Svc>
69where
70    St: Stream,
71    St::Item: Finalizable + MetaDescriptive,
72    Svc: Service<St::Item>,
73    Svc::Error: fmt::Debug + 'static,
74    Svc::Future: Send + 'static,
75    Svc::Response: DriverResponse,
76{
77    /// Runs the driver until the input stream is exhausted.
78    ///
79    /// All in-flight calls to the provided `service` will also be completed before `run` returns.
80    ///
81    /// # Errors
82    ///
83    /// The return type is mostly to simplify caller code.
84    /// An error is currently only returned if a service returns an error from `poll_ready`
85    pub async fn run(self) -> Result<(), ()> {
86        let mut in_flight = FuturesUnorderedCount::new();
87        let mut next_batch: Option<VecDeque<St::Item>> = None;
88        let mut seq_num = 0usize;
89
90        let Self {
91            input,
92            mut service,
93            protocol,
94        } = self;
95
96        let batched_input = input.ready_chunks(1024);
97        pin!(batched_input);
98
99        let bytes_sent = protocol.map(|protocol| register(BytesSent { protocol }));
100        let events_sent = RegisteredEventCache::new(());
101
102        loop {
103            // Core behavior of the loop:
104            // - always check to see if we have any response futures that have completed
105            //  -- if so, handling acking as many events as we can (ordering matters)
106            // - if we have a "current" batch, try to send each request in it to the service
107            //   -- if we can't drain all requests from the batch due to lack of service readiness,
108            //   then put the batch back and try to send the rest of it when the service is ready
109            //   again
110            // - if we have no "current" batch, but there is an available batch from our input
111            //   stream, grab that batch and store it as our current batch
112            //
113            // Essentially, we bounce back and forth between "grab the new batch from the input
114            // stream" and "send all requests in the batch to our service" which _could be trivially
115            // modeled with a normal imperative loop.  However, we want to be able to interleave the
116            // acknowledgement of responses to allow buffers and sources to continue making forward
117            // progress, which necessitates a more complex weaving of logic.  Using `select!` is
118            // more code, and requires a more careful eye than blindly doing
119            // "get_next_batch().await; process_batch().await", but it does make doing the complex
120            // logic easier than if we tried to interleave it ourselves with an imperative-style loop.
121
122            select! {
123                // Using `biased` ensures we check the branches in the order they're written, since
124                // the default behavior of the `select!` macro is to randomly order branches as a
125                // means of ensuring scheduling fairness.
126                biased;
127
128                // One or more of our service calls have completed.
129                Some(_count) = in_flight.next(), if !in_flight.is_empty() => {}
130
131                // We've got an input batch to process and the service is ready to accept a request.
132                maybe_ready = poll_fn(|cx| service.poll_ready(cx)), if next_batch.is_some() => {
133                    let mut batch = next_batch.take()
134                        .unwrap_or_else(|| unreachable!("batch should be populated"));
135
136                    let mut maybe_ready = Some(maybe_ready);
137                    while !batch.is_empty() {
138                        // Make sure the service is ready to take another request.
139                        let maybe_ready = match maybe_ready.take() {
140                            Some(ready) => Poll::Ready(ready),
141                            None => poll!(poll_fn(|cx| service.poll_ready(cx))),
142                        };
143
144                        let svc = match maybe_ready {
145                            Poll::Ready(Ok(())) => &mut service,
146                            Poll::Ready(Err(error)) => {
147                                emit(PollReadyError{ error });
148                                return Err(())
149                            }
150                            Poll::Pending => {
151                                next_batch = Some(batch);
152                                break
153                            },
154                        };
155
156                        let mut req = batch.pop_front().unwrap_or_else(|| unreachable!("batch should not be empty"));
157                        seq_num += 1;
158                        let request_id = seq_num;
159
160                        trace!(
161                            message = "Submitting service request.",
162                            in_flight_requests = in_flight.len(),
163                            request_id,
164                        );
165                        let finalizers = req.take_finalizers();
166                        let bytes_sent = bytes_sent.clone();
167                        let events_sent = events_sent.clone();
168                        let event_count = req.get_metadata().event_count();
169
170                        let fut = svc.call(req)
171                            .err_into()
172                            .map(move |result| Self::handle_response(
173                                result,
174                                request_id,
175                                finalizers,
176                                event_count,
177                                bytes_sent.as_ref(),
178                                &events_sent,
179                            ))
180                            .instrument(info_span!("request", request_id).or_current());
181
182                        in_flight.push(fut);
183                    }
184                }
185
186                // We've received some items from the input stream.
187                Some(reqs) = batched_input.next(), if next_batch.is_none() => {
188                    next_batch = Some(reqs.into());
189                }
190
191                else => break
192            }
193        }
194
195        Ok(())
196    }
197
198    fn handle_response(
199        result: Result<Svc::Response, Svc::Error>,
200        request_id: usize,
201        finalizers: EventFinalizers,
202        event_count: usize,
203        bytes_sent: Option<&Registered<BytesSent>>,
204        events_sent: &RegisteredEventCache<(), TaggedEventsSent>,
205    ) {
206        match result {
207            Err(error) => {
208                Self::emit_call_error(Some(error), request_id, event_count);
209                finalizers.update_status(EventStatus::Rejected);
210            }
211            Ok(response) => {
212                trace!(message = "Service call succeeded.", request_id);
213                finalizers.update_status(response.event_status());
214                if response.event_status() == EventStatus::Delivered {
215                    if let Some(bytes_sent) = bytes_sent {
216                        if let Some(byte_size) = response.bytes_sent() {
217                            bytes_sent.emit(ByteSize(byte_size));
218                        }
219                    }
220
221                    response.events_sent().emit_event(events_sent);
222
223                // This condition occurs specifically when the `HttpBatchService::call()` is called *within* the `Service::call()`
224                } else if response.event_status() == EventStatus::Rejected {
225                    Self::emit_call_error(None, request_id, event_count);
226                    finalizers.update_status(EventStatus::Rejected);
227                }
228            }
229        }
230        drop(finalizers); // suppress "argument not consumed" warning
231    }
232
233    /// Emit the `Error` and `EventsDropped` internal events.
234    /// This scenario occurs after retries have been attempted.
235    fn emit_call_error(error: Option<Svc::Error>, request_id: usize, count: usize) {
236        emit(CallError {
237            error,
238            request_id,
239            count,
240        });
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use std::{
247        future::Future,
248        pin::Pin,
249        sync::{atomic::AtomicUsize, atomic::Ordering, Arc},
250        task::{ready, Context, Poll},
251        time::Duration,
252    };
253
254    use futures_util::stream;
255    use rand::{prelude::StdRng, SeedableRng};
256    use rand_distr::{Distribution, Pareto};
257    use tokio::{
258        sync::{OwnedSemaphorePermit, Semaphore},
259        time::sleep,
260    };
261    use tokio_util::sync::PollSemaphore;
262    use tower::Service;
263    use vector_common::{
264        finalization::{BatchNotifier, EventFinalizer, EventFinalizers, EventStatus, Finalizable},
265        json_size::JsonSize,
266        request_metadata::{GroupedCountByteSize, RequestMetadata},
267    };
268    use vector_common::{internal_event::CountByteSize, request_metadata::MetaDescriptive};
269
270    use super::{Driver, DriverResponse};
271
272    type Counter = Arc<AtomicUsize>;
273
274    #[derive(Debug)]
275    struct DelayRequest(EventFinalizers, RequestMetadata);
276
277    impl DelayRequest {
278        fn new(value: usize, counter: &Counter) -> Self {
279            let (batch, receiver) = BatchNotifier::new_with_receiver();
280            let counter = Arc::clone(counter);
281            tokio::spawn(async move {
282                receiver.await;
283                counter.fetch_add(value, Ordering::Relaxed);
284            });
285            Self(
286                EventFinalizers::new(EventFinalizer::new(batch)),
287                RequestMetadata::default(),
288            )
289        }
290    }
291
292    impl Finalizable for DelayRequest {
293        fn take_finalizers(&mut self) -> vector_core::event::EventFinalizers {
294            std::mem::take(&mut self.0)
295        }
296    }
297
298    impl MetaDescriptive for DelayRequest {
299        fn get_metadata(&self) -> &RequestMetadata {
300            &self.1
301        }
302
303        fn metadata_mut(&mut self) -> &mut RequestMetadata {
304            &mut self.1
305        }
306    }
307
308    struct DelayResponse {
309        events_sent: GroupedCountByteSize,
310    }
311
312    impl DelayResponse {
313        fn new() -> Self {
314            Self {
315                events_sent: CountByteSize(1, JsonSize::new(1)).into(),
316            }
317        }
318    }
319
320    impl DriverResponse for DelayResponse {
321        fn event_status(&self) -> EventStatus {
322            EventStatus::Delivered
323        }
324
325        fn events_sent(&self) -> &GroupedCountByteSize {
326            &self.events_sent
327        }
328    }
329
330    // Generic service that takes a usize and applies an arbitrary delay to returning it.
331    struct DelayService {
332        semaphore: PollSemaphore,
333        permit: Option<OwnedSemaphorePermit>,
334        jitter: Pareto<f64>,
335        jitter_gen: StdRng,
336        lower_bound_us: u64,
337        upper_bound_us: u64,
338    }
339
340    // Clippy is unhappy about all of our f64/u64 shuffling.  We don't actually care about losing
341    // the fractional part of 20,459.13142 or whatever.  It just doesn't matter for this test.
342    #[allow(clippy::cast_possible_truncation)]
343    #[allow(clippy::cast_precision_loss)]
344    impl DelayService {
345        pub(crate) fn new(permits: usize, lower_bound: Duration, upper_bound: Duration) -> Self {
346            assert!(upper_bound > lower_bound);
347            Self {
348                semaphore: PollSemaphore::new(Arc::new(Semaphore::new(permits))),
349                permit: None,
350                jitter: Pareto::new(1.0, 1.0).expect("distribution should be valid"),
351                jitter_gen: StdRng::from_seed([
352                    3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4, 6, 2, 6, 4, 3, 3,
353                    8, 3, 2, 7, 9, 5,
354                ]),
355                lower_bound_us: lower_bound.as_micros().max(10_000) as u64,
356                upper_bound_us: upper_bound.as_micros().max(10_000) as u64,
357            }
358        }
359
360        pub(crate) fn get_sleep_dur(&mut self) -> Duration {
361            let lower = self.lower_bound_us;
362            let upper = self.upper_bound_us;
363            // Generate a value between `lower` and `upper`, with a long tail shape to the distribution.
364            #[allow(clippy::cast_sign_loss)] // Value will be positive anyway
365            self.jitter
366                .sample_iter(&mut self.jitter_gen)
367                .map(|n| n * lower as f64)
368                .map(|n| n as u64)
369                .filter(|n| *n > lower && *n < upper)
370                .map(Duration::from_micros)
371                .next()
372                .expect("jitter iter should be endless")
373        }
374    }
375
376    impl Service<DelayRequest> for DelayService {
377        type Response = DelayResponse;
378        type Error = ();
379        type Future =
380            Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + Sync>>;
381
382        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
383            assert!(
384                self.permit.is_none(),
385                "should not call poll_ready again after a successful call"
386            );
387
388            match ready!(self.semaphore.poll_acquire(cx)) {
389                None => panic!("semaphore should not be closed!"),
390                Some(permit) => assert!(self.permit.replace(permit).is_none()),
391            }
392
393            Poll::Ready(Ok(()))
394        }
395
396        fn call(&mut self, req: DelayRequest) -> Self::Future {
397            let permit = self
398                .permit
399                .take()
400                .expect("calling `call` without successful `poll_ready` is invalid");
401            let sleep_dur = self.get_sleep_dur();
402
403            Box::pin(async move {
404                sleep(sleep_dur).await;
405
406                // Manually drop our permit here so that we take ownership and then actually
407                // release the slot back to the semaphore.
408                drop(permit);
409                drop(req);
410
411                Ok(DelayResponse::new())
412            })
413        }
414    }
415
416    #[tokio::test]
417    async fn driver_simple() {
418        // This test uses a service which creates response futures that sleep for a variable, but
419        // bounded, amount of time, giving the impression of work being completed.  Completion of
420        // all requests/responses is asserted by checking that the shared counter matches the
421        // expected ack amount.  The delays themselves are deterministic based on a fixed-seed
422        // RNG, so the test should always run in a fairly constant time between runs.
423        //
424        // TODO: Given the use of a deterministic RNG, we could likely transition this test to be
425        // driven via `proptest`, to also allow driving the input requests.  The main thing that we
426        // do not control is the arrival of requests in the input stream itself, which means that
427        // the generated batches will almost always be the biggest possible size, since the stream
428        // is always immediately available.
429        //
430        // It might be possible to spawn a background task to drive a true MPSC channel with
431        // requests based on input provided from `proptest` to control not only the value (which
432        // determines ack size) but the delay between messages, as well... simulating delays between
433        // bursts of messages, similar to real sources.
434
435        let counter = Counter::default();
436
437        // Set up our driver input stream, service, etc.
438        let input_requests = (1..=2048).collect::<Vec<_>>();
439        let input_total: usize = input_requests.iter().sum();
440        let input_stream = stream::iter(
441            input_requests
442                .into_iter()
443                .map(|i| DelayRequest::new(i, &counter)),
444        );
445        let service = DelayService::new(10, Duration::from_millis(5), Duration::from_millis(150));
446        let driver = Driver::new(input_stream, service);
447
448        // Now actually run the driver, consuming all of the input.
449        assert_eq!(driver.run().await, Ok(()));
450        // Make sure the final finalizer task runs.
451        tokio::task::yield_now().await;
452        assert_eq!(input_total, counter.load(Ordering::SeqCst));
453    }
454}