1use std::{collections::VecDeque, fmt, future::poll_fn, task::Poll};
2
3use futures::{FutureExt, Stream, StreamExt, TryFutureExt, poll};
4use tokio::{pin, select};
5use tower::Service;
6use tracing::Instrument;
7use vector_common::{
8 internal_event::{
9 ByteSize, BytesSent, CallError, InternalEventHandle as _, PollReadyError, Registered,
10 RegisteredEventCache, SharedString, TaggedEventsSent, emit, register,
11 },
12 request_metadata::{GroupedCountByteSize, MetaDescriptive},
13};
14use vector_core::event::{EventFinalizers, EventStatus, Finalizable};
15
16use super::FuturesUnorderedCount;
17
18pub trait DriverResponse {
19 fn event_status(&self) -> EventStatus;
20 fn events_sent(&self) -> &GroupedCountByteSize;
21
22 fn bytes_sent(&self) -> Option<usize> {
26 None
27 }
28}
29
30pub struct Driver<St, Svc> {
44 input: St,
45 service: Svc,
46 protocol: Option<SharedString>,
47}
48
49impl<St, Svc> Driver<St, Svc> {
50 pub fn new(input: St, service: Svc) -> Self {
51 Self {
52 input,
53 service,
54 protocol: None,
55 }
56 }
57
58 #[must_use]
63 pub fn protocol(mut self, protocol: impl Into<SharedString>) -> Self {
64 self.protocol = Some(protocol.into());
65 self
66 }
67}
68
69impl<St, Svc> Driver<St, Svc>
70where
71 St: Stream,
72 St::Item: Finalizable + MetaDescriptive,
73 Svc: Service<St::Item>,
74 Svc::Error: fmt::Debug + 'static,
75 Svc::Future: Send + 'static,
76 Svc::Response: DriverResponse,
77{
78 pub async fn run(self) -> Result<(), ()> {
87 let mut in_flight = FuturesUnorderedCount::new();
88 let mut next_batch: Option<VecDeque<St::Item>> = None;
89 let mut seq_num = 0usize;
90
91 let Self {
92 input,
93 mut service,
94 protocol,
95 } = self;
96
97 let batched_input = input.ready_chunks(1024);
98 pin!(batched_input);
99
100 let bytes_sent = protocol.map(|protocol| register(BytesSent { protocol }));
101 let events_sent = RegisteredEventCache::new(());
102
103 loop {
104 select! {
124 biased;
128
129 Some(_count) = in_flight.next(), if !in_flight.is_empty() => {}
131
132 maybe_ready = poll_fn(|cx| service.poll_ready(cx)), if next_batch.is_some() => {
134 let mut batch = next_batch.take()
135 .unwrap_or_else(|| unreachable!("batch should be populated"));
136
137 let mut maybe_ready = Some(maybe_ready);
138 while !batch.is_empty() {
139 let maybe_ready = match maybe_ready.take() {
141 Some(ready) => Poll::Ready(ready),
142 None => poll!(poll_fn(|cx| service.poll_ready(cx))),
143 };
144
145 let svc = match maybe_ready {
146 Poll::Ready(Ok(())) => &mut service,
147 Poll::Ready(Err(error)) => {
148 emit(PollReadyError{ error });
149 return Err(())
150 }
151 Poll::Pending => {
152 next_batch = Some(batch);
153 break
154 },
155 };
156
157 let mut req = batch.pop_front().unwrap_or_else(|| unreachable!("batch should not be empty"));
158 seq_num += 1;
159 let request_id = seq_num;
160
161 trace!(
162 message = "Submitting service request.",
163 in_flight_requests = in_flight.len(),
164 request_id,
165 );
166 let finalizers = req.take_finalizers();
167 let bytes_sent = bytes_sent.clone();
168 let events_sent = events_sent.clone();
169 let event_count = req.get_metadata().event_count();
170
171 let fut = svc.call(req)
172 .err_into()
173 .map(move |result| Self::handle_response(
174 result,
175 request_id,
176 finalizers,
177 event_count,
178 bytes_sent.as_ref(),
179 &events_sent,
180 ))
181 .instrument(info_span!("request", request_id).or_current());
182
183 in_flight.push(fut);
184 }
185 }
186
187 Some(reqs) = batched_input.next(), if next_batch.is_none() => {
189 next_batch = Some(reqs.into());
190 }
191
192 else => break
193 }
194 }
195
196 Ok(())
197 }
198
199 fn handle_response(
200 result: Result<Svc::Response, Svc::Error>,
201 request_id: usize,
202 finalizers: EventFinalizers,
203 event_count: usize,
204 bytes_sent: Option<&Registered<BytesSent>>,
205 events_sent: &RegisteredEventCache<(), TaggedEventsSent>,
206 ) {
207 match result {
208 Err(error) => {
209 Self::emit_call_error(Some(error), request_id, event_count);
210 finalizers.update_status(EventStatus::Rejected);
211 }
212 Ok(response) => {
213 trace!(message = "Service call succeeded.", request_id);
214 finalizers.update_status(response.event_status());
215 if response.event_status() == EventStatus::Delivered {
216 if let Some(bytes_sent) = bytes_sent
217 && let Some(byte_size) = response.bytes_sent()
218 {
219 bytes_sent.emit(ByteSize(byte_size));
220 }
221
222 response.events_sent().emit_event(events_sent);
223
224 } else if response.event_status() == EventStatus::Rejected {
226 Self::emit_call_error(None, request_id, event_count);
227 finalizers.update_status(EventStatus::Rejected);
228 }
229 }
230 }
231 drop(finalizers); }
233
234 fn emit_call_error(error: Option<Svc::Error>, request_id: usize, count: usize) {
237 emit(CallError {
238 error,
239 request_id,
240 count,
241 });
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use std::{
248 future::Future,
249 pin::Pin,
250 sync::{
251 Arc,
252 atomic::{AtomicUsize, Ordering},
253 },
254 task::{Context, Poll, ready},
255 time::Duration,
256 };
257
258 use futures_util::stream;
259 use rand::{SeedableRng, prelude::StdRng};
260 use rand_distr::{Distribution, Pareto};
261 use tokio::{
262 sync::{OwnedSemaphorePermit, Semaphore},
263 time::sleep,
264 };
265 use tokio_util::sync::PollSemaphore;
266 use tower::Service;
267 use vector_common::{
268 finalization::{BatchNotifier, EventFinalizer, EventFinalizers, EventStatus, Finalizable},
269 internal_event::CountByteSize,
270 json_size::JsonSize,
271 request_metadata::{GroupedCountByteSize, MetaDescriptive, RequestMetadata},
272 };
273
274 use super::{Driver, DriverResponse};
275
276 type Counter = Arc<AtomicUsize>;
277
278 #[derive(Debug)]
279 struct DelayRequest(EventFinalizers, RequestMetadata);
280
281 impl DelayRequest {
282 fn new(value: usize, counter: &Counter) -> Self {
283 let (batch, receiver) = BatchNotifier::new_with_receiver();
284 let counter = Arc::clone(counter);
285 tokio::spawn(async move {
286 receiver.await;
287 counter.fetch_add(value, Ordering::Relaxed);
288 });
289 Self(
290 EventFinalizers::new(EventFinalizer::new(batch)),
291 RequestMetadata::default(),
292 )
293 }
294 }
295
296 impl Finalizable for DelayRequest {
297 fn take_finalizers(&mut self) -> vector_core::event::EventFinalizers {
298 std::mem::take(&mut self.0)
299 }
300 }
301
302 impl MetaDescriptive for DelayRequest {
303 fn get_metadata(&self) -> &RequestMetadata {
304 &self.1
305 }
306
307 fn metadata_mut(&mut self) -> &mut RequestMetadata {
308 &mut self.1
309 }
310 }
311
312 struct DelayResponse {
313 events_sent: GroupedCountByteSize,
314 }
315
316 impl DelayResponse {
317 fn new() -> Self {
318 Self {
319 events_sent: CountByteSize(1, JsonSize::new(1)).into(),
320 }
321 }
322 }
323
324 impl DriverResponse for DelayResponse {
325 fn event_status(&self) -> EventStatus {
326 EventStatus::Delivered
327 }
328
329 fn events_sent(&self) -> &GroupedCountByteSize {
330 &self.events_sent
331 }
332 }
333
334 struct DelayService {
336 semaphore: PollSemaphore,
337 permit: Option<OwnedSemaphorePermit>,
338 jitter: Pareto<f64>,
339 jitter_gen: StdRng,
340 lower_bound_us: u64,
341 upper_bound_us: u64,
342 }
343
344 #[allow(clippy::cast_possible_truncation)]
347 #[allow(clippy::cast_precision_loss)]
348 impl DelayService {
349 pub(crate) fn new(permits: usize, lower_bound: Duration, upper_bound: Duration) -> Self {
350 assert!(upper_bound > lower_bound);
351 Self {
352 semaphore: PollSemaphore::new(Arc::new(Semaphore::new(permits))),
353 permit: None,
354 jitter: Pareto::new(1.0, 1.0).expect("distribution should be valid"),
355 jitter_gen: StdRng::from_seed([
356 3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4, 6, 2, 6, 4, 3, 3,
357 8, 3, 2, 7, 9, 5,
358 ]),
359 lower_bound_us: lower_bound.as_micros().max(10_000) as u64,
360 upper_bound_us: upper_bound.as_micros().max(10_000) as u64,
361 }
362 }
363
364 pub(crate) fn get_sleep_dur(&mut self) -> Duration {
365 let lower = self.lower_bound_us;
366 let upper = self.upper_bound_us;
367 #[allow(clippy::cast_sign_loss)] self.jitter
370 .sample_iter(&mut self.jitter_gen)
371 .map(|n| n * lower as f64)
372 .map(|n| n as u64)
373 .filter(|n| *n > lower && *n < upper)
374 .map(Duration::from_micros)
375 .next()
376 .expect("jitter iter should be endless")
377 }
378 }
379
380 impl Service<DelayRequest> for DelayService {
381 type Response = DelayResponse;
382 type Error = ();
383 type Future =
384 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + Sync>>;
385
386 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387 assert!(
388 self.permit.is_none(),
389 "should not call poll_ready again after a successful call"
390 );
391
392 match ready!(self.semaphore.poll_acquire(cx)) {
393 None => panic!("semaphore should not be closed!"),
394 Some(permit) => assert!(self.permit.replace(permit).is_none()),
395 }
396
397 Poll::Ready(Ok(()))
398 }
399
400 fn call(&mut self, req: DelayRequest) -> Self::Future {
401 let permit = self
402 .permit
403 .take()
404 .expect("calling `call` without successful `poll_ready` is invalid");
405 let sleep_dur = self.get_sleep_dur();
406
407 Box::pin(async move {
408 sleep(sleep_dur).await;
409
410 drop(permit);
413 drop(req);
414
415 Ok(DelayResponse::new())
416 })
417 }
418 }
419
420 #[tokio::test]
421 async fn driver_simple() {
422 let counter = Counter::default();
440
441 let input_requests = (1..=2048).collect::<Vec<_>>();
443 let input_total: usize = input_requests.iter().sum();
444 let input_stream = stream::iter(
445 input_requests
446 .into_iter()
447 .map(|i| DelayRequest::new(i, &counter)),
448 );
449 let service = DelayService::new(10, Duration::from_millis(5), Duration::from_millis(150));
450 let driver = Driver::new(input_stream, service);
451
452 assert_eq!(driver.run().await, Ok(()));
454 tokio::task::yield_now().await;
456 assert_eq!(input_total, counter.load(Ordering::SeqCst));
457 }
458}