vector/internal_telemetry/allocations/allocator/
tracing.rs

1use std::{any::TypeId, marker::PhantomData, ptr::addr_of};
2
3use tracing::{Dispatch, Id, Subscriber};
4use tracing_subscriber::{layer::Context, registry::LookupSpan, Layer};
5
6use super::AllocationGroupToken;
7
8pub(crate) struct WithAllocationGroup {
9    pub with_allocation_group: fn(&Dispatch, &Id, AllocationGroupToken),
10}
11
12/// [`AllocationLayer`] is a [`tracing_subscriber::Layer`] that handles entering and exiting an allocation
13/// group as the span it is attached to is itself entered and exited.
14///
15/// More information on using this layer can be found in the examples, or directly in the
16/// `tracing_subscriber` docs, found [here][tracing_subscriber::layer].
17#[cfg_attr(docsrs, doc(cfg(feature = "tracing-compat")))]
18pub struct AllocationLayer<S> {
19    ctx: WithAllocationGroup,
20    _subscriber: PhantomData<S>,
21}
22
23impl<S> AllocationLayer<S>
24where
25    S: Subscriber + for<'span> LookupSpan<'span>,
26{
27    /// Creates a new [`AllocationLayer`].
28    #[must_use]
29    pub fn new() -> Self {
30        let ctx = WithAllocationGroup {
31            with_allocation_group: Self::with_allocation_group,
32        };
33
34        Self {
35            ctx,
36            _subscriber: PhantomData,
37        }
38    }
39
40    fn with_allocation_group(dispatch: &Dispatch, id: &Id, unsafe_token: AllocationGroupToken) {
41        let subscriber = dispatch
42            .downcast_ref::<S>()
43            .expect("subscriber should downcast to expected type; this is a bug!");
44        let span = subscriber
45            .span(id)
46            .expect("registry should have a span for the current ID");
47
48        span.extensions_mut().insert(unsafe_token);
49    }
50}
51
52impl<S> Layer<S> for AllocationLayer<S>
53where
54    S: Subscriber + for<'a> LookupSpan<'a>,
55{
56    fn on_enter(&self, id: &Id, ctx: Context<'_, S>) {
57        if let Some(span_ref) = ctx.span(id) {
58            if let Some(token) = span_ref.extensions().get::<AllocationGroupToken>() {
59                token.enter();
60            }
61        }
62    }
63
64    fn on_exit(&self, id: &Id, ctx: Context<'_, S>) {
65        if let Some(span_ref) = ctx.span(id) {
66            if let Some(token) = span_ref.extensions().get::<AllocationGroupToken>() {
67                token.exit();
68            }
69        }
70    }
71
72    unsafe fn downcast_raw(&self, id: TypeId) -> Option<*const ()> {
73        match id {
74            id if id == TypeId::of::<Self>() => Some(addr_of!(self).cast::<()>()),
75            id if id == TypeId::of::<WithAllocationGroup>() => {
76                Some(addr_of!(self.ctx).cast::<()>())
77            }
78            _ => None,
79        }
80    }
81}
82
83impl<S> Default for AllocationLayer<S>
84where
85    S: Subscriber + for<'span> LookupSpan<'span>,
86{
87    fn default() -> Self {
88        AllocationLayer::new()
89    }
90}