1use std::{collections::VecDeque, fmt, future::poll_fn, task::Poll};
2
3use futures::{poll, FutureExt, Stream, StreamExt, TryFutureExt};
4use tokio::{pin, select};
5use tower::Service;
6use tracing::Instrument;
7use vector_common::internal_event::emit;
8use vector_common::internal_event::{
9 register, ByteSize, BytesSent, CallError, InternalEventHandle as _, PollReadyError, Registered,
10 RegisteredEventCache, SharedString, TaggedEventsSent,
11};
12use vector_common::request_metadata::{GroupedCountByteSize, MetaDescriptive};
13use vector_core::event::{EventFinalizers, EventStatus, Finalizable};
14
15use super::FuturesUnorderedCount;
16
17pub trait DriverResponse {
18 fn event_status(&self) -> EventStatus;
19 fn events_sent(&self) -> &GroupedCountByteSize;
20
21 fn bytes_sent(&self) -> Option<usize> {
25 None
26 }
27}
28
29pub struct Driver<St, Svc> {
43 input: St,
44 service: Svc,
45 protocol: Option<SharedString>,
46}
47
48impl<St, Svc> Driver<St, Svc> {
49 pub fn new(input: St, service: Svc) -> Self {
50 Self {
51 input,
52 service,
53 protocol: None,
54 }
55 }
56
57 #[must_use]
62 pub fn protocol(mut self, protocol: impl Into<SharedString>) -> Self {
63 self.protocol = Some(protocol.into());
64 self
65 }
66}
67
68impl<St, Svc> Driver<St, Svc>
69where
70 St: Stream,
71 St::Item: Finalizable + MetaDescriptive,
72 Svc: Service<St::Item>,
73 Svc::Error: fmt::Debug + 'static,
74 Svc::Future: Send + 'static,
75 Svc::Response: DriverResponse,
76{
77 pub async fn run(self) -> Result<(), ()> {
86 let mut in_flight = FuturesUnorderedCount::new();
87 let mut next_batch: Option<VecDeque<St::Item>> = None;
88 let mut seq_num = 0usize;
89
90 let Self {
91 input,
92 mut service,
93 protocol,
94 } = self;
95
96 let batched_input = input.ready_chunks(1024);
97 pin!(batched_input);
98
99 let bytes_sent = protocol.map(|protocol| register(BytesSent { protocol }));
100 let events_sent = RegisteredEventCache::new(());
101
102 loop {
103 select! {
123 biased;
127
128 Some(_count) = in_flight.next(), if !in_flight.is_empty() => {}
130
131 maybe_ready = poll_fn(|cx| service.poll_ready(cx)), if next_batch.is_some() => {
133 let mut batch = next_batch.take()
134 .unwrap_or_else(|| unreachable!("batch should be populated"));
135
136 let mut maybe_ready = Some(maybe_ready);
137 while !batch.is_empty() {
138 let maybe_ready = match maybe_ready.take() {
140 Some(ready) => Poll::Ready(ready),
141 None => poll!(poll_fn(|cx| service.poll_ready(cx))),
142 };
143
144 let svc = match maybe_ready {
145 Poll::Ready(Ok(())) => &mut service,
146 Poll::Ready(Err(error)) => {
147 emit(PollReadyError{ error });
148 return Err(())
149 }
150 Poll::Pending => {
151 next_batch = Some(batch);
152 break
153 },
154 };
155
156 let mut req = batch.pop_front().unwrap_or_else(|| unreachable!("batch should not be empty"));
157 seq_num += 1;
158 let request_id = seq_num;
159
160 trace!(
161 message = "Submitting service request.",
162 in_flight_requests = in_flight.len(),
163 request_id,
164 );
165 let finalizers = req.take_finalizers();
166 let bytes_sent = bytes_sent.clone();
167 let events_sent = events_sent.clone();
168 let event_count = req.get_metadata().event_count();
169
170 let fut = svc.call(req)
171 .err_into()
172 .map(move |result| Self::handle_response(
173 result,
174 request_id,
175 finalizers,
176 event_count,
177 bytes_sent.as_ref(),
178 &events_sent,
179 ))
180 .instrument(info_span!("request", request_id).or_current());
181
182 in_flight.push(fut);
183 }
184 }
185
186 Some(reqs) = batched_input.next(), if next_batch.is_none() => {
188 next_batch = Some(reqs.into());
189 }
190
191 else => break
192 }
193 }
194
195 Ok(())
196 }
197
198 fn handle_response(
199 result: Result<Svc::Response, Svc::Error>,
200 request_id: usize,
201 finalizers: EventFinalizers,
202 event_count: usize,
203 bytes_sent: Option<&Registered<BytesSent>>,
204 events_sent: &RegisteredEventCache<(), TaggedEventsSent>,
205 ) {
206 match result {
207 Err(error) => {
208 Self::emit_call_error(Some(error), request_id, event_count);
209 finalizers.update_status(EventStatus::Rejected);
210 }
211 Ok(response) => {
212 trace!(message = "Service call succeeded.", request_id);
213 finalizers.update_status(response.event_status());
214 if response.event_status() == EventStatus::Delivered {
215 if let Some(bytes_sent) = bytes_sent {
216 if let Some(byte_size) = response.bytes_sent() {
217 bytes_sent.emit(ByteSize(byte_size));
218 }
219 }
220
221 response.events_sent().emit_event(events_sent);
222
223 } else if response.event_status() == EventStatus::Rejected {
225 Self::emit_call_error(None, request_id, event_count);
226 finalizers.update_status(EventStatus::Rejected);
227 }
228 }
229 }
230 drop(finalizers); }
232
233 fn emit_call_error(error: Option<Svc::Error>, request_id: usize, count: usize) {
236 emit(CallError {
237 error,
238 request_id,
239 count,
240 });
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use std::{
247 future::Future,
248 pin::Pin,
249 sync::{atomic::AtomicUsize, atomic::Ordering, Arc},
250 task::{ready, Context, Poll},
251 time::Duration,
252 };
253
254 use futures_util::stream;
255 use rand::{prelude::StdRng, SeedableRng};
256 use rand_distr::{Distribution, Pareto};
257 use tokio::{
258 sync::{OwnedSemaphorePermit, Semaphore},
259 time::sleep,
260 };
261 use tokio_util::sync::PollSemaphore;
262 use tower::Service;
263 use vector_common::{
264 finalization::{BatchNotifier, EventFinalizer, EventFinalizers, EventStatus, Finalizable},
265 json_size::JsonSize,
266 request_metadata::{GroupedCountByteSize, RequestMetadata},
267 };
268 use vector_common::{internal_event::CountByteSize, request_metadata::MetaDescriptive};
269
270 use super::{Driver, DriverResponse};
271
272 type Counter = Arc<AtomicUsize>;
273
274 #[derive(Debug)]
275 struct DelayRequest(EventFinalizers, RequestMetadata);
276
277 impl DelayRequest {
278 fn new(value: usize, counter: &Counter) -> Self {
279 let (batch, receiver) = BatchNotifier::new_with_receiver();
280 let counter = Arc::clone(counter);
281 tokio::spawn(async move {
282 receiver.await;
283 counter.fetch_add(value, Ordering::Relaxed);
284 });
285 Self(
286 EventFinalizers::new(EventFinalizer::new(batch)),
287 RequestMetadata::default(),
288 )
289 }
290 }
291
292 impl Finalizable for DelayRequest {
293 fn take_finalizers(&mut self) -> vector_core::event::EventFinalizers {
294 std::mem::take(&mut self.0)
295 }
296 }
297
298 impl MetaDescriptive for DelayRequest {
299 fn get_metadata(&self) -> &RequestMetadata {
300 &self.1
301 }
302
303 fn metadata_mut(&mut self) -> &mut RequestMetadata {
304 &mut self.1
305 }
306 }
307
308 struct DelayResponse {
309 events_sent: GroupedCountByteSize,
310 }
311
312 impl DelayResponse {
313 fn new() -> Self {
314 Self {
315 events_sent: CountByteSize(1, JsonSize::new(1)).into(),
316 }
317 }
318 }
319
320 impl DriverResponse for DelayResponse {
321 fn event_status(&self) -> EventStatus {
322 EventStatus::Delivered
323 }
324
325 fn events_sent(&self) -> &GroupedCountByteSize {
326 &self.events_sent
327 }
328 }
329
330 struct DelayService {
332 semaphore: PollSemaphore,
333 permit: Option<OwnedSemaphorePermit>,
334 jitter: Pareto<f64>,
335 jitter_gen: StdRng,
336 lower_bound_us: u64,
337 upper_bound_us: u64,
338 }
339
340 #[allow(clippy::cast_possible_truncation)]
343 #[allow(clippy::cast_precision_loss)]
344 impl DelayService {
345 pub(crate) fn new(permits: usize, lower_bound: Duration, upper_bound: Duration) -> Self {
346 assert!(upper_bound > lower_bound);
347 Self {
348 semaphore: PollSemaphore::new(Arc::new(Semaphore::new(permits))),
349 permit: None,
350 jitter: Pareto::new(1.0, 1.0).expect("distribution should be valid"),
351 jitter_gen: StdRng::from_seed([
352 3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4, 6, 2, 6, 4, 3, 3,
353 8, 3, 2, 7, 9, 5,
354 ]),
355 lower_bound_us: lower_bound.as_micros().max(10_000) as u64,
356 upper_bound_us: upper_bound.as_micros().max(10_000) as u64,
357 }
358 }
359
360 pub(crate) fn get_sleep_dur(&mut self) -> Duration {
361 let lower = self.lower_bound_us;
362 let upper = self.upper_bound_us;
363 #[allow(clippy::cast_sign_loss)] self.jitter
366 .sample_iter(&mut self.jitter_gen)
367 .map(|n| n * lower as f64)
368 .map(|n| n as u64)
369 .filter(|n| *n > lower && *n < upper)
370 .map(Duration::from_micros)
371 .next()
372 .expect("jitter iter should be endless")
373 }
374 }
375
376 impl Service<DelayRequest> for DelayService {
377 type Response = DelayResponse;
378 type Error = ();
379 type Future =
380 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + Sync>>;
381
382 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
383 assert!(
384 self.permit.is_none(),
385 "should not call poll_ready again after a successful call"
386 );
387
388 match ready!(self.semaphore.poll_acquire(cx)) {
389 None => panic!("semaphore should not be closed!"),
390 Some(permit) => assert!(self.permit.replace(permit).is_none()),
391 }
392
393 Poll::Ready(Ok(()))
394 }
395
396 fn call(&mut self, req: DelayRequest) -> Self::Future {
397 let permit = self
398 .permit
399 .take()
400 .expect("calling `call` without successful `poll_ready` is invalid");
401 let sleep_dur = self.get_sleep_dur();
402
403 Box::pin(async move {
404 sleep(sleep_dur).await;
405
406 drop(permit);
409 drop(req);
410
411 Ok(DelayResponse::new())
412 })
413 }
414 }
415
416 #[tokio::test]
417 async fn driver_simple() {
418 let counter = Counter::default();
436
437 let input_requests = (1..=2048).collect::<Vec<_>>();
439 let input_total: usize = input_requests.iter().sum();
440 let input_stream = stream::iter(
441 input_requests
442 .into_iter()
443 .map(|i| DelayRequest::new(i, &counter)),
444 );
445 let service = DelayService::new(10, Duration::from_millis(5), Duration::from_millis(150));
446 let driver = Driver::new(input_stream, service);
447
448 assert_eq!(driver.run().await, Ok(()));
450 tokio::task::yield_now().await;
452 assert_eq!(input_total, counter.load(Ordering::SeqCst));
453 }
454}