vector_stream/
driver.rs

1use std::{collections::VecDeque, fmt, future::poll_fn, task::Poll};
2
3use futures::{FutureExt, Stream, StreamExt, TryFutureExt, poll};
4use tokio::{pin, select};
5use tower::Service;
6use tracing::Instrument;
7use vector_common::{
8    internal_event::{
9        ByteSize, BytesSent, CallError, InternalEventHandle as _, PollReadyError, Registered,
10        RegisteredEventCache, SharedString, TaggedEventsSent, emit, register,
11    },
12    request_metadata::{GroupedCountByteSize, MetaDescriptive},
13};
14use vector_core::event::{EventFinalizers, EventStatus, Finalizable};
15
16use super::FuturesUnorderedCount;
17
18pub trait DriverResponse {
19    fn event_status(&self) -> EventStatus;
20    fn events_sent(&self) -> &GroupedCountByteSize;
21
22    /// Return the number of bytes that were sent in the request that returned this response.
23    // TODO, remove the default implementation once all sinks have
24    // implemented this function.
25    fn bytes_sent(&self) -> Option<usize> {
26        None
27    }
28}
29
30/// Drives the interaction between a stream of items and a service which processes them
31/// asynchronously.
32///
33/// `Driver`, as a high-level, facilitates taking items from an arbitrary `Stream` and pushing them
34/// through a `Service`, spawning each call to the service so that work can be run concurrently,
35/// managing waiting for the service to be ready before processing more items, and so on.
36///
37/// Additionally, `Driver` handles event finalization, which triggers acknowledgements
38/// to the source or disk buffer.
39///
40/// This capability is parameterized so any implementation which can define how to interpret the
41/// response for each request, as well as define how many events a request is compromised of, can be
42/// used with `Driver`.
43pub struct Driver<St, Svc> {
44    input: St,
45    service: Svc,
46    protocol: Option<SharedString>,
47}
48
49impl<St, Svc> Driver<St, Svc> {
50    pub fn new(input: St, service: Svc) -> Self {
51        Self {
52            input,
53            service,
54            protocol: None,
55        }
56    }
57
58    /// Set the protocol name for this driver.
59    ///
60    /// If this is set, the driver will fetch and use the `bytes_sent` value from responses in a
61    /// `BytesSent` event.
62    #[must_use]
63    pub fn protocol(mut self, protocol: impl Into<SharedString>) -> Self {
64        self.protocol = Some(protocol.into());
65        self
66    }
67}
68
69impl<St, Svc> Driver<St, Svc>
70where
71    St: Stream,
72    St::Item: Finalizable + MetaDescriptive,
73    Svc: Service<St::Item>,
74    Svc::Error: fmt::Debug + 'static,
75    Svc::Future: Send + 'static,
76    Svc::Response: DriverResponse,
77{
78    /// Runs the driver until the input stream is exhausted.
79    ///
80    /// All in-flight calls to the provided `service` will also be completed before `run` returns.
81    ///
82    /// # Errors
83    ///
84    /// The return type is mostly to simplify caller code.
85    /// An error is currently only returned if a service returns an error from `poll_ready`
86    pub async fn run(self) -> Result<(), ()> {
87        let mut in_flight = FuturesUnorderedCount::new();
88        let mut next_batch: Option<VecDeque<St::Item>> = None;
89        let mut seq_num = 0usize;
90
91        let Self {
92            input,
93            mut service,
94            protocol,
95        } = self;
96
97        let batched_input = input.ready_chunks(1024);
98        pin!(batched_input);
99
100        let bytes_sent = protocol.map(|protocol| register(BytesSent { protocol }));
101        let events_sent = RegisteredEventCache::new(());
102
103        loop {
104            // Core behavior of the loop:
105            // - always check to see if we have any response futures that have completed
106            //  -- if so, handling acking as many events as we can (ordering matters)
107            // - if we have a "current" batch, try to send each request in it to the service
108            //   -- if we can't drain all requests from the batch due to lack of service readiness,
109            //   then put the batch back and try to send the rest of it when the service is ready
110            //   again
111            // - if we have no "current" batch, but there is an available batch from our input
112            //   stream, grab that batch and store it as our current batch
113            //
114            // Essentially, we bounce back and forth between "grab the new batch from the input
115            // stream" and "send all requests in the batch to our service" which _could be trivially
116            // modeled with a normal imperative loop.  However, we want to be able to interleave the
117            // acknowledgement of responses to allow buffers and sources to continue making forward
118            // progress, which necessitates a more complex weaving of logic.  Using `select!` is
119            // more code, and requires a more careful eye than blindly doing
120            // "get_next_batch().await; process_batch().await", but it does make doing the complex
121            // logic easier than if we tried to interleave it ourselves with an imperative-style loop.
122
123            select! {
124                // Using `biased` ensures we check the branches in the order they're written, since
125                // the default behavior of the `select!` macro is to randomly order branches as a
126                // means of ensuring scheduling fairness.
127                biased;
128
129                // One or more of our service calls have completed.
130                Some(_count) = in_flight.next(), if !in_flight.is_empty() => {}
131
132                // We've got an input batch to process and the service is ready to accept a request.
133                maybe_ready = poll_fn(|cx| service.poll_ready(cx)), if next_batch.is_some() => {
134                    let mut batch = next_batch.take()
135                        .unwrap_or_else(|| unreachable!("batch should be populated"));
136
137                    let mut maybe_ready = Some(maybe_ready);
138                    while !batch.is_empty() {
139                        // Make sure the service is ready to take another request.
140                        let maybe_ready = match maybe_ready.take() {
141                            Some(ready) => Poll::Ready(ready),
142                            None => poll!(poll_fn(|cx| service.poll_ready(cx))),
143                        };
144
145                        let svc = match maybe_ready {
146                            Poll::Ready(Ok(())) => &mut service,
147                            Poll::Ready(Err(error)) => {
148                                emit(PollReadyError{ error });
149                                return Err(())
150                            }
151                            Poll::Pending => {
152                                next_batch = Some(batch);
153                                break
154                            },
155                        };
156
157                        let mut req = batch.pop_front().unwrap_or_else(|| unreachable!("batch should not be empty"));
158                        seq_num += 1;
159                        let request_id = seq_num;
160
161                        trace!(
162                            message = "Submitting service request.",
163                            in_flight_requests = in_flight.len(),
164                            request_id,
165                        );
166                        let finalizers = req.take_finalizers();
167                        let bytes_sent = bytes_sent.clone();
168                        let events_sent = events_sent.clone();
169                        let event_count = req.get_metadata().event_count();
170
171                        let fut = svc.call(req)
172                            .err_into()
173                            .map(move |result| Self::handle_response(
174                                result,
175                                request_id,
176                                finalizers,
177                                event_count,
178                                bytes_sent.as_ref(),
179                                &events_sent,
180                            ))
181                            .instrument(info_span!("request", request_id).or_current());
182
183                        in_flight.push(fut);
184                    }
185                }
186
187                // We've received some items from the input stream.
188                Some(reqs) = batched_input.next(), if next_batch.is_none() => {
189                    next_batch = Some(reqs.into());
190                }
191
192                else => break
193            }
194        }
195
196        Ok(())
197    }
198
199    fn handle_response(
200        result: Result<Svc::Response, Svc::Error>,
201        request_id: usize,
202        finalizers: EventFinalizers,
203        event_count: usize,
204        bytes_sent: Option<&Registered<BytesSent>>,
205        events_sent: &RegisteredEventCache<(), TaggedEventsSent>,
206    ) {
207        match result {
208            Err(error) => {
209                Self::emit_call_error(Some(error), request_id, event_count);
210                finalizers.update_status(EventStatus::Rejected);
211            }
212            Ok(response) => {
213                trace!(message = "Service call succeeded.", request_id);
214                finalizers.update_status(response.event_status());
215                if response.event_status() == EventStatus::Delivered {
216                    if let Some(bytes_sent) = bytes_sent
217                        && let Some(byte_size) = response.bytes_sent()
218                    {
219                        bytes_sent.emit(ByteSize(byte_size));
220                    }
221
222                    response.events_sent().emit_event(events_sent);
223
224                // This condition occurs specifically when the `HttpBatchService::call()` is called *within* the `Service::call()`
225                } else if response.event_status() == EventStatus::Rejected {
226                    Self::emit_call_error(None, request_id, event_count);
227                    finalizers.update_status(EventStatus::Rejected);
228                }
229            }
230        }
231        drop(finalizers); // suppress "argument not consumed" warning
232    }
233
234    /// Emit the `Error` and `EventsDropped` internal events.
235    /// This scenario occurs after retries have been attempted.
236    fn emit_call_error(error: Option<Svc::Error>, request_id: usize, count: usize) {
237        emit(CallError {
238            error,
239            request_id,
240            count,
241        });
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use std::{
248        future::Future,
249        pin::Pin,
250        sync::{
251            Arc,
252            atomic::{AtomicUsize, Ordering},
253        },
254        task::{Context, Poll, ready},
255        time::Duration,
256    };
257
258    use futures_util::stream;
259    use rand::{SeedableRng, prelude::StdRng};
260    use rand_distr::{Distribution, Pareto};
261    use tokio::{
262        sync::{OwnedSemaphorePermit, Semaphore},
263        time::sleep,
264    };
265    use tokio_util::sync::PollSemaphore;
266    use tower::Service;
267    use vector_common::{
268        finalization::{BatchNotifier, EventFinalizer, EventFinalizers, EventStatus, Finalizable},
269        internal_event::CountByteSize,
270        json_size::JsonSize,
271        request_metadata::{GroupedCountByteSize, MetaDescriptive, RequestMetadata},
272    };
273
274    use super::{Driver, DriverResponse};
275
276    type Counter = Arc<AtomicUsize>;
277
278    #[derive(Debug)]
279    struct DelayRequest(EventFinalizers, RequestMetadata);
280
281    impl DelayRequest {
282        fn new(value: usize, counter: &Counter) -> Self {
283            let (batch, receiver) = BatchNotifier::new_with_receiver();
284            let counter = Arc::clone(counter);
285            tokio::spawn(async move {
286                receiver.await;
287                counter.fetch_add(value, Ordering::Relaxed);
288            });
289            Self(
290                EventFinalizers::new(EventFinalizer::new(batch)),
291                RequestMetadata::default(),
292            )
293        }
294    }
295
296    impl Finalizable for DelayRequest {
297        fn take_finalizers(&mut self) -> vector_core::event::EventFinalizers {
298            std::mem::take(&mut self.0)
299        }
300    }
301
302    impl MetaDescriptive for DelayRequest {
303        fn get_metadata(&self) -> &RequestMetadata {
304            &self.1
305        }
306
307        fn metadata_mut(&mut self) -> &mut RequestMetadata {
308            &mut self.1
309        }
310    }
311
312    struct DelayResponse {
313        events_sent: GroupedCountByteSize,
314    }
315
316    impl DelayResponse {
317        fn new() -> Self {
318            Self {
319                events_sent: CountByteSize(1, JsonSize::new(1)).into(),
320            }
321        }
322    }
323
324    impl DriverResponse for DelayResponse {
325        fn event_status(&self) -> EventStatus {
326            EventStatus::Delivered
327        }
328
329        fn events_sent(&self) -> &GroupedCountByteSize {
330            &self.events_sent
331        }
332    }
333
334    // Generic service that takes a usize and applies an arbitrary delay to returning it.
335    struct DelayService {
336        semaphore: PollSemaphore,
337        permit: Option<OwnedSemaphorePermit>,
338        jitter: Pareto<f64>,
339        jitter_gen: StdRng,
340        lower_bound_us: u64,
341        upper_bound_us: u64,
342    }
343
344    // Clippy is unhappy about all of our f64/u64 shuffling.  We don't actually care about losing
345    // the fractional part of 20,459.13142 or whatever.  It just doesn't matter for this test.
346    #[allow(clippy::cast_possible_truncation)]
347    #[allow(clippy::cast_precision_loss)]
348    impl DelayService {
349        pub(crate) fn new(permits: usize, lower_bound: Duration, upper_bound: Duration) -> Self {
350            assert!(upper_bound > lower_bound);
351            Self {
352                semaphore: PollSemaphore::new(Arc::new(Semaphore::new(permits))),
353                permit: None,
354                jitter: Pareto::new(1.0, 1.0).expect("distribution should be valid"),
355                jitter_gen: StdRng::from_seed([
356                    3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4, 6, 2, 6, 4, 3, 3,
357                    8, 3, 2, 7, 9, 5,
358                ]),
359                lower_bound_us: lower_bound.as_micros().max(10_000) as u64,
360                upper_bound_us: upper_bound.as_micros().max(10_000) as u64,
361            }
362        }
363
364        pub(crate) fn get_sleep_dur(&mut self) -> Duration {
365            let lower = self.lower_bound_us;
366            let upper = self.upper_bound_us;
367            // Generate a value between `lower` and `upper`, with a long tail shape to the distribution.
368            #[allow(clippy::cast_sign_loss)] // Value will be positive anyway
369            self.jitter
370                .sample_iter(&mut self.jitter_gen)
371                .map(|n| n * lower as f64)
372                .map(|n| n as u64)
373                .filter(|n| *n > lower && *n < upper)
374                .map(Duration::from_micros)
375                .next()
376                .expect("jitter iter should be endless")
377        }
378    }
379
380    impl Service<DelayRequest> for DelayService {
381        type Response = DelayResponse;
382        type Error = ();
383        type Future =
384            Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + Sync>>;
385
386        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387            assert!(
388                self.permit.is_none(),
389                "should not call poll_ready again after a successful call"
390            );
391
392            match ready!(self.semaphore.poll_acquire(cx)) {
393                None => panic!("semaphore should not be closed!"),
394                Some(permit) => assert!(self.permit.replace(permit).is_none()),
395            }
396
397            Poll::Ready(Ok(()))
398        }
399
400        fn call(&mut self, req: DelayRequest) -> Self::Future {
401            let permit = self
402                .permit
403                .take()
404                .expect("calling `call` without successful `poll_ready` is invalid");
405            let sleep_dur = self.get_sleep_dur();
406
407            Box::pin(async move {
408                sleep(sleep_dur).await;
409
410                // Manually drop our permit here so that we take ownership and then actually
411                // release the slot back to the semaphore.
412                drop(permit);
413                drop(req);
414
415                Ok(DelayResponse::new())
416            })
417        }
418    }
419
420    #[tokio::test]
421    async fn driver_simple() {
422        // This test uses a service which creates response futures that sleep for a variable, but
423        // bounded, amount of time, giving the impression of work being completed.  Completion of
424        // all requests/responses is asserted by checking that the shared counter matches the
425        // expected ack amount.  The delays themselves are deterministic based on a fixed-seed
426        // RNG, so the test should always run in a fairly constant time between runs.
427        //
428        // TODO: Given the use of a deterministic RNG, we could likely transition this test to be
429        // driven via `proptest`, to also allow driving the input requests.  The main thing that we
430        // do not control is the arrival of requests in the input stream itself, which means that
431        // the generated batches will almost always be the biggest possible size, since the stream
432        // is always immediately available.
433        //
434        // It might be possible to spawn a background task to drive a true MPSC channel with
435        // requests based on input provided from `proptest` to control not only the value (which
436        // determines ack size) but the delay between messages, as well... simulating delays between
437        // bursts of messages, similar to real sources.
438
439        let counter = Counter::default();
440
441        // Set up our driver input stream, service, etc.
442        let input_requests = (1..=2048).collect::<Vec<_>>();
443        let input_total: usize = input_requests.iter().sum();
444        let input_stream = stream::iter(
445            input_requests
446                .into_iter()
447                .map(|i| DelayRequest::new(i, &counter)),
448        );
449        let service = DelayService::new(10, Duration::from_millis(5), Duration::from_millis(150));
450        let driver = Driver::new(input_stream, service);
451
452        // Now actually run the driver, consuming all of the input.
453        assert_eq!(driver.run().await, Ok(()));
454        // Make sure the final finalizer task runs.
455        tokio::task::yield_now().await;
456        assert_eq!(input_total, counter.load(Ordering::SeqCst));
457    }
458}