vector/sources/host_metrics/
tcp.rs

1use crate::sources::host_metrics::HostMetricsScrapeDetailError;
2use byteorder::{ByteOrder, NativeEndian};
3use std::{collections::HashMap, io, path::Path};
4use vector_lib::event::MetricTags;
5
6use netlink_packet_core::{
7    NetlinkHeader, NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST,
8};
9use netlink_packet_sock_diag::{
10    constants::*,
11    inet::{ExtensionFlags, InetRequest, InetResponseHeader, SocketId, StateFlags},
12    SockDiagMessage,
13};
14use netlink_sys::{
15    protocols::NETLINK_SOCK_DIAG, AsyncSocket, AsyncSocketExt, SocketAddr, TokioSocket,
16};
17use snafu::{ResultExt, Snafu};
18
19use super::HostMetrics;
20
21const PROC_IPV6_FILE: &str = "/proc/net/if_inet6";
22const TCP_CONNS_TOTAL: &str = "tcp_connections_total";
23const TCP_TX_QUEUED_BYTES_TOTAL: &str = "tcp_tx_queued_bytes_total";
24const TCP_RX_QUEUED_BYTES_TOTAL: &str = "tcp_rx_queued_bytes_total";
25const STATE: &str = "state";
26
27impl HostMetrics {
28    pub async fn tcp_metrics(&self, output: &mut super::MetricsBuffer) {
29        match build_tcp_stats().await {
30            Ok(stats) => {
31                output.name = "tcp";
32                for (state, count) in stats.conn_states {
33                    let tags = metric_tags! {
34                        STATE => state
35                    };
36                    output.gauge(TCP_CONNS_TOTAL, count, tags);
37                }
38
39                output.gauge(
40                    TCP_TX_QUEUED_BYTES_TOTAL,
41                    stats.tx_queued_bytes,
42                    MetricTags::default(),
43                );
44                output.gauge(
45                    TCP_RX_QUEUED_BYTES_TOTAL,
46                    stats.rx_queued_bytes,
47                    MetricTags::default(),
48                );
49            }
50            Err(error) => {
51                emit!(HostMetricsScrapeDetailError {
52                    message: "Failed to load tcp connection info.",
53                    error,
54                });
55            }
56        }
57    }
58}
59
60#[derive(Debug, Snafu)]
61enum TcpError {
62    #[snafu(display("Could not open new netlink socket: {}", source))]
63    NetlinkSocket { source: io::Error },
64    #[snafu(display("Could not send netlink message: {}", source))]
65    NetlinkSend { source: io::Error },
66    #[snafu(display("Could not parse netlink response: {}", source))]
67    NetlinkParse {
68        source: netlink_packet_utils::DecodeError,
69    },
70    #[snafu(display("Could not recognize TCP state {state}"))]
71    InvalidTcpState { state: u8 },
72    #[snafu(display("Received an error message from netlink; code: {code}"))]
73    NetlinkMsgError { code: i32 },
74    #[snafu(display("Invalid message length: {length}"))]
75    InvalidLength { length: usize },
76}
77
78#[repr(u8)]
79enum TcpState {
80    Established = 1,
81    SynSent = 2,
82    SynRecv = 3,
83    FinWait1 = 4,
84    FinWait2 = 5,
85    TimeWait = 6,
86    Close = 7,
87    CloseWait = 8,
88    LastAck = 9,
89    Listen = 10,
90    Closing = 11,
91}
92
93impl From<TcpState> for String {
94    fn from(val: TcpState) -> Self {
95        match val {
96            TcpState::Established => "established".into(),
97            TcpState::SynSent => "syn_sent".into(),
98            TcpState::SynRecv => "syn_recv".into(),
99            TcpState::FinWait1 => "fin_wait1".into(),
100            TcpState::FinWait2 => "fin_wait2".into(),
101            TcpState::TimeWait => "time_wait".into(),
102            TcpState::Close => "close".into(),
103            TcpState::CloseWait => "close_wait".into(),
104            TcpState::LastAck => "last_ack".into(),
105            TcpState::Listen => "listen".into(),
106            TcpState::Closing => "closing".into(),
107        }
108    }
109}
110
111impl TryFrom<u8> for TcpState {
112    type Error = TcpError;
113
114    fn try_from(value: u8) -> Result<Self, Self::Error> {
115        match value {
116            1 => Ok(TcpState::Established),
117            2 => Ok(TcpState::SynSent),
118            3 => Ok(TcpState::SynRecv),
119            4 => Ok(TcpState::FinWait1),
120            5 => Ok(TcpState::FinWait2),
121            6 => Ok(TcpState::TimeWait),
122            7 => Ok(TcpState::Close),
123            8 => Ok(TcpState::CloseWait),
124            9 => Ok(TcpState::LastAck),
125            10 => Ok(TcpState::Listen),
126            11 => Ok(TcpState::Closing),
127            _ => Err(TcpError::InvalidTcpState { state: value }),
128        }
129    }
130}
131
132#[derive(Debug, Default)]
133struct TcpStats {
134    conn_states: HashMap<String, f64>,
135    rx_queued_bytes: f64,
136    tx_queued_bytes: f64,
137}
138
139/// Parses Netlink messages from a buffer, extracting [`InetResponseHeader`]s.
140///
141/// # Arguments
142/// * `buffer` - Raw byte slice containing Netlink messages.
143/// * `headers` - Mutable vector to store parsed [`InetResponseHeader`]s.
144///
145/// # Returns
146/// * `Ok(true)` if parsing is complete (Done message received).
147/// * `Ok(false)` if more data is expected. In this case, this function can be called again.
148/// * `Err(TcpError)` on invalid length, deserialization failure, or Netlink error.
149///
150/// # Errors
151/// Returns [`TcpError`] variants for invalid message lengths or Netlink errors.
152fn parse_netlink_messages(
153    buffer: &[u8],
154    headers: &mut Vec<InetResponseHeader>,
155) -> Result<bool, TcpError> {
156    let mut offset = 0;
157    let mut done = false;
158
159    while offset < buffer.len() {
160        let remaining_bytes = &buffer[offset..];
161        if remaining_bytes.len() < 4 {
162            // Still treat this as an error since we can't even read the length
163            return Err(TcpError::InvalidLength {
164                length: remaining_bytes.len(),
165            });
166        }
167        // This function panics if the buffer length is less than 4.
168        let length = NativeEndian::read_u32(&remaining_bytes[0..4]) as usize;
169        if length == 0 {
170            break;
171        }
172        if length > remaining_bytes.len() {
173            return Err(TcpError::InvalidLength { length });
174        }
175
176        let msg_bytes = &remaining_bytes[..length];
177        let rx_packet =
178            <NetlinkMessage<SockDiagMessage>>::deserialize(msg_bytes).context(NetlinkParseSnafu)?;
179
180        match rx_packet.payload {
181            NetlinkPayload::InnerMessage(SockDiagMessage::InetResponse(response)) => {
182                headers.push(response.header);
183            }
184            NetlinkPayload::Done(_) => {
185                done = true;
186                break;
187            }
188            NetlinkPayload::Error(error) => {
189                if let Some(code) = error.code {
190                    return Err(TcpError::NetlinkMsgError { code: code.get() });
191                }
192            }
193            _ => {}
194        }
195
196        offset += length;
197    }
198
199    Ok(done)
200}
201
202/// Fetches [`InetResponseHeader`]s for TCP sockets using Netlink.
203///
204/// # Arguments
205/// * `addr_family` - Address family (`AF_INET` for IPv4, `AF_INET6` for IPv6).
206///
207/// # Returns
208/// * `Ok(Vec<InetResponseHeader>)` containing headers for active TCP sockets.
209/// * `Err(TcpError)` on socket creation, send, receive, or parsing errors.
210///
211/// # Errors
212/// Returns [`TcpError`] for socket-related or message parsing failures.
213///
214/// # Notes
215/// Asynchronously queries the kernel via a Netlink socket for TCP socket info.
216async fn fetch_netlink_inet_headers(addr_family: u8) -> Result<Vec<InetResponseHeader>, TcpError> {
217    let unicast_socket: SocketAddr = SocketAddr::new(0, 0);
218    let mut socket = TokioSocket::new(NETLINK_SOCK_DIAG).context(NetlinkSocketSnafu)?;
219
220    let mut inet_req = InetRequest {
221        family: addr_family,
222        protocol: IPPROTO_TCP,
223        extensions: ExtensionFlags::INFO,
224        states: StateFlags::all(),
225        socket_id: SocketId::new_v4(),
226    };
227    if addr_family == AF_INET6 {
228        inet_req.socket_id = SocketId::new_v6();
229    }
230
231    let mut hdr = NetlinkHeader::default();
232    hdr.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP;
233    let mut msg = NetlinkMessage::new(hdr, SockDiagMessage::InetRequest(inet_req).into());
234    msg.finalize();
235
236    let mut buf = vec![0; msg.header.length as usize];
237    msg.serialize(&mut buf[..]);
238
239    socket
240        .send_to(&buf[..msg.buffer_len()], &unicast_socket)
241        .await
242        .context(NetlinkSendSnafu)?;
243
244    let mut receive_buffer = vec![0; 4096];
245    let mut inet_resp_hdrs = Vec::with_capacity(32); // Pre-allocate with an estimate
246
247    while let Ok(()) = socket.recv(&mut &mut receive_buffer[..]).await {
248        let done = parse_netlink_messages(&receive_buffer, &mut inet_resp_hdrs)?;
249        if done {
250            break;
251        }
252    }
253
254    Ok(inet_resp_hdrs)
255}
256
257fn parse_nl_inet_hdrs(
258    hdrs: Vec<InetResponseHeader>,
259    tcp_stats: &mut TcpStats,
260) -> Result<(), TcpError> {
261    for hdr in hdrs {
262        let state: TcpState = hdr.state.try_into()?;
263        let state_str: String = state.into();
264        *tcp_stats.conn_states.entry(state_str).or_insert(0.0) += 1.0;
265        tcp_stats.tx_queued_bytes += f64::from(hdr.send_queue);
266        tcp_stats.rx_queued_bytes += f64::from(hdr.recv_queue)
267    }
268
269    Ok(())
270}
271
272async fn build_tcp_stats() -> Result<TcpStats, TcpError> {
273    let mut tcp_stats = TcpStats::default();
274    let resp = fetch_netlink_inet_headers(AF_INET).await?;
275    parse_nl_inet_hdrs(resp, &mut tcp_stats)?;
276
277    if is_ipv6_enabled() {
278        let resp = fetch_netlink_inet_headers(AF_INET6).await?;
279        parse_nl_inet_hdrs(resp, &mut tcp_stats)?;
280    }
281
282    Ok(tcp_stats)
283}
284
285fn is_ipv6_enabled() -> bool {
286    Path::new(PROC_IPV6_FILE).exists()
287}
288
289#[cfg(test)]
290mod tests {
291    use tokio::net::{TcpListener, TcpStream};
292
293    use netlink_packet_sock_diag::{
294        inet::{InetResponseHeader, SocketId},
295        AF_INET,
296    };
297
298    use super::{
299        fetch_netlink_inet_headers, parse_nl_inet_hdrs, TcpStats, STATE, TCP_CONNS_TOTAL,
300        TCP_RX_QUEUED_BYTES_TOTAL, TCP_TX_QUEUED_BYTES_TOTAL,
301    };
302    use crate::sources::host_metrics::{HostMetrics, HostMetricsConfig, MetricsBuffer};
303    use crate::test_util::next_addr;
304
305    #[test]
306    fn parses_nl_inet_hdrs() {
307        let mut hdrs: Vec<InetResponseHeader> = Vec::new();
308        for i in 1..4 {
309            let hdr = InetResponseHeader {
310                family: 0,
311                state: i,
312                timer: None,
313                socket_id: SocketId::new_v4(),
314                recv_queue: 3,
315                send_queue: 5,
316                uid: 0,
317                inode: 0,
318            };
319            hdrs.push(hdr);
320        }
321
322        let mut tcp_stats = TcpStats::default();
323        parse_nl_inet_hdrs(hdrs, &mut tcp_stats).unwrap();
324
325        assert_eq!(tcp_stats.tx_queued_bytes, 15.0);
326        assert_eq!(tcp_stats.rx_queued_bytes, 9.0);
327        assert_eq!(tcp_stats.conn_states.len(), 3);
328        assert_eq!(*tcp_stats.conn_states.get("established").unwrap(), 1.0);
329        assert_eq!(*tcp_stats.conn_states.get("syn_sent").unwrap(), 1.0);
330        assert_eq!(*tcp_stats.conn_states.get("syn_recv").unwrap(), 1.0);
331    }
332
333    #[ignore]
334    /// These tests are flakey and need reworking.
335    /// This is a workaround for running these tests serially.
336    /// The `generates_tcp_metrics` test internally calls `fetch_nl_inet_hdrs` and reads
337    /// from the same socket as the `fetches_nl_net_hdrs` test.
338    #[tokio::test]
339    async fn tcp_metrics_tests() {
340        fetches_nl_net_hdrs().await;
341        generates_tcp_metrics().await;
342    }
343
344    async fn fetches_nl_net_hdrs() {
345        // start a TCP server
346        let next_addr = next_addr();
347        let listener = TcpListener::bind(next_addr).await.unwrap();
348        let addr = listener.local_addr().unwrap();
349        tokio::spawn(async move {
350            // accept a connection
351            let (_stream, _socket) = listener.accept().await.unwrap();
352        });
353        // initiate a connection
354        let _stream = TcpStream::connect(addr).await.unwrap();
355
356        let hdrs = fetch_netlink_inet_headers(AF_INET).await.unwrap();
357        // there should be at least two connections, one for the server and one for the client
358        assert!(hdrs.len() >= 2);
359
360        // assert that we have one connection with the server's port as the source port and
361        // one as the destination port
362        let mut source = false;
363        let mut dst = false;
364        for hdr in hdrs {
365            if hdr.socket_id.source_port == addr.port() {
366                source = true;
367            }
368            if hdr.socket_id.destination_port == addr.port() {
369                dst = true;
370            }
371        }
372        assert!(source);
373        assert!(dst);
374    }
375
376    async fn generates_tcp_metrics() {
377        let next_addr = next_addr();
378        let _listener = TcpListener::bind(next_addr).await.unwrap();
379
380        let mut buffer = MetricsBuffer::new(None);
381        HostMetrics::new(HostMetricsConfig::default())
382            .tcp_metrics(&mut buffer)
383            .await;
384        let metrics = buffer.metrics;
385
386        assert!(!metrics.is_empty());
387
388        let mut n_tx_queued_bytes_metric = 0;
389        let mut n_rx_queued_bytes_metric = 0;
390
391        for metric in metrics {
392            if metric.name() == TCP_CONNS_TOTAL {
393                let tags = metric.tags();
394                assert!(
395                    tags.is_some(),
396                    "Metric tcp_connections_total must have a tag"
397                );
398                let tags = tags.unwrap();
399                assert!(
400                    tags.contains_key(STATE),
401                    "Metric tcp_connections_total must have a mode tag"
402                );
403            } else if metric.name() == TCP_TX_QUEUED_BYTES_TOTAL {
404                n_tx_queued_bytes_metric += 1;
405            } else if metric.name() == TCP_RX_QUEUED_BYTES_TOTAL {
406                n_rx_queued_bytes_metric += 1;
407            } else {
408                panic!("unrecognized metric name");
409            }
410        }
411
412        assert_eq!(n_tx_queued_bytes_metric, 1);
413        assert_eq!(n_rx_queued_bytes_metric, 1);
414    }
415}