vector/sources/
postgresql_metrics.rs

1use std::{
2    collections::HashSet,
3    fmt::Write as _,
4    iter,
5    path::PathBuf,
6    time::{Duration, Instant},
7};
8
9use chrono::{DateTime, Utc};
10use futures::{
11    FutureExt, StreamExt,
12    future::{join_all, try_join_all},
13};
14use openssl::{
15    error::ErrorStack,
16    ssl::{SslConnector, SslMethod},
17};
18use postgres_openssl::MakeTlsConnector;
19use serde_with::serde_as;
20use snafu::{ResultExt, Snafu};
21use tokio::time;
22use tokio_postgres::{
23    Client, Config, Error as PgError, NoTls, Row,
24    config::{ChannelBinding, Host, SslMode, TargetSessionAttrs},
25    types::FromSql,
26};
27use tokio_stream::wrappers::IntervalStream;
28use vector_lib::{
29    ByteSizeOf, EstimatedJsonEncodedSizeOf,
30    config::LogNamespace,
31    configurable::configurable_component,
32    internal_event::{CountByteSize, InternalEventHandle as _, Registered},
33    json_size::JsonSize,
34    metric_tags,
35};
36
37use crate::{
38    config::{SourceConfig, SourceContext, SourceOutput},
39    event::metric::{Metric, MetricKind, MetricTags, MetricValue},
40    internal_events::{
41        CollectionCompleted, EndpointBytesReceived, EventsReceived, PostgresqlMetricsCollectError,
42        StreamClosedError,
43    },
44};
45
46macro_rules! tags {
47    ($tags:expr_2021) => { $tags.clone() };
48    ($tags:expr_2021, $($key:expr_2021 => $value:expr_2021),*) => {
49        {
50            let mut tags = $tags.clone();
51            $(
52                tags.replace($key.into(), String::from($value));
53            )*
54            tags
55        }
56    };
57}
58
59macro_rules! counter {
60    ($value:expr_2021) => {
61        MetricValue::Counter {
62            value: $value as f64,
63        }
64    };
65}
66
67macro_rules! gauge {
68    ($value:expr_2021) => {
69        MetricValue::Gauge {
70            value: $value as f64,
71        }
72    };
73}
74
75#[derive(Debug, Snafu)]
76enum BuildError {
77    #[snafu(display("invalid endpoint: {}", source))]
78    InvalidEndpoint { source: PgError },
79    #[snafu(display("host missing"))]
80    HostMissing,
81    #[snafu(display("multiple hosts not supported: {:?}", hosts))]
82    MultipleHostsNotSupported { hosts: Vec<Host> },
83}
84
85#[derive(Debug, Snafu)]
86enum ConnectError {
87    #[snafu(display("failed to create tls connector: {}", source))]
88    TlsFailed { source: ErrorStack },
89    #[snafu(display("failed to connect ({}): {}", endpoint, source))]
90    ConnectionFailed { source: PgError, endpoint: String },
91    #[snafu(display("failed to get PostgreSQL version ({}): {}", endpoint, source))]
92    SelectVersionFailed { source: PgError, endpoint: String },
93    #[snafu(display("version ({}) is not supported", version))]
94    InvalidVersion { version: String },
95}
96
97#[derive(Debug, Snafu)]
98enum CollectError {
99    #[snafu(display("failed to get value by key: {} (reason: {})", key, source))]
100    PostgresGetValue { source: PgError, key: &'static str },
101    #[snafu(display("query failed: {}", source))]
102    QueryError { source: PgError },
103}
104
105/// Configuration of TLS when connecting to PostgreSQL.
106#[configurable_component]
107#[derive(Clone, Debug)]
108#[serde(deny_unknown_fields)]
109struct PostgresqlMetricsTlsConfig {
110    /// Absolute path to an additional CA certificate file.
111    ///
112    /// The certificate must be in the DER or PEM (X.509) format.
113    #[configurable(metadata(docs::examples = "certs/ca.pem"))]
114    ca_file: PathBuf,
115}
116
117/// Configuration for the `postgresql_metrics` source.
118#[serde_as]
119#[configurable_component(source(
120    "postgresql_metrics",
121    "Collect metrics from the PostgreSQL database."
122))]
123#[derive(Clone, Debug)]
124#[serde(deny_unknown_fields)]
125pub struct PostgresqlMetricsConfig {
126    /// A list of PostgreSQL instances to scrape.
127    ///
128    /// Each endpoint must be in the [Connection URI
129    /// format](https://www.postgresql.org/docs/current/libpq-connect.html#id-1.7.3.8.3.6).
130    #[configurable(metadata(
131        docs::examples = "postgresql://postgres:vector@localhost:5432/postgres"
132    ))]
133    endpoints: Vec<String>,
134
135    /// A list of databases to match (by using [POSIX Regular
136    /// Expressions](https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP)) against
137    /// the `datname` column for which you want to collect metrics from.
138    ///
139    /// If not set, metrics are collected from all databases. Specifying `""` includes metrics where `datname` is
140    /// `NULL`.
141    ///
142    /// This can be used in conjunction with `exclude_databases`.
143    #[configurable(metadata(
144        docs::examples = "^postgres$",
145        docs::examples = "^vector$",
146        docs::examples = "^foo",
147    ))]
148    include_databases: Option<Vec<String>>,
149
150    /// A list of databases to match (by using [POSIX Regular
151    /// Expressions](https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP)) against
152    /// the `datname` column for which you don’t want to collect metrics from.
153    ///
154    /// Specifying `""` includes metrics where `datname` is `NULL`.
155    ///
156    /// This can be used in conjunction with `include_databases`.
157    #[configurable(metadata(docs::examples = "^postgres$", docs::examples = "^template.*",))]
158    exclude_databases: Option<Vec<String>>,
159
160    /// The interval between scrapes.
161    #[serde(default = "default_scrape_interval_secs")]
162    #[serde_as(as = "serde_with::DurationSeconds<u64>")]
163    #[configurable(metadata(docs::human_name = "Scrape Interval"))]
164    scrape_interval_secs: Duration,
165
166    /// Overrides the default namespace for the metrics emitted by the source.
167    #[serde(default = "default_namespace")]
168    namespace: String,
169
170    #[configurable(derived)]
171    tls: Option<PostgresqlMetricsTlsConfig>,
172}
173
174impl Default for PostgresqlMetricsConfig {
175    fn default() -> Self {
176        Self {
177            endpoints: vec![],
178            include_databases: None,
179            exclude_databases: None,
180            scrape_interval_secs: Duration::from_secs(15),
181            namespace: "postgresql".to_owned(),
182            tls: None,
183        }
184    }
185}
186
187impl_generate_config_from_default!(PostgresqlMetricsConfig);
188
189pub const fn default_scrape_interval_secs() -> Duration {
190    Duration::from_secs(15)
191}
192
193pub fn default_namespace() -> String {
194    "postgresql".to_owned()
195}
196
197#[async_trait::async_trait]
198#[typetag::serde(name = "postgresql_metrics")]
199impl SourceConfig for PostgresqlMetricsConfig {
200    async fn build(&self, mut cx: SourceContext) -> crate::Result<super::Source> {
201        let datname_filter = DatnameFilter::new(
202            self.include_databases.clone().unwrap_or_default(),
203            self.exclude_databases.clone().unwrap_or_default(),
204        );
205        let namespace = Some(self.namespace.clone()).filter(|namespace| !namespace.is_empty());
206
207        let mut sources = try_join_all(self.endpoints.iter().map(|endpoint| {
208            PostgresqlMetrics::new(
209                endpoint.clone(),
210                datname_filter.clone(),
211                namespace.clone(),
212                self.tls.clone(),
213            )
214        }))
215        .await?;
216
217        let duration = self.scrape_interval_secs;
218        let shutdown = cx.shutdown;
219        Ok(Box::pin(async move {
220            let mut interval = IntervalStream::new(time::interval(duration)).take_until(shutdown);
221            while interval.next().await.is_some() {
222                let start = Instant::now();
223                let metrics = join_all(sources.iter_mut().map(|source| source.collect())).await;
224                emit!(CollectionCompleted {
225                    start,
226                    end: Instant::now()
227                });
228
229                let metrics: Vec<Metric> = metrics.into_iter().flatten().collect();
230                let count = metrics.len();
231
232                if (cx.out.send_batch(metrics).await).is_err() {
233                    emit!(StreamClosedError { count });
234                    return Err(());
235                }
236            }
237
238            Ok(())
239        }))
240    }
241
242    fn outputs(&self, _global_log_namespace: LogNamespace) -> Vec<SourceOutput> {
243        vec![SourceOutput::new_metrics()]
244    }
245
246    fn can_acknowledge(&self) -> bool {
247        false
248    }
249}
250
251#[derive(Debug)]
252struct PostgresqlClient {
253    config: Config,
254    tls_config: Option<PostgresqlMetricsTlsConfig>,
255    client: Option<(Client, usize)>,
256    endpoint: String,
257}
258
259impl PostgresqlClient {
260    fn new(config: Config, tls_config: Option<PostgresqlMetricsTlsConfig>) -> Self {
261        let endpoint = config_to_endpoint(&config);
262        Self {
263            config,
264            tls_config,
265            client: None,
266            endpoint,
267        }
268    }
269
270    async fn take(&mut self) -> Result<(Client, usize), ConnectError> {
271        match self.client.take() {
272            Some((client, version)) => Ok((client, version)),
273            None => self.build_client().await,
274        }
275    }
276
277    fn set(&mut self, value: (Client, usize)) {
278        self.client.replace(value);
279    }
280
281    async fn build_client(&self) -> Result<(Client, usize), ConnectError> {
282        // Create postgresql client
283        let client = match &self.tls_config {
284            Some(tls_config) => {
285                let mut builder =
286                    SslConnector::builder(SslMethod::tls_client()).context(TlsFailedSnafu)?;
287                builder
288                    .set_ca_file(tls_config.ca_file.clone())
289                    .context(TlsFailedSnafu)?;
290                let connector = MakeTlsConnector::new(builder.build());
291
292                let (client, connection) =
293                    self.config.connect(connector).await.with_context(|_| {
294                        ConnectionFailedSnafu {
295                            endpoint: &self.endpoint,
296                        }
297                    })?;
298                tokio::spawn(connection);
299                client
300            }
301            None => {
302                let (client, connection) =
303                    self.config
304                        .connect(NoTls)
305                        .await
306                        .with_context(|_| ConnectionFailedSnafu {
307                            endpoint: &self.endpoint,
308                        })?;
309                tokio::spawn(connection);
310                client
311            }
312        };
313
314        // Log version if required
315        if tracing::level_enabled!(tracing::Level::DEBUG) {
316            let version_row = client
317                .query_one("SELECT version()", &[])
318                .await
319                .with_context(|_| SelectVersionFailedSnafu {
320                    endpoint: &self.endpoint,
321                })?;
322            let version = version_row
323                .try_get::<&str, &str>("version")
324                .with_context(|_| SelectVersionFailedSnafu {
325                    endpoint: &self.endpoint,
326                })?;
327            debug!(message = "Connected to server.", endpoint = %self.endpoint, server_version = %version);
328        }
329
330        // Get server version and check that we support it
331        let row = client
332            .query_one("SHOW server_version_num", &[])
333            .await
334            .with_context(|_| SelectVersionFailedSnafu {
335                endpoint: &self.endpoint,
336            })?;
337
338        let version = row
339            .try_get::<&str, &str>("server_version_num")
340            .with_context(|_| SelectVersionFailedSnafu {
341                endpoint: &self.endpoint,
342            })?;
343
344        let version = match version.parse::<usize>() {
345            Ok(version) if version >= 90600 => version,
346            Ok(_) | Err(_) => {
347                return Err(ConnectError::InvalidVersion {
348                    version: version.to_string(),
349                });
350            }
351        };
352
353        //
354        Ok((client, version))
355    }
356}
357
358#[derive(Debug, Clone)]
359struct DatnameFilter {
360    pg_stat_database_sql: String,
361    pg_stat_database_conflicts_sql: String,
362    match_params: Vec<String>,
363}
364
365impl DatnameFilter {
366    fn new(include: Vec<String>, exclude: Vec<String>) -> Self {
367        let (include_databases, include_null) = Self::clean_databases(include);
368        let (exclude_databases, exclude_null) = Self::clean_databases(exclude);
369        let (match_sql, match_params) =
370            Self::build_match_params(include_databases, exclude_databases);
371
372        let mut pg_stat_database_sql = "SELECT * FROM pg_stat_database".to_owned();
373        if !match_sql.is_empty() {
374            pg_stat_database_sql += " WHERE";
375            pg_stat_database_sql += &match_sql;
376        }
377        match (include_null, exclude_null) {
378            // Nothing
379            (false, false) => {}
380            // Include tracking objects not in database
381            (true, false) => {
382                pg_stat_database_sql += if match_sql.is_empty() {
383                    " WHERE"
384                } else {
385                    " OR"
386                };
387                pg_stat_database_sql += " datname IS NULL";
388            }
389            // Exclude tracking objects not in database, precedence over include
390            (false, true) | (true, true) => {
391                pg_stat_database_sql += if match_sql.is_empty() {
392                    " WHERE"
393                } else {
394                    " AND"
395                };
396                pg_stat_database_sql += " datname IS NOT NULL";
397            }
398        }
399
400        let mut pg_stat_database_conflicts_sql =
401            "SELECT * FROM pg_stat_database_conflicts".to_owned();
402        if !match_sql.is_empty() {
403            pg_stat_database_conflicts_sql += " WHERE";
404            pg_stat_database_conflicts_sql += &match_sql;
405        }
406
407        Self {
408            pg_stat_database_sql,
409            pg_stat_database_conflicts_sql,
410            match_params,
411        }
412    }
413
414    fn clean_databases(names: Vec<String>) -> (Vec<String>, bool) {
415        let mut set = names.into_iter().collect::<HashSet<_>>();
416        let null = set.remove("");
417        (set.into_iter().collect(), null)
418    }
419
420    fn build_match_params(include: Vec<String>, exclude: Vec<String>) -> (String, Vec<String>) {
421        let mut query = String::new();
422        let mut params = vec![];
423
424        if !include.is_empty() {
425            query.push_str(" (");
426            for (i, name) in include.into_iter().enumerate() {
427                params.push(name);
428                if i > 0 {
429                    query.push_str(" OR");
430                }
431                write!(query, " datname ~ ${}", params.len()).expect("write to String never fails");
432            }
433            query.push(')');
434        }
435
436        if !exclude.is_empty() {
437            if !query.is_empty() {
438                query.push_str(" AND");
439            }
440
441            query.push_str(" NOT (");
442            for (i, name) in exclude.into_iter().enumerate() {
443                params.push(name);
444                if i > 0 {
445                    query.push_str(" OR");
446                }
447                write!(query, " datname ~ ${}", params.len()).expect("write to String never fails");
448            }
449            query.push(')');
450        }
451
452        (query, params)
453    }
454
455    fn get_match_params(&self) -> Vec<&(dyn tokio_postgres::types::ToSql + Sync)> {
456        let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
457            Vec::with_capacity(self.match_params.len());
458        for item in self.match_params.iter() {
459            params.push(item);
460        }
461        params
462    }
463
464    async fn pg_stat_database(&self, client: &Client) -> Result<Vec<Row>, PgError> {
465        client
466            .query(
467                self.pg_stat_database_sql.as_str(),
468                self.get_match_params().as_slice(),
469            )
470            .await
471    }
472
473    async fn pg_stat_database_conflicts(&self, client: &Client) -> Result<Vec<Row>, PgError> {
474        client
475            .query(
476                self.pg_stat_database_conflicts_sql.as_str(),
477                self.get_match_params().as_slice(),
478            )
479            .await
480    }
481
482    async fn pg_stat_bgwriter(&self, client: &Client) -> Result<Row, PgError> {
483        client
484            .query_one("SELECT * FROM pg_stat_bgwriter", &[])
485            .await
486    }
487}
488
489struct PostgresqlMetrics {
490    client: PostgresqlClient,
491    endpoint: String,
492    namespace: Option<String>,
493    tags: MetricTags,
494    datname_filter: DatnameFilter,
495    events_received: Registered<EventsReceived>,
496}
497
498impl PostgresqlMetrics {
499    async fn new(
500        endpoint: String,
501        datname_filter: DatnameFilter,
502        namespace: Option<String>,
503        tls_config: Option<PostgresqlMetricsTlsConfig>,
504    ) -> Result<Self, BuildError> {
505        // Takes the raw endpoint, parses it into a configuration, and then we set `endpoint` back to a sanitized
506        // version of the original value, dropping things like username/password, etc.
507        let config: Config = endpoint.parse().context(InvalidEndpointSnafu)?;
508        let endpoint = config_to_endpoint(&config);
509
510        let hosts = config.get_hosts();
511        let host = match hosts.len() {
512            0 => return Err(BuildError::HostMissing),
513            1 => match &hosts[0] {
514                Host::Tcp(host) => host.clone(),
515                #[cfg(unix)]
516                Host::Unix(path) => path.to_string_lossy().to_string(),
517            },
518            _ => {
519                return Err(BuildError::MultipleHostsNotSupported {
520                    hosts: config.get_hosts().to_owned(),
521                });
522            }
523        };
524
525        let tags = metric_tags!(
526            "endpoint" => endpoint.clone(),
527            "host" => host,
528        );
529
530        Ok(Self {
531            client: PostgresqlClient::new(config, tls_config),
532            endpoint,
533            namespace,
534            tags,
535            datname_filter,
536            events_received: register!(EventsReceived),
537        })
538    }
539
540    async fn collect(&mut self) -> Box<dyn Iterator<Item = Metric> + Send> {
541        match self.collect_metrics().await {
542            Ok(metrics) => Box::new(
543                iter::once(self.create_metric("up", gauge!(1.0), tags!(self.tags))).chain(metrics),
544            ),
545            Err(error) => {
546                emit!(PostgresqlMetricsCollectError {
547                    error,
548                    endpoint: &self.endpoint,
549                });
550                Box::new(iter::once(self.create_metric(
551                    "up",
552                    gauge!(0.0),
553                    tags!(self.tags),
554                )))
555            }
556        }
557    }
558
559    async fn collect_metrics(&mut self) -> Result<impl Iterator<Item = Metric> + use<>, String> {
560        let (client, client_version) = self
561            .client
562            .take()
563            .await
564            .map_err(|error| error.to_string())?;
565
566        match try_join_all(vec![
567            self.collect_pg_stat_database(&client, client_version)
568                .boxed(),
569            self.collect_pg_stat_database_conflicts(&client).boxed(),
570            self.collect_pg_stat_bgwriter(&client).boxed(),
571        ])
572        .await
573        {
574            Ok(result) => {
575                let (count, json_byte_size, received_byte_size) =
576                    result
577                        .iter()
578                        .fold((0, JsonSize::zero(), 0), |res, (set, size)| {
579                            (
580                                res.0 + set.len(),
581                                res.1 + set.estimated_json_encoded_size_of(),
582                                res.2 + size,
583                            )
584                        });
585                emit!(EndpointBytesReceived {
586                    byte_size: received_byte_size,
587                    protocol: "tcp",
588                    endpoint: &self.endpoint,
589                });
590                self.events_received
591                    .emit(CountByteSize(count, json_byte_size));
592                self.client.set((client, client_version));
593                Ok(result.into_iter().flat_map(|(metrics, _)| metrics))
594            }
595            Err(error) => Err(error.to_string()),
596        }
597    }
598
599    async fn collect_pg_stat_database(
600        &self,
601        client: &Client,
602        client_version: usize,
603    ) -> Result<(Vec<Metric>, usize), CollectError> {
604        let rows = self
605            .datname_filter
606            .pg_stat_database(client)
607            .await
608            .context(QuerySnafu)?;
609
610        let mut metrics = Vec::with_capacity(20 * rows.len());
611        let mut reader = RowReader::default();
612        for row in rows.iter() {
613            let db = reader.read::<Option<&str>>(row, "datname")?.unwrap_or("");
614
615            metrics.extend_from_slice(&[
616                self.create_metric(
617                    "pg_stat_database_datid",
618                    gauge!(reader.read::<u32>(row, "datid")?),
619                    tags!(self.tags, "db" => db),
620                ),
621                self.create_metric(
622                    "pg_stat_database_numbackends",
623                    gauge!(reader.read::<i32>(row, "numbackends")?),
624                    tags!(self.tags, "db" => db),
625                ),
626                self.create_metric(
627                    "pg_stat_database_xact_commit_total",
628                    counter!(reader.read::<i64>(row, "xact_commit")?),
629                    tags!(self.tags, "db" => db),
630                ),
631                self.create_metric(
632                    "pg_stat_database_xact_rollback_total",
633                    counter!(reader.read::<i64>(row, "xact_rollback")?),
634                    tags!(self.tags, "db" => db),
635                ),
636                self.create_metric(
637                    "pg_stat_database_blks_read_total",
638                    counter!(reader.read::<i64>(row, "blks_read")?),
639                    tags!(self.tags, "db" => db),
640                ),
641                self.create_metric(
642                    "pg_stat_database_blks_hit_total",
643                    counter!(reader.read::<i64>(row, "blks_hit")?),
644                    tags!(self.tags, "db" => db),
645                ),
646                self.create_metric(
647                    "pg_stat_database_tup_returned_total",
648                    counter!(reader.read::<i64>(row, "tup_returned")?),
649                    tags!(self.tags, "db" => db),
650                ),
651                self.create_metric(
652                    "pg_stat_database_tup_fetched_total",
653                    counter!(reader.read::<i64>(row, "tup_fetched")?),
654                    tags!(self.tags, "db" => db),
655                ),
656                self.create_metric(
657                    "pg_stat_database_tup_inserted_total",
658                    counter!(reader.read::<i64>(row, "tup_inserted")?),
659                    tags!(self.tags, "db" => db),
660                ),
661                self.create_metric(
662                    "pg_stat_database_tup_updated_total",
663                    counter!(reader.read::<i64>(row, "tup_updated")?),
664                    tags!(self.tags, "db" => db),
665                ),
666                self.create_metric(
667                    "pg_stat_database_tup_deleted_total",
668                    counter!(reader.read::<i64>(row, "tup_deleted")?),
669                    tags!(self.tags, "db" => db),
670                ),
671                self.create_metric(
672                    "pg_stat_database_conflicts_total",
673                    counter!(reader.read::<i64>(row, "conflicts")?),
674                    tags!(self.tags, "db" => db),
675                ),
676                self.create_metric(
677                    "pg_stat_database_temp_files_total",
678                    counter!(reader.read::<i64>(row, "temp_files")?),
679                    tags!(self.tags, "db" => db),
680                ),
681                self.create_metric(
682                    "pg_stat_database_temp_bytes_total",
683                    counter!(reader.read::<i64>(row, "temp_bytes")?),
684                    tags!(self.tags, "db" => db),
685                ),
686                self.create_metric(
687                    "pg_stat_database_deadlocks_total",
688                    counter!(reader.read::<i64>(row, "deadlocks")?),
689                    tags!(self.tags, "db" => db),
690                ),
691            ]);
692            if client_version >= 120000 {
693                metrics.extend_from_slice(&[
694                    self.create_metric(
695                        "pg_stat_database_checksum_failures_total",
696                        counter!(
697                            reader
698                                .read::<Option<i64>>(row, "checksum_failures")?
699                                .unwrap_or(0)
700                        ),
701                        tags!(self.tags, "db" => db),
702                    ),
703                    self.create_metric(
704                        "pg_stat_database_checksum_last_failure",
705                        gauge!(
706                            reader
707                                .read::<Option<DateTime<Utc>>>(row, "checksum_last_failure")?
708                                .map(|t| t.timestamp())
709                                .unwrap_or(0)
710                        ),
711                        tags!(self.tags, "db" => db),
712                    ),
713                ]);
714            }
715            metrics.extend_from_slice(&[
716                self.create_metric(
717                    "pg_stat_database_blk_read_time_seconds_total",
718                    counter!(reader.read::<f64>(row, "blk_read_time")? / 1000f64),
719                    tags!(self.tags, "db" => db),
720                ),
721                self.create_metric(
722                    "pg_stat_database_blk_write_time_seconds_total",
723                    counter!(reader.read::<f64>(row, "blk_write_time")? / 1000f64),
724                    tags!(self.tags, "db" => db),
725                ),
726                self.create_metric(
727                    "pg_stat_database_stats_reset",
728                    gauge!(
729                        reader
730                            .read::<Option<DateTime<Utc>>>(row, "stats_reset")?
731                            .map(|t| t.timestamp())
732                            .unwrap_or(0)
733                    ),
734                    tags!(self.tags, "db" => db),
735                ),
736            ]);
737        }
738        Ok((metrics, reader.into_inner()))
739    }
740
741    async fn collect_pg_stat_database_conflicts(
742        &self,
743        client: &Client,
744    ) -> Result<(Vec<Metric>, usize), CollectError> {
745        let rows = self
746            .datname_filter
747            .pg_stat_database_conflicts(client)
748            .await
749            .context(QuerySnafu)?;
750
751        let mut metrics = Vec::with_capacity(5 * rows.len());
752        let mut reader = RowReader::default();
753        for row in rows.iter() {
754            let db = reader.read::<&str>(row, "datname")?;
755
756            metrics.extend_from_slice(&[
757                self.create_metric(
758                    "pg_stat_database_conflicts_confl_tablespace_total",
759                    counter!(reader.read::<i64>(row, "confl_tablespace")?),
760                    tags!(self.tags, "db" => db),
761                ),
762                self.create_metric(
763                    "pg_stat_database_conflicts_confl_lock_total",
764                    counter!(reader.read::<i64>(row, "confl_lock")?),
765                    tags!(self.tags, "db" => db),
766                ),
767                self.create_metric(
768                    "pg_stat_database_conflicts_confl_snapshot_total",
769                    counter!(reader.read::<i64>(row, "confl_snapshot")?),
770                    tags!(self.tags, "db" => db),
771                ),
772                self.create_metric(
773                    "pg_stat_database_conflicts_confl_bufferpin_total",
774                    counter!(reader.read::<i64>(row, "confl_bufferpin")?),
775                    tags!(self.tags, "db" => db),
776                ),
777                self.create_metric(
778                    "pg_stat_database_conflicts_confl_deadlock_total",
779                    counter!(reader.read::<i64>(row, "confl_deadlock")?),
780                    tags!(self.tags, "db" => db),
781                ),
782            ]);
783        }
784        Ok((metrics, reader.into_inner()))
785    }
786
787    async fn collect_pg_stat_bgwriter(
788        &self,
789        client: &Client,
790    ) -> Result<(Vec<Metric>, usize), CollectError> {
791        let row = self
792            .datname_filter
793            .pg_stat_bgwriter(client)
794            .await
795            .context(QuerySnafu)?;
796        let mut reader = RowReader::default();
797
798        Ok((
799            vec![
800                self.create_metric(
801                    "pg_stat_bgwriter_checkpoints_timed_total",
802                    counter!(reader.read::<i64>(&row, "checkpoints_timed")?),
803                    tags!(self.tags),
804                ),
805                self.create_metric(
806                    "pg_stat_bgwriter_checkpoints_req_total",
807                    counter!(reader.read::<i64>(&row, "checkpoints_req")?),
808                    tags!(self.tags),
809                ),
810                self.create_metric(
811                    "pg_stat_bgwriter_checkpoint_write_time_seconds_total",
812                    counter!(reader.read::<f64>(&row, "checkpoint_write_time")? / 1000f64),
813                    tags!(self.tags),
814                ),
815                self.create_metric(
816                    "pg_stat_bgwriter_checkpoint_sync_time_seconds_total",
817                    counter!(reader.read::<f64>(&row, "checkpoint_sync_time")? / 1000f64),
818                    tags!(self.tags),
819                ),
820                self.create_metric(
821                    "pg_stat_bgwriter_buffers_checkpoint_total",
822                    counter!(reader.read::<i64>(&row, "buffers_checkpoint")?),
823                    tags!(self.tags),
824                ),
825                self.create_metric(
826                    "pg_stat_bgwriter_buffers_clean_total",
827                    counter!(reader.read::<i64>(&row, "buffers_clean")?),
828                    tags!(self.tags),
829                ),
830                self.create_metric(
831                    "pg_stat_bgwriter_maxwritten_clean_total",
832                    counter!(reader.read::<i64>(&row, "maxwritten_clean")?),
833                    tags!(self.tags),
834                ),
835                self.create_metric(
836                    "pg_stat_bgwriter_buffers_backend_total",
837                    counter!(reader.read::<i64>(&row, "buffers_backend")?),
838                    tags!(self.tags),
839                ),
840                self.create_metric(
841                    "pg_stat_bgwriter_buffers_backend_fsync_total",
842                    counter!(reader.read::<i64>(&row, "buffers_backend_fsync")?),
843                    tags!(self.tags),
844                ),
845                self.create_metric(
846                    "pg_stat_bgwriter_buffers_alloc_total",
847                    counter!(reader.read::<i64>(&row, "buffers_alloc")?),
848                    tags!(self.tags),
849                ),
850                self.create_metric(
851                    "pg_stat_bgwriter_stats_reset",
852                    gauge!(
853                        reader
854                            .read::<DateTime<Utc>>(&row, "stats_reset")?
855                            .timestamp()
856                    ),
857                    tags!(self.tags),
858                ),
859            ],
860            reader.into_inner(),
861        ))
862    }
863
864    fn create_metric(&self, name: &str, value: MetricValue, tags: MetricTags) -> Metric {
865        Metric::new(name, MetricKind::Absolute, value)
866            .with_namespace(self.namespace.clone())
867            .with_tags(Some(tags))
868            .with_timestamp(Some(Utc::now()))
869    }
870}
871
872#[derive(Default)]
873struct RowReader(usize);
874
875impl RowReader {
876    pub fn read<'a, T: FromSql<'a> + ByteSizeOf>(
877        &mut self,
878        row: &'a Row,
879        key: &'static str,
880    ) -> Result<T, CollectError> {
881        let value = row_get_value::<T>(row, key)?;
882        self.0 += value.size_of();
883        Ok(value)
884    }
885
886    pub const fn into_inner(self) -> usize {
887        self.0
888    }
889}
890
891fn row_get_value<'a, T: FromSql<'a> + ByteSizeOf>(
892    row: &'a Row,
893    key: &'static str,
894) -> Result<T, CollectError> {
895    row.try_get::<&str, T>(key)
896        .map_err(|source| CollectError::PostgresGetValue { source, key })
897}
898
899fn config_to_endpoint(config: &Config) -> String {
900    let mut params: Vec<(&'static str, String)> = vec![];
901
902    // options
903    if let Some(options) = config.get_options() {
904        params.push(("options", options.to_string()));
905    }
906
907    // application_name
908    if let Some(name) = config.get_application_name() {
909        params.push(("application_name", name.to_string()));
910    }
911
912    // ssl_mode, ignore default value (SslMode::Prefer)
913    match config.get_ssl_mode() {
914        SslMode::Disable => params.push(("sslmode", "disable".to_string())),
915        SslMode::Prefer => {} // default, ignore
916        SslMode::Require => params.push(("sslmode", "require".to_string())),
917        // non_exhaustive enum
918        _ => {
919            warn!("Unknown variant of \"SslMode\".");
920        }
921    };
922
923    // host
924    for host in config.get_hosts() {
925        match host {
926            Host::Tcp(host) => params.push(("host", host.to_string())),
927            #[cfg(unix)]
928            Host::Unix(path) => params.push(("host", path.to_string_lossy().to_string())),
929        }
930    }
931
932    // port
933    for port in config.get_ports() {
934        params.push(("port", port.to_string()));
935    }
936
937    // connect_timeout
938    if let Some(connect_timeout) = config.get_connect_timeout() {
939        params.push(("connect_timeout", connect_timeout.as_secs().to_string()));
940    }
941
942    // keepalives, ignore default value (true)
943    if !config.get_keepalives() {
944        params.push(("keepalives", "1".to_owned()));
945    }
946
947    // keepalives_idle, ignore default value (2 * 60 * 60)
948    let keepalives_idle = config.get_keepalives_idle().as_secs();
949    if keepalives_idle != 2 * 60 * 60 {
950        params.push(("keepalives_idle", keepalives_idle.to_string()));
951    }
952
953    // target_session_attrs, ignore default value (TargetSessionAttrs::Any)
954    match config.get_target_session_attrs() {
955        TargetSessionAttrs::Any => {} // default, ignore
956        TargetSessionAttrs::ReadWrite => {
957            params.push(("target_session_attrs", "read-write".to_owned()))
958        }
959        // non_exhaustive enum
960        _ => {
961            warn!("Unknown variant of \"TargetSessionAttr\".");
962        }
963    }
964
965    // channel_binding, ignore default value (ChannelBinding::Prefer)
966    match config.get_channel_binding() {
967        ChannelBinding::Disable => params.push(("channel_binding", "disable".to_owned())),
968        ChannelBinding::Prefer => {} // default, ignore
969        ChannelBinding::Require => params.push(("channel_binding", "require".to_owned())),
970        // non_exhaustive enum
971        _ => {
972            warn!("Unknown variant of \"ChannelBinding\".");
973        }
974    }
975
976    format!(
977        "postgresql:///{}?{}",
978        config.get_dbname().unwrap_or(""),
979        params
980            .into_iter()
981            .map(|(k, v)| format!(
982                "{}={}",
983                percent_encoding(k.as_bytes()),
984                percent_encoding(v.as_bytes())
985            ))
986            .collect::<Vec<String>>()
987            .join("&")
988    )
989}
990
991fn percent_encoding(input: &'_ [u8]) -> String {
992    percent_encoding::percent_encode(input, percent_encoding::NON_ALPHANUMERIC).to_string()
993}
994
995#[cfg(test)]
996mod tests {
997    use super::*;
998
999    #[test]
1000    fn generate_config() {
1001        crate::test_util::test_generate_config::<PostgresqlMetricsConfig>();
1002    }
1003}
1004
1005#[cfg(all(test, feature = "postgresql_metrics-integration-tests"))]
1006mod integration_tests {
1007    use super::*;
1008    use crate::{
1009        SourceSender,
1010        event::Event,
1011        test_util::{
1012            components::{PULL_SOURCE_TAGS, assert_source_compliance},
1013            integration::postgres::{pg_socket, pg_url},
1014        },
1015        tls,
1016    };
1017
1018    async fn test_postgresql_metrics(
1019        endpoint: String,
1020        tls: Option<PostgresqlMetricsTlsConfig>,
1021        include_databases: Option<Vec<String>>,
1022        exclude_databases: Option<Vec<String>>,
1023    ) -> Vec<Event> {
1024        assert_source_compliance(&PULL_SOURCE_TAGS, async move {
1025            let config: Config = endpoint.parse().unwrap();
1026            let tags_endpoint = config_to_endpoint(&config);
1027            let tags_host = match config.get_hosts().first().unwrap() {
1028                Host::Tcp(host) => host.clone(),
1029                #[cfg(unix)]
1030                Host::Unix(path) => path.to_string_lossy().to_string(),
1031            };
1032
1033            let (sender, mut recv) = SourceSender::new_test();
1034
1035            tokio::spawn(async move {
1036                PostgresqlMetricsConfig {
1037                    endpoints: vec![endpoint],
1038                    tls,
1039                    include_databases,
1040                    exclude_databases,
1041                    ..Default::default()
1042                }
1043                .build(SourceContext::new_test(sender, None))
1044                .await
1045                .unwrap()
1046                .await
1047                .unwrap()
1048            });
1049
1050            let event = time::timeout(time::Duration::from_secs(3), recv.next())
1051                .await
1052                .expect("fetch metrics timeout")
1053                .expect("failed to get metrics from a stream");
1054            let mut events = vec![event];
1055            loop {
1056                match time::timeout(time::Duration::from_millis(10), recv.next()).await {
1057                    Ok(Some(event)) => events.push(event),
1058                    Ok(None) => break,
1059                    Err(_) => break,
1060                }
1061            }
1062
1063            assert!(events.len() > 1);
1064
1065            // test up metric
1066            assert_eq!(
1067                events
1068                    .iter()
1069                    .map(|e| e.as_metric())
1070                    .find(|e| e.name() == "up")
1071                    .unwrap()
1072                    .value(),
1073                &gauge!(1)
1074            );
1075
1076            // test namespace and tags
1077            for event in &events {
1078                let metric = event.as_metric();
1079
1080                assert_eq!(metric.namespace(), Some("postgresql"));
1081                assert_eq!(
1082                    metric.tags().unwrap().get("endpoint").unwrap(),
1083                    &tags_endpoint
1084                );
1085                assert_eq!(metric.tags().unwrap().get("host").unwrap(), &tags_host);
1086            }
1087
1088            // test metrics from different queries
1089            let names = vec![
1090                "pg_stat_database_datid",
1091                "pg_stat_database_conflicts_confl_tablespace_total",
1092                "pg_stat_bgwriter_checkpoints_timed_total",
1093            ];
1094            for name in names {
1095                assert!(events.iter().any(|e| e.as_metric().name() == name));
1096            }
1097
1098            events
1099        })
1100        .await
1101    }
1102
1103    #[tokio::test]
1104    async fn test_host() {
1105        test_postgresql_metrics(pg_url(), None, None, None).await;
1106    }
1107
1108    #[tokio::test]
1109    async fn test_local() {
1110        let endpoint = format!(
1111            "postgresql:///postgres?host={}&user=vector&password=vector",
1112            pg_socket().to_str().unwrap()
1113        );
1114        test_postgresql_metrics(endpoint, None, None, None).await;
1115    }
1116
1117    #[tokio::test]
1118    async fn test_host_ssl() {
1119        test_postgresql_metrics(
1120            format!("{}?sslmode=require", pg_url()),
1121            Some(PostgresqlMetricsTlsConfig {
1122                ca_file: tls::TEST_PEM_CA_PATH.into(),
1123            }),
1124            None,
1125            None,
1126        )
1127        .await;
1128    }
1129
1130    #[tokio::test]
1131    async fn test_host_include_databases() {
1132        let events = test_postgresql_metrics(
1133            pg_url(),
1134            None,
1135            Some(vec!["^vec".to_owned(), "gres$".to_owned()]),
1136            None,
1137        )
1138        .await;
1139
1140        for event in events {
1141            let metric = event.into_metric();
1142
1143            if let Some(db) = metric.tags().unwrap().get("db") {
1144                assert!(db == "vector" || db == "postgres");
1145            }
1146        }
1147    }
1148
1149    #[tokio::test]
1150    async fn test_host_exclude_databases() {
1151        let events = test_postgresql_metrics(
1152            pg_url(),
1153            None,
1154            None,
1155            Some(vec!["^vec".to_owned(), "gres$".to_owned()]),
1156        )
1157        .await;
1158
1159        for event in events {
1160            let metric = event.into_metric();
1161
1162            if let Some(db) = metric.tags().unwrap().get("db") {
1163                assert!(db != "vector" && db != "postgres");
1164            }
1165        }
1166    }
1167
1168    #[tokio::test]
1169    async fn test_host_exclude_databases_empty() {
1170        test_postgresql_metrics(pg_url(), None, None, Some(vec!["".to_owned()])).await;
1171    }
1172
1173    #[tokio::test]
1174    async fn test_host_include_databases_and_exclude_databases() {
1175        let events = test_postgresql_metrics(
1176            pg_url(),
1177            None,
1178            Some(vec!["template\\d+".to_owned()]),
1179            Some(vec!["template0".to_owned()]),
1180        )
1181        .await;
1182
1183        for event in events {
1184            let metric = event.into_metric();
1185
1186            if let Some(db) = metric.tags().unwrap().get("db") {
1187                assert!(db == "template1");
1188            }
1189        }
1190    }
1191}