1#![deny(warnings)]
2
3use std::fmt;
4
5use dashmap::DashMap;
6use tracing_core::{
7 Event, Metadata, Subscriber,
8 callsite::Identifier,
9 field::{Field, Value, Visit, display},
10 span,
11 subscriber::Interest,
12};
13use tracing_subscriber::layer::{Context, Layer};
14
15#[cfg(test)]
16#[macro_use]
17extern crate tracing;
18
19#[cfg(not(test))]
20use std::time::Instant;
21
22#[cfg(test)]
23use mock_instant::global::Instant;
24
25const RATE_LIMIT_FIELD: &str = "internal_log_rate_limit";
26const RATE_LIMIT_SECS_FIELD: &str = "internal_log_rate_secs";
27const MESSAGE_FIELD: &str = "message";
28
29const COMPONENT_ID_FIELD: &str = "component_id";
32const VRL_POSITION: &str = "vrl_position";
33
34#[derive(Eq, PartialEq, Hash, Clone)]
35struct RateKeyIdentifier {
36 callsite: Identifier,
37 rate_limit_key_values: RateLimitedSpanKeys,
38}
39
40pub struct RateLimitedLayer<S, L>
41where
42 L: Layer<S> + Sized,
43 S: Subscriber,
44{
45 events: DashMap<RateKeyIdentifier, State>,
46 inner: L,
47 internal_log_rate_limit: u64,
48 _subscriber: std::marker::PhantomData<S>,
49}
50
51impl<S, L> RateLimitedLayer<S, L>
52where
53 L: Layer<S> + Sized,
54 S: Subscriber,
55{
56 pub fn new(layer: L) -> Self {
57 RateLimitedLayer {
58 events: Default::default(),
59 internal_log_rate_limit: 10,
60 inner: layer,
61 _subscriber: std::marker::PhantomData,
62 }
63 }
64
65 pub fn with_default_limit(mut self, internal_log_rate_limit: u64) -> Self {
66 self.internal_log_rate_limit = internal_log_rate_limit;
67 self
68 }
69}
70
71impl<S, L> Layer<S> for RateLimitedLayer<S, L>
72where
73 L: Layer<S>,
74 S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
75{
76 #[inline]
77 fn register_callsite(&self, metadata: &'static Metadata<'static>) -> Interest {
78 self.inner.register_callsite(metadata)
79 }
80
81 #[inline]
82 fn enabled(&self, metadata: &Metadata<'_>, ctx: Context<'_, S>) -> bool {
83 self.inner.enabled(metadata, ctx)
84 }
85
86 fn on_new_span(&self, attrs: &span::Attributes<'_>, id: &span::Id, ctx: Context<'_, S>) {
88 {
89 let span = ctx.span(id).expect("Span not found, this is a bug");
90 let mut extensions = span.extensions_mut();
91
92 if extensions.get_mut::<RateLimitedSpanKeys>().is_none() {
93 let mut fields = RateLimitedSpanKeys::default();
94 attrs.record(&mut fields);
95 extensions.insert(fields);
96 };
97 }
98 self.inner.on_new_span(attrs, id, ctx);
99 }
100
101 fn on_record(&self, id: &span::Id, values: &span::Record<'_>, ctx: Context<'_, S>) {
103 {
104 let span = ctx.span(id).expect("Span not found, this is a bug");
105 let mut extensions = span.extensions_mut();
106
107 match extensions.get_mut::<RateLimitedSpanKeys>() {
108 Some(fields) => {
109 values.record(fields);
110 }
111 None => {
112 let mut fields = RateLimitedSpanKeys::default();
113 values.record(&mut fields);
114 extensions.insert(fields);
115 }
116 };
117 }
118 self.inner.on_record(id, values, ctx);
119 }
120
121 #[inline]
122 fn on_follows_from(&self, span: &span::Id, follows: &span::Id, ctx: Context<'_, S>) {
123 self.inner.on_follows_from(span, follows, ctx);
124 }
125
126 fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
127 let mut limit_visitor = LimitVisitor::default();
130 event.record(&mut limit_visitor);
131
132 let limit_exists = limit_visitor.limit.unwrap_or(false);
133 if !limit_exists {
134 return self.inner.on_event(event, ctx);
135 }
136
137 let limit = match limit_visitor.limit_secs {
138 Some(limit_secs) => limit_secs, None => self.internal_log_rate_limit,
140 };
141
142 let rate_limit_key_values = {
146 let mut keys = RateLimitedSpanKeys::default();
147 event.record(&mut keys);
148
149 ctx.lookup_current()
150 .into_iter()
151 .flat_map(|span| span.scope().from_root())
152 .fold(keys, |mut keys, span| {
153 let extensions = span.extensions();
154 if let Some(span_keys) = extensions.get::<RateLimitedSpanKeys>() {
155 keys.merge(span_keys);
156 }
157 keys
158 })
159 };
160
161 let metadata = event.metadata();
164 let id = RateKeyIdentifier {
165 callsite: metadata.callsite(),
166 rate_limit_key_values,
167 };
168
169 let mut state = self.events.entry(id).or_insert_with(|| {
170 let mut message_visitor = MessageVisitor::default();
171 event.record(&mut message_visitor);
172
173 let message = message_visitor
174 .message
175 .unwrap_or_else(|| metadata.name().into());
176
177 State::new(message, limit)
178 });
179
180 let previous_count = state.increment_count();
186 if state.should_limit() {
187 match previous_count {
188 0 => self.inner.on_event(event, ctx),
189 1 => {
190 let message = format!(
191 "Internal log [{}] is being suppressed to avoid flooding.",
192 state.message
193 );
194 self.create_event(&ctx, metadata, message, state.limit);
195 }
196 _ => {}
197 }
198 } else {
199 if previous_count > 1 {
202 let message = format!(
203 "Internal log [{}] has been suppressed {} times.",
204 state.message,
205 previous_count - 1
206 );
207
208 self.create_event(&ctx, metadata, message, state.limit);
209 }
210
211 self.inner.on_event(event, ctx);
214
215 state.reset();
216 }
217 }
218
219 #[inline]
220 fn on_enter(&self, id: &span::Id, ctx: Context<'_, S>) {
221 self.inner.on_enter(id, ctx);
222 }
223
224 #[inline]
225 fn on_exit(&self, id: &span::Id, ctx: Context<'_, S>) {
226 self.inner.on_exit(id, ctx);
227 }
228
229 #[inline]
230 fn on_close(&self, id: span::Id, ctx: Context<'_, S>) {
231 self.inner.on_close(id, ctx);
232 }
233
234 #[inline]
235 fn on_id_change(&self, old: &span::Id, new: &span::Id, ctx: Context<'_, S>) {
236 self.inner.on_id_change(old, new, ctx);
237 }
238
239 #[inline]
240 fn on_layer(&mut self, subscriber: &mut S) {
241 self.inner.on_layer(subscriber);
242 }
243}
244
245impl<S, L> RateLimitedLayer<S, L>
246where
247 S: Subscriber,
248 L: Layer<S>,
249{
250 fn create_event(
251 &self,
252 ctx: &Context<S>,
253 metadata: &'static Metadata<'static>,
254 message: String,
255 rate_limit: u64,
256 ) {
257 let fields = metadata.fields();
258
259 let message = display(message);
260
261 if let Some(message_field) = fields.field("message") {
262 let values = [(&message_field, Some(&message as &dyn Value))];
263
264 let valueset = fields.value_set(&values);
265 let event = Event::new(metadata, &valueset);
266 self.inner.on_event(&event, ctx.clone());
267 } else {
268 let values = [(
269 &fields.field(RATE_LIMIT_FIELD).unwrap(),
270 Some(&rate_limit as &dyn Value),
271 )];
272
273 let valueset = fields.value_set(&values);
274 let event = Event::new(metadata, &valueset);
275 self.inner.on_event(&event, ctx.clone());
276 }
277 }
278}
279
280#[derive(Debug)]
281struct State {
282 start: Instant,
283 count: u64,
284 limit: u64,
285 message: String,
286}
287
288impl State {
289 fn new(message: String, limit: u64) -> Self {
290 Self {
291 start: Instant::now(),
292 count: 0,
293 limit,
294 message,
295 }
296 }
297
298 fn reset(&mut self) {
299 self.start = Instant::now();
300 self.count = 1;
301 }
302
303 fn increment_count(&mut self) -> u64 {
304 let prev = self.count;
305 self.count += 1;
306 prev
307 }
308
309 fn should_limit(&self) -> bool {
310 self.start.elapsed().as_secs() < self.limit
311 }
312}
313
314#[derive(PartialEq, Eq, Clone, Hash)]
315enum TraceValue {
316 String(String),
317 Int(i64),
318 Uint(u64),
319 Bool(bool),
320}
321
322impl From<bool> for TraceValue {
323 fn from(b: bool) -> Self {
324 TraceValue::Bool(b)
325 }
326}
327
328impl From<i64> for TraceValue {
329 fn from(i: i64) -> Self {
330 TraceValue::Int(i)
331 }
332}
333
334impl From<u64> for TraceValue {
335 fn from(u: u64) -> Self {
336 TraceValue::Uint(u)
337 }
338}
339
340impl From<String> for TraceValue {
341 fn from(s: String) -> Self {
342 TraceValue::String(s)
343 }
344}
345
346#[derive(Default, Eq, PartialEq, Hash, Clone)]
350struct RateLimitedSpanKeys {
351 component_id: Option<TraceValue>,
352 vrl_position: Option<TraceValue>,
353}
354
355impl RateLimitedSpanKeys {
356 fn record(&mut self, field: &Field, value: TraceValue) {
357 match field.name() {
358 COMPONENT_ID_FIELD => self.component_id = Some(value),
359 VRL_POSITION => self.vrl_position = Some(value),
360 _ => {}
361 }
362 }
363
364 fn merge(&mut self, other: &Self) {
365 if let Some(component_id) = &other.component_id {
366 self.component_id = Some(component_id.clone());
367 }
368 if let Some(vrl_position) = &other.vrl_position {
369 self.vrl_position = Some(vrl_position.clone());
370 }
371 }
372}
373
374impl Visit for RateLimitedSpanKeys {
375 fn record_i64(&mut self, field: &Field, value: i64) {
376 self.record(field, value.into());
377 }
378
379 fn record_u64(&mut self, field: &Field, value: u64) {
380 self.record(field, value.into());
381 }
382
383 fn record_bool(&mut self, field: &Field, value: bool) {
384 self.record(field, value.into());
385 }
386
387 fn record_str(&mut self, field: &Field, value: &str) {
388 self.record(field, value.to_owned().into());
389 }
390
391 fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) {
392 self.record(field, format!("{value:?}").into());
393 }
394}
395
396#[derive(Default)]
397struct LimitVisitor {
398 pub limit: Option<bool>,
399 pub limit_secs: Option<u64>,
400}
401
402impl Visit for LimitVisitor {
403 fn record_bool(&mut self, field: &Field, value: bool) {
404 if field.name() == RATE_LIMIT_FIELD {
405 self.limit = Some(value);
406 }
407 }
408
409 fn record_i64(&mut self, field: &Field, value: i64) {
410 if field.name() == RATE_LIMIT_SECS_FIELD {
411 self.limit = Some(true); self.limit_secs = Some(u64::try_from(value).unwrap_or_default()); }
414 }
415
416 fn record_u64(&mut self, field: &Field, value: u64) {
417 if field.name() == RATE_LIMIT_SECS_FIELD {
418 self.limit = Some(true); self.limit_secs = Some(value); }
421 }
422
423 fn record_debug(&mut self, _field: &Field, _value: &dyn fmt::Debug) {}
424}
425
426#[derive(Default)]
427struct MessageVisitor {
428 pub message: Option<String>,
429}
430
431impl Visit for MessageVisitor {
432 fn record_str(&mut self, field: &Field, value: &str) {
433 if self.message.is_none() && field.name() == MESSAGE_FIELD {
434 self.message = Some(value.to_string());
435 }
436 }
437
438 fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) {
439 if self.message.is_none() && field.name() == MESSAGE_FIELD {
440 self.message = Some(format!("{value:?}"));
441 }
442 }
443}
444
445#[cfg(test)]
446mod test {
447 use std::{
448 sync::{Arc, LazyLock, Mutex},
449 time::Duration,
450 };
451
452 use mock_instant::global::MockClock;
453 use tracing_subscriber::layer::SubscriberExt;
454
455 static TRACING_DEFAULT_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
456
457 use super::*;
458
459 #[derive(Default)]
460 struct RecordingLayer<S> {
461 events: Arc<Mutex<Vec<String>>>,
462
463 _subscriber: std::marker::PhantomData<S>,
464 }
465
466 impl<S> RecordingLayer<S> {
467 fn new(events: Arc<Mutex<Vec<String>>>) -> Self {
468 RecordingLayer {
469 events,
470
471 _subscriber: std::marker::PhantomData,
472 }
473 }
474 }
475
476 impl<S> Layer<S> for RecordingLayer<S>
477 where
478 S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
479 {
480 fn register_callsite(&self, _metadata: &'static Metadata<'static>) -> Interest {
481 Interest::always()
482 }
483
484 fn enabled(&self, _metadata: &Metadata<'_>, _ctx: Context<'_, S>) -> bool {
485 true
486 }
487
488 fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) {
489 let mut visitor = MessageVisitor::default();
490 event.record(&mut visitor);
491
492 let mut events = self.events.lock().unwrap();
493 events.push(visitor.message.unwrap_or_default());
494 }
495 }
496
497 #[test]
498 fn rate_limits() {
499 let _guard = TRACING_DEFAULT_LOCK.lock().unwrap();
500
501 let events: Arc<Mutex<Vec<String>>> = Default::default();
502
503 let recorder = RecordingLayer::new(Arc::clone(&events));
504 let sub = tracing_subscriber::registry::Registry::default()
505 .with(RateLimitedLayer::new(recorder).with_default_limit(1));
506 tracing::subscriber::with_default(sub, || {
507 for _ in 0..21 {
508 info!(message = "Hello world!", internal_log_rate_limit = true);
509 MockClock::advance(Duration::from_millis(100));
510 }
511 });
512
513 let events = events.lock().unwrap();
514
515 assert_eq!(
516 *events,
517 vec![
518 "Hello world!",
519 "Internal log [Hello world!] is being suppressed to avoid flooding.",
520 "Internal log [Hello world!] has been suppressed 9 times.",
521 "Hello world!",
522 "Internal log [Hello world!] is being suppressed to avoid flooding.",
523 "Internal log [Hello world!] has been suppressed 9 times.",
524 "Hello world!",
525 ]
526 .into_iter()
527 .map(std::borrow::ToOwned::to_owned)
528 .collect::<Vec<String>>()
529 );
530 }
531
532 #[test]
533 fn override_rate_limit_at_callsite() {
534 let _guard = TRACING_DEFAULT_LOCK.lock().unwrap();
535
536 let events: Arc<Mutex<Vec<String>>> = Default::default();
537
538 let recorder = RecordingLayer::new(Arc::clone(&events));
539 let sub = tracing_subscriber::registry::Registry::default()
540 .with(RateLimitedLayer::new(recorder).with_default_limit(100));
541 tracing::subscriber::with_default(sub, || {
542 for _ in 0..21 {
543 info!(
544 message = "Hello world!",
545 internal_log_rate_limit = true,
546 internal_log_rate_secs = 1
547 );
548 MockClock::advance(Duration::from_millis(100));
549 }
550 });
551
552 let events = events.lock().unwrap();
553
554 assert_eq!(
555 *events,
556 vec![
557 "Hello world!",
558 "Internal log [Hello world!] is being suppressed to avoid flooding.",
559 "Internal log [Hello world!] has been suppressed 9 times.",
560 "Hello world!",
561 "Internal log [Hello world!] is being suppressed to avoid flooding.",
562 "Internal log [Hello world!] has been suppressed 9 times.",
563 "Hello world!",
564 ]
565 .into_iter()
566 .map(std::borrow::ToOwned::to_owned)
567 .collect::<Vec<String>>()
568 );
569 }
570
571 #[test]
572 fn rate_limit_by_span_key() {
573 let _guard = TRACING_DEFAULT_LOCK.lock().unwrap();
574
575 let events: Arc<Mutex<Vec<String>>> = Default::default();
576
577 let recorder = RecordingLayer::new(Arc::clone(&events));
578 let sub = tracing_subscriber::registry::Registry::default()
579 .with(RateLimitedLayer::new(recorder).with_default_limit(1));
580 tracing::subscriber::with_default(sub, || {
581 for _ in 0..21 {
582 for key in &["foo", "bar"] {
583 for line_number in &[1, 2] {
584 let span =
585 info_span!("span", component_id = &key, vrl_position = &line_number);
586 let _enter = span.enter();
587 info!(
588 message = format!("Hello {key} on line_number {line_number}!").as_str(),
589 internal_log_rate_limit = true
590 );
591 }
592 }
593 MockClock::advance(Duration::from_millis(100));
594 }
595 });
596
597 let events = events.lock().unwrap();
598
599 assert_eq!(
600 *events,
601 vec![
602 "Hello foo on line_number 1!",
603 "Hello foo on line_number 2!",
604 "Hello bar on line_number 1!",
605 "Hello bar on line_number 2!",
606 "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
607 "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
608 "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
609 "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
610 "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
611 "Hello foo on line_number 1!",
612 "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
613 "Hello foo on line_number 2!",
614 "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
615 "Hello bar on line_number 1!",
616 "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
617 "Hello bar on line_number 2!",
618 "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
619 "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
620 "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
621 "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
622 "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
623 "Hello foo on line_number 1!",
624 "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
625 "Hello foo on line_number 2!",
626 "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
627 "Hello bar on line_number 1!",
628 "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
629 "Hello bar on line_number 2!",
630 ]
631 .into_iter()
632 .map(std::borrow::ToOwned::to_owned)
633 .collect::<Vec<String>>()
634 );
635 }
636
637 #[test]
638 fn rate_limit_by_event_key() {
639 let _guard = TRACING_DEFAULT_LOCK.lock().unwrap();
640
641 let events: Arc<Mutex<Vec<String>>> = Default::default();
642
643 let recorder = RecordingLayer::new(Arc::clone(&events));
644 let sub = tracing_subscriber::registry::Registry::default()
645 .with(RateLimitedLayer::new(recorder).with_default_limit(1));
646 tracing::subscriber::with_default(sub, || {
647 for _ in 0..21 {
648 for key in &["foo", "bar"] {
649 for line_number in &[1, 2] {
650 info!(
651 message = format!("Hello {key} on line_number {line_number}!").as_str(),
652 internal_log_rate_limit = true,
653 component_id = &key,
654 vrl_position = &line_number
655 );
656 }
657 }
658 MockClock::advance(Duration::from_millis(100));
659 }
660 });
661
662 let events = events.lock().unwrap();
663
664 assert_eq!(
665 *events,
666 vec![
667 "Hello foo on line_number 1!",
668 "Hello foo on line_number 2!",
669 "Hello bar on line_number 1!",
670 "Hello bar on line_number 2!",
671 "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
672 "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
673 "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
674 "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
675 "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
676 "Hello foo on line_number 1!",
677 "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
678 "Hello foo on line_number 2!",
679 "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
680 "Hello bar on line_number 1!",
681 "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
682 "Hello bar on line_number 2!",
683 "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
684 "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
685 "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
686 "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
687 "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
688 "Hello foo on line_number 1!",
689 "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
690 "Hello foo on line_number 2!",
691 "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
692 "Hello bar on line_number 1!",
693 "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
694 "Hello bar on line_number 2!",
695 ]
696 .into_iter()
697 .map(std::borrow::ToOwned::to_owned)
698 .collect::<Vec<String>>()
699 );
700 }
701}