1#![deny(warnings)]
2
3use std::fmt;
4
5use dashmap::DashMap;
6use tracing_core::{
7 callsite::Identifier,
8 field::{display, Field, Value, Visit},
9 span,
10 subscriber::Interest,
11 Event, Metadata, Subscriber,
12};
13use tracing_subscriber::layer::{Context, Layer};
14
15#[cfg(test)]
16#[macro_use]
17extern crate tracing;
18
19#[cfg(not(test))]
20use std::time::Instant;
21
22#[cfg(test)]
23use mock_instant::global::Instant;
24
25const RATE_LIMIT_FIELD: &str = "internal_log_rate_limit";
26const RATE_LIMIT_SECS_FIELD: &str = "internal_log_rate_secs";
27const MESSAGE_FIELD: &str = "message";
28
29const COMPONENT_ID_FIELD: &str = "component_id";
32const VRL_POSITION: &str = "vrl_position";
33
34#[derive(Eq, PartialEq, Hash, Clone)]
35struct RateKeyIdentifier {
36 callsite: Identifier,
37 rate_limit_key_values: RateLimitedSpanKeys,
38}
39
40pub struct RateLimitedLayer<S, L>
41where
42 L: Layer<S> + Sized,
43 S: Subscriber,
44{
45 events: DashMap<RateKeyIdentifier, State>,
46 inner: L,
47 internal_log_rate_limit: u64,
48 _subscriber: std::marker::PhantomData<S>,
49}
50
51impl<S, L> RateLimitedLayer<S, L>
52where
53 L: Layer<S> + Sized,
54 S: Subscriber,
55{
56 pub fn new(layer: L) -> Self {
57 RateLimitedLayer {
58 events: Default::default(),
59 internal_log_rate_limit: 10,
60 inner: layer,
61 _subscriber: std::marker::PhantomData,
62 }
63 }
64
65 pub fn with_default_limit(mut self, internal_log_rate_limit: u64) -> Self {
66 self.internal_log_rate_limit = internal_log_rate_limit;
67 self
68 }
69}
70
71impl<S, L> Layer<S> for RateLimitedLayer<S, L>
72where
73 L: Layer<S>,
74 S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
75{
76 #[inline]
77 fn register_callsite(&self, metadata: &'static Metadata<'static>) -> Interest {
78 self.inner.register_callsite(metadata)
79 }
80
81 #[inline]
82 fn enabled(&self, metadata: &Metadata<'_>, ctx: Context<'_, S>) -> bool {
83 self.inner.enabled(metadata, ctx)
84 }
85
86 fn on_new_span(&self, attrs: &span::Attributes<'_>, id: &span::Id, ctx: Context<'_, S>) {
88 {
89 let span = ctx.span(id).expect("Span not found, this is a bug");
90 let mut extensions = span.extensions_mut();
91
92 if extensions.get_mut::<RateLimitedSpanKeys>().is_none() {
93 let mut fields = RateLimitedSpanKeys::default();
94 attrs.record(&mut fields);
95 extensions.insert(fields);
96 };
97 }
98 self.inner.on_new_span(attrs, id, ctx);
99 }
100
101 fn on_record(&self, id: &span::Id, values: &span::Record<'_>, ctx: Context<'_, S>) {
103 {
104 let span = ctx.span(id).expect("Span not found, this is a bug");
105 let mut extensions = span.extensions_mut();
106
107 match extensions.get_mut::<RateLimitedSpanKeys>() {
108 Some(fields) => {
109 values.record(fields);
110 }
111 None => {
112 let mut fields = RateLimitedSpanKeys::default();
113 values.record(&mut fields);
114 extensions.insert(fields);
115 }
116 };
117 }
118 self.inner.on_record(id, values, ctx);
119 }
120
121 #[inline]
122 fn on_follows_from(&self, span: &span::Id, follows: &span::Id, ctx: Context<'_, S>) {
123 self.inner.on_follows_from(span, follows, ctx);
124 }
125
126 fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
127 let mut limit_visitor = LimitVisitor::default();
130 event.record(&mut limit_visitor);
131
132 let limit_exists = limit_visitor.limit.unwrap_or(false);
133 if !limit_exists {
134 return self.inner.on_event(event, ctx);
135 }
136
137 let limit = match limit_visitor.limit_secs {
138 Some(limit_secs) => limit_secs, None => self.internal_log_rate_limit,
140 };
141
142 let rate_limit_key_values = {
146 let mut keys = RateLimitedSpanKeys::default();
147 event.record(&mut keys);
148
149 ctx.lookup_current()
150 .into_iter()
151 .flat_map(|span| span.scope().from_root())
152 .fold(keys, |mut keys, span| {
153 let extensions = span.extensions();
154 if let Some(span_keys) = extensions.get::<RateLimitedSpanKeys>() {
155 keys.merge(span_keys);
156 }
157 keys
158 })
159 };
160
161 let metadata = event.metadata();
164 let id = RateKeyIdentifier {
165 callsite: metadata.callsite(),
166 rate_limit_key_values,
167 };
168
169 let mut state = self.events.entry(id).or_insert_with(|| {
170 let mut message_visitor = MessageVisitor::default();
171 event.record(&mut message_visitor);
172
173 let message = message_visitor
174 .message
175 .unwrap_or_else(|| metadata.name().into());
176
177 State::new(message, limit)
178 });
179
180 let previous_count = state.increment_count();
186 if state.should_limit() {
187 match previous_count {
188 0 => self.inner.on_event(event, ctx),
189 1 => {
190 let message = format!(
191 "Internal log [{}] is being suppressed to avoid flooding.",
192 state.message
193 );
194 self.create_event(&ctx, metadata, message, state.limit);
195 }
196 _ => {}
197 }
198 } else {
199 if previous_count > 1 {
202 let message = format!(
203 "Internal log [{}] has been suppressed {} times.",
204 state.message,
205 previous_count - 1
206 );
207
208 self.create_event(&ctx, metadata, message, state.limit);
209 }
210
211 self.inner.on_event(event, ctx);
214
215 state.reset();
216 }
217 }
218
219 #[inline]
220 fn on_enter(&self, id: &span::Id, ctx: Context<'_, S>) {
221 self.inner.on_enter(id, ctx);
222 }
223
224 #[inline]
225 fn on_exit(&self, id: &span::Id, ctx: Context<'_, S>) {
226 self.inner.on_exit(id, ctx);
227 }
228
229 #[inline]
230 fn on_close(&self, id: span::Id, ctx: Context<'_, S>) {
231 self.inner.on_close(id, ctx);
232 }
233
234 #[inline]
235 fn on_id_change(&self, old: &span::Id, new: &span::Id, ctx: Context<'_, S>) {
236 self.inner.on_id_change(old, new, ctx);
237 }
238
239 #[inline]
240 fn on_layer(&mut self, subscriber: &mut S) {
241 self.inner.on_layer(subscriber);
242 }
243}
244
245impl<S, L> RateLimitedLayer<S, L>
246where
247 S: Subscriber,
248 L: Layer<S>,
249{
250 fn create_event(
251 &self,
252 ctx: &Context<S>,
253 metadata: &'static Metadata<'static>,
254 message: String,
255 rate_limit: u64,
256 ) {
257 let fields = metadata.fields();
258
259 let message = display(message);
260
261 if let Some(message_field) = fields.field("message") {
262 let values = [(&message_field, Some(&message as &dyn Value))];
263
264 let valueset = fields.value_set(&values);
265 let event = Event::new(metadata, &valueset);
266 self.inner.on_event(&event, ctx.clone());
267 } else {
268 let values = [(
269 &fields.field(RATE_LIMIT_FIELD).unwrap(),
270 Some(&rate_limit as &dyn Value),
271 )];
272
273 let valueset = fields.value_set(&values);
274 let event = Event::new(metadata, &valueset);
275 self.inner.on_event(&event, ctx.clone());
276 }
277 }
278}
279
280#[derive(Debug)]
281struct State {
282 start: Instant,
283 count: u64,
284 limit: u64,
285 message: String,
286}
287
288impl State {
289 fn new(message: String, limit: u64) -> Self {
290 Self {
291 start: Instant::now(),
292 count: 0,
293 limit,
294 message,
295 }
296 }
297
298 fn reset(&mut self) {
299 self.start = Instant::now();
300 self.count = 1;
301 }
302
303 fn increment_count(&mut self) -> u64 {
304 let prev = self.count;
305 self.count += 1;
306 prev
307 }
308
309 fn should_limit(&self) -> bool {
310 self.start.elapsed().as_secs() < self.limit
311 }
312}
313
314#[derive(PartialEq, Eq, Clone, Hash)]
315enum TraceValue {
316 String(String),
317 Int(i64),
318 Uint(u64),
319 Bool(bool),
320}
321
322impl From<bool> for TraceValue {
323 fn from(b: bool) -> Self {
324 TraceValue::Bool(b)
325 }
326}
327
328impl From<i64> for TraceValue {
329 fn from(i: i64) -> Self {
330 TraceValue::Int(i)
331 }
332}
333
334impl From<u64> for TraceValue {
335 fn from(u: u64) -> Self {
336 TraceValue::Uint(u)
337 }
338}
339
340impl From<String> for TraceValue {
341 fn from(s: String) -> Self {
342 TraceValue::String(s)
343 }
344}
345
346#[derive(Default, Eq, PartialEq, Hash, Clone)]
350struct RateLimitedSpanKeys {
351 component_id: Option<TraceValue>,
352 vrl_position: Option<TraceValue>,
353}
354
355impl RateLimitedSpanKeys {
356 fn record(&mut self, field: &Field, value: TraceValue) {
357 match field.name() {
358 COMPONENT_ID_FIELD => self.component_id = Some(value),
359 VRL_POSITION => self.vrl_position = Some(value),
360 _ => {}
361 }
362 }
363
364 fn merge(&mut self, other: &Self) {
365 if let Some(component_id) = &other.component_id {
366 self.component_id = Some(component_id.clone());
367 }
368 if let Some(vrl_position) = &other.vrl_position {
369 self.vrl_position = Some(vrl_position.clone());
370 }
371 }
372}
373
374impl Visit for RateLimitedSpanKeys {
375 fn record_i64(&mut self, field: &Field, value: i64) {
376 self.record(field, value.into());
377 }
378
379 fn record_u64(&mut self, field: &Field, value: u64) {
380 self.record(field, value.into());
381 }
382
383 fn record_bool(&mut self, field: &Field, value: bool) {
384 self.record(field, value.into());
385 }
386
387 fn record_str(&mut self, field: &Field, value: &str) {
388 self.record(field, value.to_owned().into());
389 }
390
391 fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) {
392 self.record(field, format!("{value:?}").into());
393 }
394}
395
396#[derive(Default)]
397struct LimitVisitor {
398 pub limit: Option<bool>,
399 pub limit_secs: Option<u64>,
400}
401
402impl Visit for LimitVisitor {
403 fn record_bool(&mut self, field: &Field, value: bool) {
404 if field.name() == RATE_LIMIT_FIELD {
405 self.limit = Some(value);
406 }
407 }
408
409 fn record_i64(&mut self, field: &Field, value: i64) {
410 if field.name() == RATE_LIMIT_SECS_FIELD {
411 self.limit = Some(true); self.limit_secs = Some(u64::try_from(value).unwrap_or_default()); }
414 }
415
416 fn record_u64(&mut self, field: &Field, value: u64) {
417 if field.name() == RATE_LIMIT_SECS_FIELD {
418 self.limit = Some(true); self.limit_secs = Some(value); }
421 }
422
423 fn record_debug(&mut self, _field: &Field, _value: &dyn fmt::Debug) {}
424}
425
426#[derive(Default)]
427struct MessageVisitor {
428 pub message: Option<String>,
429}
430
431impl Visit for MessageVisitor {
432 fn record_str(&mut self, field: &Field, value: &str) {
433 if self.message.is_none() && field.name() == MESSAGE_FIELD {
434 self.message = Some(value.to_string());
435 }
436 }
437
438 fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) {
439 if self.message.is_none() && field.name() == MESSAGE_FIELD {
440 self.message = Some(format!("{value:?}"));
441 }
442 }
443}
444
445#[cfg(test)]
446mod test {
447 use std::{
448 sync::{Arc, Mutex},
449 time::Duration,
450 };
451
452 use mock_instant::global::MockClock;
453 use tracing_subscriber::layer::SubscriberExt;
454
455 use super::*;
456
457 #[derive(Default)]
458 struct RecordingLayer<S> {
459 events: Arc<Mutex<Vec<String>>>,
460
461 _subscriber: std::marker::PhantomData<S>,
462 }
463
464 impl<S> RecordingLayer<S> {
465 fn new(events: Arc<Mutex<Vec<String>>>) -> Self {
466 RecordingLayer {
467 events,
468
469 _subscriber: std::marker::PhantomData,
470 }
471 }
472 }
473
474 impl<S> Layer<S> for RecordingLayer<S>
475 where
476 S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
477 {
478 fn register_callsite(&self, _metadata: &'static Metadata<'static>) -> Interest {
479 Interest::always()
480 }
481
482 fn enabled(&self, _metadata: &Metadata<'_>, _ctx: Context<'_, S>) -> bool {
483 true
484 }
485
486 fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) {
487 let mut visitor = MessageVisitor::default();
488 event.record(&mut visitor);
489
490 let mut events = self.events.lock().unwrap();
491 events.push(visitor.message.unwrap_or_default());
492 }
493 }
494
495 #[test]
496 fn rate_limits() {
497 let events: Arc<Mutex<Vec<String>>> = Default::default();
498
499 let recorder = RecordingLayer::new(Arc::clone(&events));
500 let sub = tracing_subscriber::registry::Registry::default()
501 .with(RateLimitedLayer::new(recorder).with_default_limit(1));
502 tracing::subscriber::with_default(sub, || {
503 for _ in 0..21 {
504 info!(message = "Hello world!", internal_log_rate_limit = true);
505 MockClock::advance(Duration::from_millis(100));
506 }
507 });
508
509 let events = events.lock().unwrap();
510
511 assert_eq!(
512 *events,
513 vec![
514 "Hello world!",
515 "Internal log [Hello world!] is being suppressed to avoid flooding.",
516 "Internal log [Hello world!] has been suppressed 9 times.",
517 "Hello world!",
518 "Internal log [Hello world!] is being suppressed to avoid flooding.",
519 "Internal log [Hello world!] has been suppressed 9 times.",
520 "Hello world!",
521 ]
522 .into_iter()
523 .map(std::borrow::ToOwned::to_owned)
524 .collect::<Vec<String>>()
525 );
526 }
527
528 #[test]
529 fn override_rate_limit_at_callsite() {
530 let events: Arc<Mutex<Vec<String>>> = Default::default();
531
532 let recorder = RecordingLayer::new(Arc::clone(&events));
533 let sub = tracing_subscriber::registry::Registry::default()
534 .with(RateLimitedLayer::new(recorder).with_default_limit(100));
535 tracing::subscriber::with_default(sub, || {
536 for _ in 0..21 {
537 info!(
538 message = "Hello world!",
539 internal_log_rate_limit = true,
540 internal_log_rate_secs = 1
541 );
542 MockClock::advance(Duration::from_millis(100));
543 }
544 });
545
546 let events = events.lock().unwrap();
547
548 assert_eq!(
549 *events,
550 vec![
551 "Hello world!",
552 "Internal log [Hello world!] is being suppressed to avoid flooding.",
553 "Internal log [Hello world!] has been suppressed 9 times.",
554 "Hello world!",
555 "Internal log [Hello world!] is being suppressed to avoid flooding.",
556 "Internal log [Hello world!] has been suppressed 9 times.",
557 "Hello world!",
558 ]
559 .into_iter()
560 .map(std::borrow::ToOwned::to_owned)
561 .collect::<Vec<String>>()
562 );
563 }
564
565 #[test]
566 fn rate_limit_by_span_key() {
567 let events: Arc<Mutex<Vec<String>>> = Default::default();
568
569 let recorder = RecordingLayer::new(Arc::clone(&events));
570 let sub = tracing_subscriber::registry::Registry::default()
571 .with(RateLimitedLayer::new(recorder).with_default_limit(1));
572 tracing::subscriber::with_default(sub, || {
573 for _ in 0..21 {
574 for key in &["foo", "bar"] {
575 for line_number in &[1, 2] {
576 let span =
577 info_span!("span", component_id = &key, vrl_position = &line_number);
578 let _enter = span.enter();
579 info!(
580 message = format!("Hello {key} on line_number {line_number}!").as_str(),
581 internal_log_rate_limit = true
582 );
583 }
584 }
585 MockClock::advance(Duration::from_millis(100));
586 }
587 });
588
589 let events = events.lock().unwrap();
590
591 assert_eq!(
592 *events,
593 vec![
594 "Hello foo on line_number 1!",
595 "Hello foo on line_number 2!",
596 "Hello bar on line_number 1!",
597 "Hello bar on line_number 2!",
598 "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
599 "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
600 "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
601 "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
602 "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
603 "Hello foo on line_number 1!",
604 "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
605 "Hello foo on line_number 2!",
606 "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
607 "Hello bar on line_number 1!",
608 "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
609 "Hello bar on line_number 2!",
610 "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
611 "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
612 "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
613 "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
614 "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
615 "Hello foo on line_number 1!",
616 "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
617 "Hello foo on line_number 2!",
618 "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
619 "Hello bar on line_number 1!",
620 "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
621 "Hello bar on line_number 2!",
622 ]
623 .into_iter()
624 .map(std::borrow::ToOwned::to_owned)
625 .collect::<Vec<String>>()
626 );
627 }
628
629 #[test]
630 fn rate_limit_by_event_key() {
631 let events: Arc<Mutex<Vec<String>>> = Default::default();
632
633 let recorder = RecordingLayer::new(Arc::clone(&events));
634 let sub = tracing_subscriber::registry::Registry::default()
635 .with(RateLimitedLayer::new(recorder).with_default_limit(1));
636 tracing::subscriber::with_default(sub, || {
637 for _ in 0..21 {
638 for key in &["foo", "bar"] {
639 for line_number in &[1, 2] {
640 info!(
641 message = format!("Hello {key} on line_number {line_number}!").as_str(),
642 internal_log_rate_limit = true,
643 component_id = &key,
644 vrl_position = &line_number
645 );
646 }
647 }
648 MockClock::advance(Duration::from_millis(100));
649 }
650 });
651
652 let events = events.lock().unwrap();
653
654 assert_eq!(
655 *events,
656 vec![
657 "Hello foo on line_number 1!",
658 "Hello foo on line_number 2!",
659 "Hello bar on line_number 1!",
660 "Hello bar on line_number 2!",
661 "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
662 "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
663 "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
664 "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
665 "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
666 "Hello foo on line_number 1!",
667 "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
668 "Hello foo on line_number 2!",
669 "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
670 "Hello bar on line_number 1!",
671 "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
672 "Hello bar on line_number 2!",
673 "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
674 "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
675 "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
676 "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
677 "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
678 "Hello foo on line_number 1!",
679 "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
680 "Hello foo on line_number 2!",
681 "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
682 "Hello bar on line_number 1!",
683 "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
684 "Hello bar on line_number 2!",
685 ]
686 .into_iter()
687 .map(std::borrow::ToOwned::to_owned)
688 .collect::<Vec<String>>()
689 );
690 }
691}