tracing_limit/
lib.rs

1#![deny(warnings)]
2
3use std::fmt;
4
5use dashmap::DashMap;
6use tracing_core::{
7    callsite::Identifier,
8    field::{display, Field, Value, Visit},
9    span,
10    subscriber::Interest,
11    Event, Metadata, Subscriber,
12};
13use tracing_subscriber::layer::{Context, Layer};
14
15#[cfg(test)]
16#[macro_use]
17extern crate tracing;
18
19#[cfg(not(test))]
20use std::time::Instant;
21
22#[cfg(test)]
23use mock_instant::global::Instant;
24
25const RATE_LIMIT_FIELD: &str = "internal_log_rate_limit";
26const RATE_LIMIT_SECS_FIELD: &str = "internal_log_rate_secs";
27const MESSAGE_FIELD: &str = "message";
28
29// These fields will cause events to be independently rate limited by the values
30// for these keys
31const COMPONENT_ID_FIELD: &str = "component_id";
32const VRL_POSITION: &str = "vrl_position";
33
34#[derive(Eq, PartialEq, Hash, Clone)]
35struct RateKeyIdentifier {
36    callsite: Identifier,
37    rate_limit_key_values: RateLimitedSpanKeys,
38}
39
40pub struct RateLimitedLayer<S, L>
41where
42    L: Layer<S> + Sized,
43    S: Subscriber,
44{
45    events: DashMap<RateKeyIdentifier, State>,
46    inner: L,
47    internal_log_rate_limit: u64,
48    _subscriber: std::marker::PhantomData<S>,
49}
50
51impl<S, L> RateLimitedLayer<S, L>
52where
53    L: Layer<S> + Sized,
54    S: Subscriber,
55{
56    pub fn new(layer: L) -> Self {
57        RateLimitedLayer {
58            events: Default::default(),
59            internal_log_rate_limit: 10,
60            inner: layer,
61            _subscriber: std::marker::PhantomData,
62        }
63    }
64
65    pub fn with_default_limit(mut self, internal_log_rate_limit: u64) -> Self {
66        self.internal_log_rate_limit = internal_log_rate_limit;
67        self
68    }
69}
70
71impl<S, L> Layer<S> for RateLimitedLayer<S, L>
72where
73    L: Layer<S>,
74    S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
75{
76    #[inline]
77    fn register_callsite(&self, metadata: &'static Metadata<'static>) -> Interest {
78        self.inner.register_callsite(metadata)
79    }
80
81    #[inline]
82    fn enabled(&self, metadata: &Metadata<'_>, ctx: Context<'_, S>) -> bool {
83        self.inner.enabled(metadata, ctx)
84    }
85
86    // keep track of any span fields we use for grouping rate limiting
87    fn on_new_span(&self, attrs: &span::Attributes<'_>, id: &span::Id, ctx: Context<'_, S>) {
88        {
89            let span = ctx.span(id).expect("Span not found, this is a bug");
90            let mut extensions = span.extensions_mut();
91
92            if extensions.get_mut::<RateLimitedSpanKeys>().is_none() {
93                let mut fields = RateLimitedSpanKeys::default();
94                attrs.record(&mut fields);
95                extensions.insert(fields);
96            };
97        }
98        self.inner.on_new_span(attrs, id, ctx);
99    }
100
101    // keep track of any span fields we use for grouping rate limiting
102    fn on_record(&self, id: &span::Id, values: &span::Record<'_>, ctx: Context<'_, S>) {
103        {
104            let span = ctx.span(id).expect("Span not found, this is a bug");
105            let mut extensions = span.extensions_mut();
106
107            match extensions.get_mut::<RateLimitedSpanKeys>() {
108                Some(fields) => {
109                    values.record(fields);
110                }
111                None => {
112                    let mut fields = RateLimitedSpanKeys::default();
113                    values.record(&mut fields);
114                    extensions.insert(fields);
115                }
116            };
117        }
118        self.inner.on_record(id, values, ctx);
119    }
120
121    #[inline]
122    fn on_follows_from(&self, span: &span::Id, follows: &span::Id, ctx: Context<'_, S>) {
123        self.inner.on_follows_from(span, follows, ctx);
124    }
125
126    fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
127        // Visit the event, grabbing the limit status if one is defined. If we can't find a rate limit field, or the rate limit
128        // is set as false, then we let it pass through untouched.
129        let mut limit_visitor = LimitVisitor::default();
130        event.record(&mut limit_visitor);
131
132        let limit_exists = limit_visitor.limit.unwrap_or(false);
133        if !limit_exists {
134            return self.inner.on_event(event, ctx);
135        }
136
137        let limit = match limit_visitor.limit_secs {
138            Some(limit_secs) => limit_secs, // override the cli limit
139            None => self.internal_log_rate_limit,
140        };
141
142        // Visit all of the spans in the scope of this event, looking for specific fields that we use to differentiate
143        // rate-limited events. This ensures that we don't rate limit an event's _callsite_, but the specific usage of a
144        // callsite, since multiple copies of the same component could be running, etc.
145        let rate_limit_key_values = {
146            let mut keys = RateLimitedSpanKeys::default();
147            event.record(&mut keys);
148
149            ctx.lookup_current()
150                .into_iter()
151                .flat_map(|span| span.scope().from_root())
152                .fold(keys, |mut keys, span| {
153                    let extensions = span.extensions();
154                    if let Some(span_keys) = extensions.get::<RateLimitedSpanKeys>() {
155                        keys.merge(span_keys);
156                    }
157                    keys
158                })
159        };
160
161        // Build the key to represent this event, given its span fields, and see if we're already rate limiting it. If
162        // not, we'll initialize an entry for it.
163        let metadata = event.metadata();
164        let id = RateKeyIdentifier {
165            callsite: metadata.callsite(),
166            rate_limit_key_values,
167        };
168
169        let mut state = self.events.entry(id).or_insert_with(|| {
170            let mut message_visitor = MessageVisitor::default();
171            event.record(&mut message_visitor);
172
173            let message = message_visitor
174                .message
175                .unwrap_or_else(|| metadata.name().into());
176
177            State::new(message, limit)
178        });
179
180        // Update our suppressed state for this event, and see if we should still be suppressing it.
181        //
182        // When this is the first time seeing the event, we emit it like we normally would. The second time we see it in
183        // the limit period, we emit a new event to indicate that the original event is being actively suppressed.
184        // Otherwise, we don't emit anything.
185        let previous_count = state.increment_count();
186        if state.should_limit() {
187            match previous_count {
188                0 => self.inner.on_event(event, ctx),
189                1 => {
190                    let message = format!(
191                        "Internal log [{}] is being suppressed to avoid flooding.",
192                        state.message
193                    );
194                    self.create_event(&ctx, metadata, message, state.limit);
195                }
196                _ => {}
197            }
198        } else {
199            // If we saw this event 3 or more times total, emit an event that indicates the total number of times we
200            // suppressed the event in the limit period.
201            if previous_count > 1 {
202                let message = format!(
203                    "Internal log [{}] has been suppressed {} times.",
204                    state.message,
205                    previous_count - 1
206                );
207
208                self.create_event(&ctx, metadata, message, state.limit);
209            }
210
211            // We're not suppressing anymore, so we also emit the current event as normal.. but we update our rate
212            // limiting state since this is effectively equivalent to seeing the event again for the first time.
213            self.inner.on_event(event, ctx);
214
215            state.reset();
216        }
217    }
218
219    #[inline]
220    fn on_enter(&self, id: &span::Id, ctx: Context<'_, S>) {
221        self.inner.on_enter(id, ctx);
222    }
223
224    #[inline]
225    fn on_exit(&self, id: &span::Id, ctx: Context<'_, S>) {
226        self.inner.on_exit(id, ctx);
227    }
228
229    #[inline]
230    fn on_close(&self, id: span::Id, ctx: Context<'_, S>) {
231        self.inner.on_close(id, ctx);
232    }
233
234    #[inline]
235    fn on_id_change(&self, old: &span::Id, new: &span::Id, ctx: Context<'_, S>) {
236        self.inner.on_id_change(old, new, ctx);
237    }
238
239    #[inline]
240    fn on_layer(&mut self, subscriber: &mut S) {
241        self.inner.on_layer(subscriber);
242    }
243}
244
245impl<S, L> RateLimitedLayer<S, L>
246where
247    S: Subscriber,
248    L: Layer<S>,
249{
250    fn create_event(
251        &self,
252        ctx: &Context<S>,
253        metadata: &'static Metadata<'static>,
254        message: String,
255        rate_limit: u64,
256    ) {
257        let fields = metadata.fields();
258
259        let message = display(message);
260
261        if let Some(message_field) = fields.field("message") {
262            let values = [(&message_field, Some(&message as &dyn Value))];
263
264            let valueset = fields.value_set(&values);
265            let event = Event::new(metadata, &valueset);
266            self.inner.on_event(&event, ctx.clone());
267        } else {
268            let values = [(
269                &fields.field(RATE_LIMIT_FIELD).unwrap(),
270                Some(&rate_limit as &dyn Value),
271            )];
272
273            let valueset = fields.value_set(&values);
274            let event = Event::new(metadata, &valueset);
275            self.inner.on_event(&event, ctx.clone());
276        }
277    }
278}
279
280#[derive(Debug)]
281struct State {
282    start: Instant,
283    count: u64,
284    limit: u64,
285    message: String,
286}
287
288impl State {
289    fn new(message: String, limit: u64) -> Self {
290        Self {
291            start: Instant::now(),
292            count: 0,
293            limit,
294            message,
295        }
296    }
297
298    fn reset(&mut self) {
299        self.start = Instant::now();
300        self.count = 1;
301    }
302
303    fn increment_count(&mut self) -> u64 {
304        let prev = self.count;
305        self.count += 1;
306        prev
307    }
308
309    fn should_limit(&self) -> bool {
310        self.start.elapsed().as_secs() < self.limit
311    }
312}
313
314#[derive(PartialEq, Eq, Clone, Hash)]
315enum TraceValue {
316    String(String),
317    Int(i64),
318    Uint(u64),
319    Bool(bool),
320}
321
322impl From<bool> for TraceValue {
323    fn from(b: bool) -> Self {
324        TraceValue::Bool(b)
325    }
326}
327
328impl From<i64> for TraceValue {
329    fn from(i: i64) -> Self {
330        TraceValue::Int(i)
331    }
332}
333
334impl From<u64> for TraceValue {
335    fn from(u: u64) -> Self {
336        TraceValue::Uint(u)
337    }
338}
339
340impl From<String> for TraceValue {
341    fn from(s: String) -> Self {
342        TraceValue::String(s)
343    }
344}
345
346/// RateLimitedSpanKeys records span keys that we use to rate limit callsites separately by. For
347/// example, if a given trace callsite is called from two different components, then they will be
348/// rate limited separately.
349#[derive(Default, Eq, PartialEq, Hash, Clone)]
350struct RateLimitedSpanKeys {
351    component_id: Option<TraceValue>,
352    vrl_position: Option<TraceValue>,
353}
354
355impl RateLimitedSpanKeys {
356    fn record(&mut self, field: &Field, value: TraceValue) {
357        match field.name() {
358            COMPONENT_ID_FIELD => self.component_id = Some(value),
359            VRL_POSITION => self.vrl_position = Some(value),
360            _ => {}
361        }
362    }
363
364    fn merge(&mut self, other: &Self) {
365        if let Some(component_id) = &other.component_id {
366            self.component_id = Some(component_id.clone());
367        }
368        if let Some(vrl_position) = &other.vrl_position {
369            self.vrl_position = Some(vrl_position.clone());
370        }
371    }
372}
373
374impl Visit for RateLimitedSpanKeys {
375    fn record_i64(&mut self, field: &Field, value: i64) {
376        self.record(field, value.into());
377    }
378
379    fn record_u64(&mut self, field: &Field, value: u64) {
380        self.record(field, value.into());
381    }
382
383    fn record_bool(&mut self, field: &Field, value: bool) {
384        self.record(field, value.into());
385    }
386
387    fn record_str(&mut self, field: &Field, value: &str) {
388        self.record(field, value.to_owned().into());
389    }
390
391    fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) {
392        self.record(field, format!("{value:?}").into());
393    }
394}
395
396#[derive(Default)]
397struct LimitVisitor {
398    pub limit: Option<bool>,
399    pub limit_secs: Option<u64>,
400}
401
402impl Visit for LimitVisitor {
403    fn record_bool(&mut self, field: &Field, value: bool) {
404        if field.name() == RATE_LIMIT_FIELD {
405            self.limit = Some(value);
406        }
407    }
408
409    fn record_i64(&mut self, field: &Field, value: i64) {
410        if field.name() == RATE_LIMIT_SECS_FIELD {
411            self.limit = Some(true); // limit if we have this field
412            self.limit_secs = Some(u64::try_from(value).unwrap_or_default()); // override the cli passed limit
413        }
414    }
415
416    fn record_u64(&mut self, field: &Field, value: u64) {
417        if field.name() == RATE_LIMIT_SECS_FIELD {
418            self.limit = Some(true); // limit if we have this field
419            self.limit_secs = Some(value); // override the cli passed limit
420        }
421    }
422
423    fn record_debug(&mut self, _field: &Field, _value: &dyn fmt::Debug) {}
424}
425
426#[derive(Default)]
427struct MessageVisitor {
428    pub message: Option<String>,
429}
430
431impl Visit for MessageVisitor {
432    fn record_str(&mut self, field: &Field, value: &str) {
433        if self.message.is_none() && field.name() == MESSAGE_FIELD {
434            self.message = Some(value.to_string());
435        }
436    }
437
438    fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) {
439        if self.message.is_none() && field.name() == MESSAGE_FIELD {
440            self.message = Some(format!("{value:?}"));
441        }
442    }
443}
444
445#[cfg(test)]
446mod test {
447    use std::{
448        sync::{Arc, Mutex},
449        time::Duration,
450    };
451
452    use mock_instant::global::MockClock;
453    use tracing_subscriber::layer::SubscriberExt;
454
455    use super::*;
456
457    #[derive(Default)]
458    struct RecordingLayer<S> {
459        events: Arc<Mutex<Vec<String>>>,
460
461        _subscriber: std::marker::PhantomData<S>,
462    }
463
464    impl<S> RecordingLayer<S> {
465        fn new(events: Arc<Mutex<Vec<String>>>) -> Self {
466            RecordingLayer {
467                events,
468
469                _subscriber: std::marker::PhantomData,
470            }
471        }
472    }
473
474    impl<S> Layer<S> for RecordingLayer<S>
475    where
476        S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
477    {
478        fn register_callsite(&self, _metadata: &'static Metadata<'static>) -> Interest {
479            Interest::always()
480        }
481
482        fn enabled(&self, _metadata: &Metadata<'_>, _ctx: Context<'_, S>) -> bool {
483            true
484        }
485
486        fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) {
487            let mut visitor = MessageVisitor::default();
488            event.record(&mut visitor);
489
490            let mut events = self.events.lock().unwrap();
491            events.push(visitor.message.unwrap_or_default());
492        }
493    }
494
495    #[test]
496    fn rate_limits() {
497        let events: Arc<Mutex<Vec<String>>> = Default::default();
498
499        let recorder = RecordingLayer::new(Arc::clone(&events));
500        let sub = tracing_subscriber::registry::Registry::default()
501            .with(RateLimitedLayer::new(recorder).with_default_limit(1));
502        tracing::subscriber::with_default(sub, || {
503            for _ in 0..21 {
504                info!(message = "Hello world!", internal_log_rate_limit = true);
505                MockClock::advance(Duration::from_millis(100));
506            }
507        });
508
509        let events = events.lock().unwrap();
510
511        assert_eq!(
512            *events,
513            vec![
514                "Hello world!",
515                "Internal log [Hello world!] is being suppressed to avoid flooding.",
516                "Internal log [Hello world!] has been suppressed 9 times.",
517                "Hello world!",
518                "Internal log [Hello world!] is being suppressed to avoid flooding.",
519                "Internal log [Hello world!] has been suppressed 9 times.",
520                "Hello world!",
521            ]
522            .into_iter()
523            .map(std::borrow::ToOwned::to_owned)
524            .collect::<Vec<String>>()
525        );
526    }
527
528    #[test]
529    fn override_rate_limit_at_callsite() {
530        let events: Arc<Mutex<Vec<String>>> = Default::default();
531
532        let recorder = RecordingLayer::new(Arc::clone(&events));
533        let sub = tracing_subscriber::registry::Registry::default()
534            .with(RateLimitedLayer::new(recorder).with_default_limit(100));
535        tracing::subscriber::with_default(sub, || {
536            for _ in 0..21 {
537                info!(
538                    message = "Hello world!",
539                    internal_log_rate_limit = true,
540                    internal_log_rate_secs = 1
541                );
542                MockClock::advance(Duration::from_millis(100));
543            }
544        });
545
546        let events = events.lock().unwrap();
547
548        assert_eq!(
549            *events,
550            vec![
551                "Hello world!",
552                "Internal log [Hello world!] is being suppressed to avoid flooding.",
553                "Internal log [Hello world!] has been suppressed 9 times.",
554                "Hello world!",
555                "Internal log [Hello world!] is being suppressed to avoid flooding.",
556                "Internal log [Hello world!] has been suppressed 9 times.",
557                "Hello world!",
558            ]
559            .into_iter()
560            .map(std::borrow::ToOwned::to_owned)
561            .collect::<Vec<String>>()
562        );
563    }
564
565    #[test]
566    fn rate_limit_by_span_key() {
567        let events: Arc<Mutex<Vec<String>>> = Default::default();
568
569        let recorder = RecordingLayer::new(Arc::clone(&events));
570        let sub = tracing_subscriber::registry::Registry::default()
571            .with(RateLimitedLayer::new(recorder).with_default_limit(1));
572        tracing::subscriber::with_default(sub, || {
573            for _ in 0..21 {
574                for key in &["foo", "bar"] {
575                    for line_number in &[1, 2] {
576                        let span =
577                            info_span!("span", component_id = &key, vrl_position = &line_number);
578                        let _enter = span.enter();
579                        info!(
580                            message = format!("Hello {key} on line_number {line_number}!").as_str(),
581                            internal_log_rate_limit = true
582                        );
583                    }
584                }
585                MockClock::advance(Duration::from_millis(100));
586            }
587        });
588
589        let events = events.lock().unwrap();
590
591        assert_eq!(
592            *events,
593            vec![
594                "Hello foo on line_number 1!",
595                "Hello foo on line_number 2!",
596                "Hello bar on line_number 1!",
597                "Hello bar on line_number 2!",
598                "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
599                "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
600                "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
601                "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
602                "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
603                "Hello foo on line_number 1!",
604                "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
605                "Hello foo on line_number 2!",
606                "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
607                "Hello bar on line_number 1!",
608                "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
609                "Hello bar on line_number 2!",
610                "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
611                "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
612                "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
613                "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
614                "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
615                "Hello foo on line_number 1!",
616                "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
617                "Hello foo on line_number 2!",
618                "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
619                "Hello bar on line_number 1!",
620                "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
621                "Hello bar on line_number 2!",
622            ]
623            .into_iter()
624            .map(std::borrow::ToOwned::to_owned)
625            .collect::<Vec<String>>()
626        );
627    }
628
629    #[test]
630    fn rate_limit_by_event_key() {
631        let events: Arc<Mutex<Vec<String>>> = Default::default();
632
633        let recorder = RecordingLayer::new(Arc::clone(&events));
634        let sub = tracing_subscriber::registry::Registry::default()
635            .with(RateLimitedLayer::new(recorder).with_default_limit(1));
636        tracing::subscriber::with_default(sub, || {
637            for _ in 0..21 {
638                for key in &["foo", "bar"] {
639                    for line_number in &[1, 2] {
640                        info!(
641                            message = format!("Hello {key} on line_number {line_number}!").as_str(),
642                            internal_log_rate_limit = true,
643                            component_id = &key,
644                            vrl_position = &line_number
645                        );
646                    }
647                }
648                MockClock::advance(Duration::from_millis(100));
649            }
650        });
651
652        let events = events.lock().unwrap();
653
654        assert_eq!(
655            *events,
656            vec![
657                "Hello foo on line_number 1!",
658                "Hello foo on line_number 2!",
659                "Hello bar on line_number 1!",
660                "Hello bar on line_number 2!",
661                "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
662                "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
663                "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
664                "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
665                "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
666                "Hello foo on line_number 1!",
667                "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
668                "Hello foo on line_number 2!",
669                "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
670                "Hello bar on line_number 1!",
671                "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
672                "Hello bar on line_number 2!",
673                "Internal log [Hello foo on line_number 1!] is being suppressed to avoid flooding.",
674                "Internal log [Hello foo on line_number 2!] is being suppressed to avoid flooding.",
675                "Internal log [Hello bar on line_number 1!] is being suppressed to avoid flooding.",
676                "Internal log [Hello bar on line_number 2!] is being suppressed to avoid flooding.",
677                "Internal log [Hello foo on line_number 1!] has been suppressed 9 times.",
678                "Hello foo on line_number 1!",
679                "Internal log [Hello foo on line_number 2!] has been suppressed 9 times.",
680                "Hello foo on line_number 2!",
681                "Internal log [Hello bar on line_number 1!] has been suppressed 9 times.",
682                "Hello bar on line_number 1!",
683                "Internal log [Hello bar on line_number 2!] has been suppressed 9 times.",
684                "Hello bar on line_number 2!",
685            ]
686            .into_iter()
687            .map(std::borrow::ToOwned::to_owned)
688            .collect::<Vec<String>>()
689        );
690    }
691}