vector/sources/
postgresql_metrics.rs

1use std::{
2    collections::HashSet,
3    fmt::Write as _,
4    iter,
5    path::PathBuf,
6    time::{Duration, Instant},
7};
8
9use chrono::{DateTime, Utc};
10use futures::{
11    future::{join_all, try_join_all},
12    FutureExt, StreamExt,
13};
14use openssl::{
15    error::ErrorStack,
16    ssl::{SslConnector, SslMethod},
17};
18use postgres_openssl::MakeTlsConnector;
19use serde_with::serde_as;
20use snafu::{ResultExt, Snafu};
21use tokio::time;
22use tokio_postgres::{
23    config::{ChannelBinding, Host, SslMode, TargetSessionAttrs},
24    types::FromSql,
25    Client, Config, Error as PgError, NoTls, Row,
26};
27use tokio_stream::wrappers::IntervalStream;
28use vector_lib::config::LogNamespace;
29use vector_lib::configurable::configurable_component;
30use vector_lib::{
31    internal_event::{CountByteSize, InternalEventHandle as _, Registered},
32    json_size::JsonSize,
33};
34use vector_lib::{metric_tags, ByteSizeOf, EstimatedJsonEncodedSizeOf};
35
36use crate::{
37    config::{SourceConfig, SourceContext, SourceOutput},
38    event::metric::{Metric, MetricKind, MetricTags, MetricValue},
39    internal_events::{
40        CollectionCompleted, EndpointBytesReceived, EventsReceived, PostgresqlMetricsCollectError,
41        StreamClosedError,
42    },
43};
44
45macro_rules! tags {
46    ($tags:expr_2021) => { $tags.clone() };
47    ($tags:expr_2021, $($key:expr_2021 => $value:expr_2021),*) => {
48        {
49            let mut tags = $tags.clone();
50            $(
51                tags.replace($key.into(), String::from($value));
52            )*
53            tags
54        }
55    };
56}
57
58macro_rules! counter {
59    ($value:expr_2021) => {
60        MetricValue::Counter {
61            value: $value as f64,
62        }
63    };
64}
65
66macro_rules! gauge {
67    ($value:expr_2021) => {
68        MetricValue::Gauge {
69            value: $value as f64,
70        }
71    };
72}
73
74#[derive(Debug, Snafu)]
75enum BuildError {
76    #[snafu(display("invalid endpoint: {}", source))]
77    InvalidEndpoint { source: PgError },
78    #[snafu(display("host missing"))]
79    HostMissing,
80    #[snafu(display("multiple hosts not supported: {:?}", hosts))]
81    MultipleHostsNotSupported { hosts: Vec<Host> },
82}
83
84#[derive(Debug, Snafu)]
85enum ConnectError {
86    #[snafu(display("failed to create tls connector: {}", source))]
87    TlsFailed { source: ErrorStack },
88    #[snafu(display("failed to connect ({}): {}", endpoint, source))]
89    ConnectionFailed { source: PgError, endpoint: String },
90    #[snafu(display("failed to get PostgreSQL version ({}): {}", endpoint, source))]
91    SelectVersionFailed { source: PgError, endpoint: String },
92    #[snafu(display("version ({}) is not supported", version))]
93    InvalidVersion { version: String },
94}
95
96#[derive(Debug, Snafu)]
97enum CollectError {
98    #[snafu(display("failed to get value by key: {} (reason: {})", key, source))]
99    PostgresGetValue { source: PgError, key: &'static str },
100    #[snafu(display("query failed: {}", source))]
101    QueryError { source: PgError },
102}
103
104/// Configuration of TLS when connecting to PostgreSQL.
105#[configurable_component]
106#[derive(Clone, Debug)]
107#[serde(deny_unknown_fields)]
108struct PostgresqlMetricsTlsConfig {
109    /// Absolute path to an additional CA certificate file.
110    ///
111    /// The certificate must be in the DER or PEM (X.509) format.
112    #[configurable(metadata(docs::examples = "certs/ca.pem"))]
113    ca_file: PathBuf,
114}
115
116/// Configuration for the `postgresql_metrics` source.
117#[serde_as]
118#[configurable_component(source(
119    "postgresql_metrics",
120    "Collect metrics from the PostgreSQL database."
121))]
122#[derive(Clone, Debug)]
123#[serde(deny_unknown_fields)]
124pub struct PostgresqlMetricsConfig {
125    /// A list of PostgreSQL instances to scrape.
126    ///
127    /// Each endpoint must be in the [Connection URI
128    /// format](https://www.postgresql.org/docs/current/libpq-connect.html#id-1.7.3.8.3.6).
129    #[configurable(metadata(
130        docs::examples = "postgresql://postgres:vector@localhost:5432/postgres"
131    ))]
132    endpoints: Vec<String>,
133
134    /// A list of databases to match (by using [POSIX Regular
135    /// Expressions](https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP)) against
136    /// the `datname` column for which you want to collect metrics from.
137    ///
138    /// If not set, metrics are collected from all databases. Specifying `""` includes metrics where `datname` is
139    /// `NULL`.
140    ///
141    /// This can be used in conjunction with `exclude_databases`.
142    #[configurable(metadata(
143        docs::examples = "^postgres$",
144        docs::examples = "^vector$",
145        docs::examples = "^foo",
146    ))]
147    include_databases: Option<Vec<String>>,
148
149    /// A list of databases to match (by using [POSIX Regular
150    /// Expressions](https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP)) against
151    /// the `datname` column for which you don’t want to collect metrics from.
152    ///
153    /// Specifying `""` includes metrics where `datname` is `NULL`.
154    ///
155    /// This can be used in conjunction with `include_databases`.
156    #[configurable(metadata(docs::examples = "^postgres$", docs::examples = "^template.*",))]
157    exclude_databases: Option<Vec<String>>,
158
159    /// The interval between scrapes.
160    #[serde(default = "default_scrape_interval_secs")]
161    #[serde_as(as = "serde_with::DurationSeconds<u64>")]
162    #[configurable(metadata(docs::human_name = "Scrape Interval"))]
163    scrape_interval_secs: Duration,
164
165    /// Overrides the default namespace for the metrics emitted by the source.
166    #[serde(default = "default_namespace")]
167    namespace: String,
168
169    #[configurable(derived)]
170    tls: Option<PostgresqlMetricsTlsConfig>,
171}
172
173impl Default for PostgresqlMetricsConfig {
174    fn default() -> Self {
175        Self {
176            endpoints: vec![],
177            include_databases: None,
178            exclude_databases: None,
179            scrape_interval_secs: Duration::from_secs(15),
180            namespace: "postgresql".to_owned(),
181            tls: None,
182        }
183    }
184}
185
186impl_generate_config_from_default!(PostgresqlMetricsConfig);
187
188pub const fn default_scrape_interval_secs() -> Duration {
189    Duration::from_secs(15)
190}
191
192pub fn default_namespace() -> String {
193    "postgresql".to_owned()
194}
195
196#[async_trait::async_trait]
197#[typetag::serde(name = "postgresql_metrics")]
198impl SourceConfig for PostgresqlMetricsConfig {
199    async fn build(&self, mut cx: SourceContext) -> crate::Result<super::Source> {
200        let datname_filter = DatnameFilter::new(
201            self.include_databases.clone().unwrap_or_default(),
202            self.exclude_databases.clone().unwrap_or_default(),
203        );
204        let namespace = Some(self.namespace.clone()).filter(|namespace| !namespace.is_empty());
205
206        let mut sources = try_join_all(self.endpoints.iter().map(|endpoint| {
207            PostgresqlMetrics::new(
208                endpoint.clone(),
209                datname_filter.clone(),
210                namespace.clone(),
211                self.tls.clone(),
212            )
213        }))
214        .await?;
215
216        let duration = self.scrape_interval_secs;
217        let shutdown = cx.shutdown;
218        Ok(Box::pin(async move {
219            let mut interval = IntervalStream::new(time::interval(duration)).take_until(shutdown);
220            while interval.next().await.is_some() {
221                let start = Instant::now();
222                let metrics = join_all(sources.iter_mut().map(|source| source.collect())).await;
223                emit!(CollectionCompleted {
224                    start,
225                    end: Instant::now()
226                });
227
228                let metrics: Vec<Metric> = metrics.into_iter().flatten().collect();
229                let count = metrics.len();
230
231                if (cx.out.send_batch(metrics).await).is_err() {
232                    emit!(StreamClosedError { count });
233                    return Err(());
234                }
235            }
236
237            Ok(())
238        }))
239    }
240
241    fn outputs(&self, _global_log_namespace: LogNamespace) -> Vec<SourceOutput> {
242        vec![SourceOutput::new_metrics()]
243    }
244
245    fn can_acknowledge(&self) -> bool {
246        false
247    }
248}
249
250#[derive(Debug)]
251struct PostgresqlClient {
252    config: Config,
253    tls_config: Option<PostgresqlMetricsTlsConfig>,
254    client: Option<(Client, usize)>,
255    endpoint: String,
256}
257
258impl PostgresqlClient {
259    fn new(config: Config, tls_config: Option<PostgresqlMetricsTlsConfig>) -> Self {
260        let endpoint = config_to_endpoint(&config);
261        Self {
262            config,
263            tls_config,
264            client: None,
265            endpoint,
266        }
267    }
268
269    async fn take(&mut self) -> Result<(Client, usize), ConnectError> {
270        match self.client.take() {
271            Some((client, version)) => Ok((client, version)),
272            None => self.build_client().await,
273        }
274    }
275
276    fn set(&mut self, value: (Client, usize)) {
277        self.client.replace(value);
278    }
279
280    async fn build_client(&self) -> Result<(Client, usize), ConnectError> {
281        // Create postgresql client
282        let client = match &self.tls_config {
283            Some(tls_config) => {
284                let mut builder =
285                    SslConnector::builder(SslMethod::tls_client()).context(TlsFailedSnafu)?;
286                builder
287                    .set_ca_file(tls_config.ca_file.clone())
288                    .context(TlsFailedSnafu)?;
289                let connector = MakeTlsConnector::new(builder.build());
290
291                let (client, connection) =
292                    self.config.connect(connector).await.with_context(|_| {
293                        ConnectionFailedSnafu {
294                            endpoint: &self.endpoint,
295                        }
296                    })?;
297                tokio::spawn(connection);
298                client
299            }
300            None => {
301                let (client, connection) =
302                    self.config
303                        .connect(NoTls)
304                        .await
305                        .with_context(|_| ConnectionFailedSnafu {
306                            endpoint: &self.endpoint,
307                        })?;
308                tokio::spawn(connection);
309                client
310            }
311        };
312
313        // Log version if required
314        if tracing::level_enabled!(tracing::Level::DEBUG) {
315            let version_row = client
316                .query_one("SELECT version()", &[])
317                .await
318                .with_context(|_| SelectVersionFailedSnafu {
319                    endpoint: &self.endpoint,
320                })?;
321            let version = version_row
322                .try_get::<&str, &str>("version")
323                .with_context(|_| SelectVersionFailedSnafu {
324                    endpoint: &self.endpoint,
325                })?;
326            debug!(message = "Connected to server.", endpoint = %self.endpoint, server_version = %version);
327        }
328
329        // Get server version and check that we support it
330        let row = client
331            .query_one("SHOW server_version_num", &[])
332            .await
333            .with_context(|_| SelectVersionFailedSnafu {
334                endpoint: &self.endpoint,
335            })?;
336
337        let version = row
338            .try_get::<&str, &str>("server_version_num")
339            .with_context(|_| SelectVersionFailedSnafu {
340                endpoint: &self.endpoint,
341            })?;
342
343        let version = match version.parse::<usize>() {
344            Ok(version) if version >= 90600 => version,
345            Ok(_) | Err(_) => {
346                return Err(ConnectError::InvalidVersion {
347                    version: version.to_string(),
348                })
349            }
350        };
351
352        //
353        Ok((client, version))
354    }
355}
356
357#[derive(Debug, Clone)]
358struct DatnameFilter {
359    pg_stat_database_sql: String,
360    pg_stat_database_conflicts_sql: String,
361    match_params: Vec<String>,
362}
363
364impl DatnameFilter {
365    fn new(include: Vec<String>, exclude: Vec<String>) -> Self {
366        let (include_databases, include_null) = Self::clean_databases(include);
367        let (exclude_databases, exclude_null) = Self::clean_databases(exclude);
368        let (match_sql, match_params) =
369            Self::build_match_params(include_databases, exclude_databases);
370
371        let mut pg_stat_database_sql = "SELECT * FROM pg_stat_database".to_owned();
372        if !match_sql.is_empty() {
373            pg_stat_database_sql += " WHERE";
374            pg_stat_database_sql += &match_sql;
375        }
376        match (include_null, exclude_null) {
377            // Nothing
378            (false, false) => {}
379            // Include tracking objects not in database
380            (true, false) => {
381                pg_stat_database_sql += if match_sql.is_empty() {
382                    " WHERE"
383                } else {
384                    " OR"
385                };
386                pg_stat_database_sql += " datname IS NULL";
387            }
388            // Exclude tracking objects not in database, precedence over include
389            (false, true) | (true, true) => {
390                pg_stat_database_sql += if match_sql.is_empty() {
391                    " WHERE"
392                } else {
393                    " AND"
394                };
395                pg_stat_database_sql += " datname IS NOT NULL";
396            }
397        }
398
399        let mut pg_stat_database_conflicts_sql =
400            "SELECT * FROM pg_stat_database_conflicts".to_owned();
401        if !match_sql.is_empty() {
402            pg_stat_database_conflicts_sql += " WHERE";
403            pg_stat_database_conflicts_sql += &match_sql;
404        }
405
406        Self {
407            pg_stat_database_sql,
408            pg_stat_database_conflicts_sql,
409            match_params,
410        }
411    }
412
413    fn clean_databases(names: Vec<String>) -> (Vec<String>, bool) {
414        let mut set = names.into_iter().collect::<HashSet<_>>();
415        let null = set.remove("");
416        (set.into_iter().collect(), null)
417    }
418
419    fn build_match_params(include: Vec<String>, exclude: Vec<String>) -> (String, Vec<String>) {
420        let mut query = String::new();
421        let mut params = vec![];
422
423        if !include.is_empty() {
424            query.push_str(" (");
425            for (i, name) in include.into_iter().enumerate() {
426                params.push(name);
427                if i > 0 {
428                    query.push_str(" OR");
429                }
430                write!(query, " datname ~ ${}", params.len()).expect("write to String never fails");
431            }
432            query.push(')');
433        }
434
435        if !exclude.is_empty() {
436            if !query.is_empty() {
437                query.push_str(" AND");
438            }
439
440            query.push_str(" NOT (");
441            for (i, name) in exclude.into_iter().enumerate() {
442                params.push(name);
443                if i > 0 {
444                    query.push_str(" OR");
445                }
446                write!(query, " datname ~ ${}", params.len()).expect("write to String never fails");
447            }
448            query.push(')');
449        }
450
451        (query, params)
452    }
453
454    fn get_match_params(&self) -> Vec<&(dyn tokio_postgres::types::ToSql + Sync)> {
455        let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
456            Vec::with_capacity(self.match_params.len());
457        for item in self.match_params.iter() {
458            params.push(item);
459        }
460        params
461    }
462
463    async fn pg_stat_database(&self, client: &Client) -> Result<Vec<Row>, PgError> {
464        client
465            .query(
466                self.pg_stat_database_sql.as_str(),
467                self.get_match_params().as_slice(),
468            )
469            .await
470    }
471
472    async fn pg_stat_database_conflicts(&self, client: &Client) -> Result<Vec<Row>, PgError> {
473        client
474            .query(
475                self.pg_stat_database_conflicts_sql.as_str(),
476                self.get_match_params().as_slice(),
477            )
478            .await
479    }
480
481    async fn pg_stat_bgwriter(&self, client: &Client) -> Result<Row, PgError> {
482        client
483            .query_one("SELECT * FROM pg_stat_bgwriter", &[])
484            .await
485    }
486}
487
488struct PostgresqlMetrics {
489    client: PostgresqlClient,
490    endpoint: String,
491    namespace: Option<String>,
492    tags: MetricTags,
493    datname_filter: DatnameFilter,
494    events_received: Registered<EventsReceived>,
495}
496
497impl PostgresqlMetrics {
498    async fn new(
499        endpoint: String,
500        datname_filter: DatnameFilter,
501        namespace: Option<String>,
502        tls_config: Option<PostgresqlMetricsTlsConfig>,
503    ) -> Result<Self, BuildError> {
504        // Takes the raw endpoint, parses it into a configuration, and then we set `endpoint` back to a sanitized
505        // version of the original value, dropping things like username/password, etc.
506        let config: Config = endpoint.parse().context(InvalidEndpointSnafu)?;
507        let endpoint = config_to_endpoint(&config);
508
509        let hosts = config.get_hosts();
510        let host = match hosts.len() {
511            0 => return Err(BuildError::HostMissing),
512            1 => match &hosts[0] {
513                Host::Tcp(host) => host.clone(),
514                #[cfg(unix)]
515                Host::Unix(path) => path.to_string_lossy().to_string(),
516            },
517            _ => {
518                return Err(BuildError::MultipleHostsNotSupported {
519                    hosts: config.get_hosts().to_owned(),
520                })
521            }
522        };
523
524        let tags = metric_tags!(
525            "endpoint" => endpoint.clone(),
526            "host" => host,
527        );
528
529        Ok(Self {
530            client: PostgresqlClient::new(config, tls_config),
531            endpoint,
532            namespace,
533            tags,
534            datname_filter,
535            events_received: register!(EventsReceived),
536        })
537    }
538
539    async fn collect(&mut self) -> Box<dyn Iterator<Item = Metric> + Send> {
540        match self.collect_metrics().await {
541            Ok(metrics) => Box::new(
542                iter::once(self.create_metric("up", gauge!(1.0), tags!(self.tags))).chain(metrics),
543            ),
544            Err(error) => {
545                emit!(PostgresqlMetricsCollectError {
546                    error,
547                    endpoint: &self.endpoint,
548                });
549                Box::new(iter::once(self.create_metric(
550                    "up",
551                    gauge!(0.0),
552                    tags!(self.tags),
553                )))
554            }
555        }
556    }
557
558    async fn collect_metrics(&mut self) -> Result<impl Iterator<Item = Metric> + use<>, String> {
559        let (client, client_version) = self
560            .client
561            .take()
562            .await
563            .map_err(|error| error.to_string())?;
564
565        match try_join_all(vec![
566            self.collect_pg_stat_database(&client, client_version)
567                .boxed(),
568            self.collect_pg_stat_database_conflicts(&client).boxed(),
569            self.collect_pg_stat_bgwriter(&client).boxed(),
570        ])
571        .await
572        {
573            Ok(result) => {
574                let (count, json_byte_size, received_byte_size) =
575                    result
576                        .iter()
577                        .fold((0, JsonSize::zero(), 0), |res, (set, size)| {
578                            (
579                                res.0 + set.len(),
580                                res.1 + set.estimated_json_encoded_size_of(),
581                                res.2 + size,
582                            )
583                        });
584                emit!(EndpointBytesReceived {
585                    byte_size: received_byte_size,
586                    protocol: "tcp",
587                    endpoint: &self.endpoint,
588                });
589                self.events_received
590                    .emit(CountByteSize(count, json_byte_size));
591                self.client.set((client, client_version));
592                Ok(result.into_iter().flat_map(|(metrics, _)| metrics))
593            }
594            Err(error) => Err(error.to_string()),
595        }
596    }
597
598    async fn collect_pg_stat_database(
599        &self,
600        client: &Client,
601        client_version: usize,
602    ) -> Result<(Vec<Metric>, usize), CollectError> {
603        let rows = self
604            .datname_filter
605            .pg_stat_database(client)
606            .await
607            .context(QuerySnafu)?;
608
609        let mut metrics = Vec::with_capacity(20 * rows.len());
610        let mut reader = RowReader::default();
611        for row in rows.iter() {
612            let db = reader.read::<Option<&str>>(row, "datname")?.unwrap_or("");
613
614            metrics.extend_from_slice(&[
615                self.create_metric(
616                    "pg_stat_database_datid",
617                    gauge!(reader.read::<u32>(row, "datid")?),
618                    tags!(self.tags, "db" => db),
619                ),
620                self.create_metric(
621                    "pg_stat_database_numbackends",
622                    gauge!(reader.read::<i32>(row, "numbackends")?),
623                    tags!(self.tags, "db" => db),
624                ),
625                self.create_metric(
626                    "pg_stat_database_xact_commit_total",
627                    counter!(reader.read::<i64>(row, "xact_commit")?),
628                    tags!(self.tags, "db" => db),
629                ),
630                self.create_metric(
631                    "pg_stat_database_xact_rollback_total",
632                    counter!(reader.read::<i64>(row, "xact_rollback")?),
633                    tags!(self.tags, "db" => db),
634                ),
635                self.create_metric(
636                    "pg_stat_database_blks_read_total",
637                    counter!(reader.read::<i64>(row, "blks_read")?),
638                    tags!(self.tags, "db" => db),
639                ),
640                self.create_metric(
641                    "pg_stat_database_blks_hit_total",
642                    counter!(reader.read::<i64>(row, "blks_hit")?),
643                    tags!(self.tags, "db" => db),
644                ),
645                self.create_metric(
646                    "pg_stat_database_tup_returned_total",
647                    counter!(reader.read::<i64>(row, "tup_returned")?),
648                    tags!(self.tags, "db" => db),
649                ),
650                self.create_metric(
651                    "pg_stat_database_tup_fetched_total",
652                    counter!(reader.read::<i64>(row, "tup_fetched")?),
653                    tags!(self.tags, "db" => db),
654                ),
655                self.create_metric(
656                    "pg_stat_database_tup_inserted_total",
657                    counter!(reader.read::<i64>(row, "tup_inserted")?),
658                    tags!(self.tags, "db" => db),
659                ),
660                self.create_metric(
661                    "pg_stat_database_tup_updated_total",
662                    counter!(reader.read::<i64>(row, "tup_updated")?),
663                    tags!(self.tags, "db" => db),
664                ),
665                self.create_metric(
666                    "pg_stat_database_tup_deleted_total",
667                    counter!(reader.read::<i64>(row, "tup_deleted")?),
668                    tags!(self.tags, "db" => db),
669                ),
670                self.create_metric(
671                    "pg_stat_database_conflicts_total",
672                    counter!(reader.read::<i64>(row, "conflicts")?),
673                    tags!(self.tags, "db" => db),
674                ),
675                self.create_metric(
676                    "pg_stat_database_temp_files_total",
677                    counter!(reader.read::<i64>(row, "temp_files")?),
678                    tags!(self.tags, "db" => db),
679                ),
680                self.create_metric(
681                    "pg_stat_database_temp_bytes_total",
682                    counter!(reader.read::<i64>(row, "temp_bytes")?),
683                    tags!(self.tags, "db" => db),
684                ),
685                self.create_metric(
686                    "pg_stat_database_deadlocks_total",
687                    counter!(reader.read::<i64>(row, "deadlocks")?),
688                    tags!(self.tags, "db" => db),
689                ),
690            ]);
691            if client_version >= 120000 {
692                metrics.extend_from_slice(&[
693                    self.create_metric(
694                        "pg_stat_database_checksum_failures_total",
695                        counter!(reader
696                            .read::<Option<i64>>(row, "checksum_failures")?
697                            .unwrap_or(0)),
698                        tags!(self.tags, "db" => db),
699                    ),
700                    self.create_metric(
701                        "pg_stat_database_checksum_last_failure",
702                        gauge!(reader
703                            .read::<Option<DateTime<Utc>>>(row, "checksum_last_failure")?
704                            .map(|t| t.timestamp())
705                            .unwrap_or(0)),
706                        tags!(self.tags, "db" => db),
707                    ),
708                ]);
709            }
710            metrics.extend_from_slice(&[
711                self.create_metric(
712                    "pg_stat_database_blk_read_time_seconds_total",
713                    counter!(reader.read::<f64>(row, "blk_read_time")? / 1000f64),
714                    tags!(self.tags, "db" => db),
715                ),
716                self.create_metric(
717                    "pg_stat_database_blk_write_time_seconds_total",
718                    counter!(reader.read::<f64>(row, "blk_write_time")? / 1000f64),
719                    tags!(self.tags, "db" => db),
720                ),
721                self.create_metric(
722                    "pg_stat_database_stats_reset",
723                    gauge!(reader
724                        .read::<Option<DateTime<Utc>>>(row, "stats_reset")?
725                        .map(|t| t.timestamp())
726                        .unwrap_or(0)),
727                    tags!(self.tags, "db" => db),
728                ),
729            ]);
730        }
731        Ok((metrics, reader.into_inner()))
732    }
733
734    async fn collect_pg_stat_database_conflicts(
735        &self,
736        client: &Client,
737    ) -> Result<(Vec<Metric>, usize), CollectError> {
738        let rows = self
739            .datname_filter
740            .pg_stat_database_conflicts(client)
741            .await
742            .context(QuerySnafu)?;
743
744        let mut metrics = Vec::with_capacity(5 * rows.len());
745        let mut reader = RowReader::default();
746        for row in rows.iter() {
747            let db = reader.read::<&str>(row, "datname")?;
748
749            metrics.extend_from_slice(&[
750                self.create_metric(
751                    "pg_stat_database_conflicts_confl_tablespace_total",
752                    counter!(reader.read::<i64>(row, "confl_tablespace")?),
753                    tags!(self.tags, "db" => db),
754                ),
755                self.create_metric(
756                    "pg_stat_database_conflicts_confl_lock_total",
757                    counter!(reader.read::<i64>(row, "confl_lock")?),
758                    tags!(self.tags, "db" => db),
759                ),
760                self.create_metric(
761                    "pg_stat_database_conflicts_confl_snapshot_total",
762                    counter!(reader.read::<i64>(row, "confl_snapshot")?),
763                    tags!(self.tags, "db" => db),
764                ),
765                self.create_metric(
766                    "pg_stat_database_conflicts_confl_bufferpin_total",
767                    counter!(reader.read::<i64>(row, "confl_bufferpin")?),
768                    tags!(self.tags, "db" => db),
769                ),
770                self.create_metric(
771                    "pg_stat_database_conflicts_confl_deadlock_total",
772                    counter!(reader.read::<i64>(row, "confl_deadlock")?),
773                    tags!(self.tags, "db" => db),
774                ),
775            ]);
776        }
777        Ok((metrics, reader.into_inner()))
778    }
779
780    async fn collect_pg_stat_bgwriter(
781        &self,
782        client: &Client,
783    ) -> Result<(Vec<Metric>, usize), CollectError> {
784        let row = self
785            .datname_filter
786            .pg_stat_bgwriter(client)
787            .await
788            .context(QuerySnafu)?;
789        let mut reader = RowReader::default();
790
791        Ok((
792            vec![
793                self.create_metric(
794                    "pg_stat_bgwriter_checkpoints_timed_total",
795                    counter!(reader.read::<i64>(&row, "checkpoints_timed")?),
796                    tags!(self.tags),
797                ),
798                self.create_metric(
799                    "pg_stat_bgwriter_checkpoints_req_total",
800                    counter!(reader.read::<i64>(&row, "checkpoints_req")?),
801                    tags!(self.tags),
802                ),
803                self.create_metric(
804                    "pg_stat_bgwriter_checkpoint_write_time_seconds_total",
805                    counter!(reader.read::<f64>(&row, "checkpoint_write_time")? / 1000f64),
806                    tags!(self.tags),
807                ),
808                self.create_metric(
809                    "pg_stat_bgwriter_checkpoint_sync_time_seconds_total",
810                    counter!(reader.read::<f64>(&row, "checkpoint_sync_time")? / 1000f64),
811                    tags!(self.tags),
812                ),
813                self.create_metric(
814                    "pg_stat_bgwriter_buffers_checkpoint_total",
815                    counter!(reader.read::<i64>(&row, "buffers_checkpoint")?),
816                    tags!(self.tags),
817                ),
818                self.create_metric(
819                    "pg_stat_bgwriter_buffers_clean_total",
820                    counter!(reader.read::<i64>(&row, "buffers_clean")?),
821                    tags!(self.tags),
822                ),
823                self.create_metric(
824                    "pg_stat_bgwriter_maxwritten_clean_total",
825                    counter!(reader.read::<i64>(&row, "maxwritten_clean")?),
826                    tags!(self.tags),
827                ),
828                self.create_metric(
829                    "pg_stat_bgwriter_buffers_backend_total",
830                    counter!(reader.read::<i64>(&row, "buffers_backend")?),
831                    tags!(self.tags),
832                ),
833                self.create_metric(
834                    "pg_stat_bgwriter_buffers_backend_fsync_total",
835                    counter!(reader.read::<i64>(&row, "buffers_backend_fsync")?),
836                    tags!(self.tags),
837                ),
838                self.create_metric(
839                    "pg_stat_bgwriter_buffers_alloc_total",
840                    counter!(reader.read::<i64>(&row, "buffers_alloc")?),
841                    tags!(self.tags),
842                ),
843                self.create_metric(
844                    "pg_stat_bgwriter_stats_reset",
845                    gauge!(reader
846                        .read::<DateTime<Utc>>(&row, "stats_reset")?
847                        .timestamp()),
848                    tags!(self.tags),
849                ),
850            ],
851            reader.into_inner(),
852        ))
853    }
854
855    fn create_metric(&self, name: &str, value: MetricValue, tags: MetricTags) -> Metric {
856        Metric::new(name, MetricKind::Absolute, value)
857            .with_namespace(self.namespace.clone())
858            .with_tags(Some(tags))
859            .with_timestamp(Some(Utc::now()))
860    }
861}
862
863#[derive(Default)]
864struct RowReader(usize);
865
866impl RowReader {
867    pub fn read<'a, T: FromSql<'a> + ByteSizeOf>(
868        &mut self,
869        row: &'a Row,
870        key: &'static str,
871    ) -> Result<T, CollectError> {
872        let value = row_get_value::<T>(row, key)?;
873        self.0 += value.size_of();
874        Ok(value)
875    }
876
877    pub const fn into_inner(self) -> usize {
878        self.0
879    }
880}
881
882fn row_get_value<'a, T: FromSql<'a> + ByteSizeOf>(
883    row: &'a Row,
884    key: &'static str,
885) -> Result<T, CollectError> {
886    row.try_get::<&str, T>(key)
887        .map_err(|source| CollectError::PostgresGetValue { source, key })
888}
889
890fn config_to_endpoint(config: &Config) -> String {
891    let mut params: Vec<(&'static str, String)> = vec![];
892
893    // options
894    if let Some(options) = config.get_options() {
895        params.push(("options", options.to_string()));
896    }
897
898    // application_name
899    if let Some(name) = config.get_application_name() {
900        params.push(("application_name", name.to_string()));
901    }
902
903    // ssl_mode, ignore default value (SslMode::Prefer)
904    match config.get_ssl_mode() {
905        SslMode::Disable => params.push(("sslmode", "disable".to_string())),
906        SslMode::Prefer => {} // default, ignore
907        SslMode::Require => params.push(("sslmode", "require".to_string())),
908        // non_exhaustive enum
909        _ => {
910            warn!("Unknown variant of \"SslMode\".");
911        }
912    };
913
914    // host
915    for host in config.get_hosts() {
916        match host {
917            Host::Tcp(host) => params.push(("host", host.to_string())),
918            #[cfg(unix)]
919            Host::Unix(path) => params.push(("host", path.to_string_lossy().to_string())),
920        }
921    }
922
923    // port
924    for port in config.get_ports() {
925        params.push(("port", port.to_string()));
926    }
927
928    // connect_timeout
929    if let Some(connect_timeout) = config.get_connect_timeout() {
930        params.push(("connect_timeout", connect_timeout.as_secs().to_string()));
931    }
932
933    // keepalives, ignore default value (true)
934    if !config.get_keepalives() {
935        params.push(("keepalives", "1".to_owned()));
936    }
937
938    // keepalives_idle, ignore default value (2 * 60 * 60)
939    let keepalives_idle = config.get_keepalives_idle().as_secs();
940    if keepalives_idle != 2 * 60 * 60 {
941        params.push(("keepalives_idle", keepalives_idle.to_string()));
942    }
943
944    // target_session_attrs, ignore default value (TargetSessionAttrs::Any)
945    match config.get_target_session_attrs() {
946        TargetSessionAttrs::Any => {} // default, ignore
947        TargetSessionAttrs::ReadWrite => {
948            params.push(("target_session_attrs", "read-write".to_owned()))
949        }
950        // non_exhaustive enum
951        _ => {
952            warn!("Unknown variant of \"TargetSessionAttr\".");
953        }
954    }
955
956    // channel_binding, ignore default value (ChannelBinding::Prefer)
957    match config.get_channel_binding() {
958        ChannelBinding::Disable => params.push(("channel_binding", "disable".to_owned())),
959        ChannelBinding::Prefer => {} // default, ignore
960        ChannelBinding::Require => params.push(("channel_binding", "require".to_owned())),
961        // non_exhaustive enum
962        _ => {
963            warn!("Unknown variant of \"ChannelBinding\".");
964        }
965    }
966
967    format!(
968        "postgresql:///{}?{}",
969        config.get_dbname().unwrap_or(""),
970        params
971            .into_iter()
972            .map(|(k, v)| format!(
973                "{}={}",
974                percent_encoding(k.as_bytes()),
975                percent_encoding(v.as_bytes())
976            ))
977            .collect::<Vec<String>>()
978            .join("&")
979    )
980}
981
982fn percent_encoding(input: &'_ [u8]) -> String {
983    percent_encoding::percent_encode(input, percent_encoding::NON_ALPHANUMERIC).to_string()
984}
985
986#[cfg(test)]
987mod tests {
988    use super::*;
989
990    #[test]
991    fn generate_config() {
992        crate::test_util::test_generate_config::<PostgresqlMetricsConfig>();
993    }
994}
995
996#[cfg(all(test, feature = "postgresql_metrics-integration-tests"))]
997mod integration_tests {
998    use super::*;
999    use crate::{
1000        event::Event,
1001        test_util::{
1002            components::{assert_source_compliance, PULL_SOURCE_TAGS},
1003            integration::postgres::{pg_socket, pg_url},
1004        },
1005        tls, SourceSender,
1006    };
1007
1008    async fn test_postgresql_metrics(
1009        endpoint: String,
1010        tls: Option<PostgresqlMetricsTlsConfig>,
1011        include_databases: Option<Vec<String>>,
1012        exclude_databases: Option<Vec<String>>,
1013    ) -> Vec<Event> {
1014        assert_source_compliance(&PULL_SOURCE_TAGS, async move {
1015            let config: Config = endpoint.parse().unwrap();
1016            let tags_endpoint = config_to_endpoint(&config);
1017            let tags_host = match config.get_hosts().first().unwrap() {
1018                Host::Tcp(host) => host.clone(),
1019                #[cfg(unix)]
1020                Host::Unix(path) => path.to_string_lossy().to_string(),
1021            };
1022
1023            let (sender, mut recv) = SourceSender::new_test();
1024
1025            tokio::spawn(async move {
1026                PostgresqlMetricsConfig {
1027                    endpoints: vec![endpoint],
1028                    tls,
1029                    include_databases,
1030                    exclude_databases,
1031                    ..Default::default()
1032                }
1033                .build(SourceContext::new_test(sender, None))
1034                .await
1035                .unwrap()
1036                .await
1037                .unwrap()
1038            });
1039
1040            let event = time::timeout(time::Duration::from_secs(3), recv.next())
1041                .await
1042                .expect("fetch metrics timeout")
1043                .expect("failed to get metrics from a stream");
1044            let mut events = vec![event];
1045            loop {
1046                match time::timeout(time::Duration::from_millis(10), recv.next()).await {
1047                    Ok(Some(event)) => events.push(event),
1048                    Ok(None) => break,
1049                    Err(_) => break,
1050                }
1051            }
1052
1053            assert!(events.len() > 1);
1054
1055            // test up metric
1056            assert_eq!(
1057                events
1058                    .iter()
1059                    .map(|e| e.as_metric())
1060                    .find(|e| e.name() == "up")
1061                    .unwrap()
1062                    .value(),
1063                &gauge!(1)
1064            );
1065
1066            // test namespace and tags
1067            for event in &events {
1068                let metric = event.as_metric();
1069
1070                assert_eq!(metric.namespace(), Some("postgresql"));
1071                assert_eq!(
1072                    metric.tags().unwrap().get("endpoint").unwrap(),
1073                    &tags_endpoint
1074                );
1075                assert_eq!(metric.tags().unwrap().get("host").unwrap(), &tags_host);
1076            }
1077
1078            // test metrics from different queries
1079            let names = vec![
1080                "pg_stat_database_datid",
1081                "pg_stat_database_conflicts_confl_tablespace_total",
1082                "pg_stat_bgwriter_checkpoints_timed_total",
1083            ];
1084            for name in names {
1085                assert!(events.iter().any(|e| e.as_metric().name() == name));
1086            }
1087
1088            events
1089        })
1090        .await
1091    }
1092
1093    #[tokio::test]
1094    async fn test_host() {
1095        test_postgresql_metrics(pg_url(), None, None, None).await;
1096    }
1097
1098    #[tokio::test]
1099    async fn test_local() {
1100        let endpoint = format!(
1101            "postgresql:///postgres?host={}&user=vector&password=vector",
1102            pg_socket().to_str().unwrap()
1103        );
1104        test_postgresql_metrics(endpoint, None, None, None).await;
1105    }
1106
1107    #[tokio::test]
1108    async fn test_host_ssl() {
1109        test_postgresql_metrics(
1110            format!("{}?sslmode=require", pg_url()),
1111            Some(PostgresqlMetricsTlsConfig {
1112                ca_file: tls::TEST_PEM_CA_PATH.into(),
1113            }),
1114            None,
1115            None,
1116        )
1117        .await;
1118    }
1119
1120    #[tokio::test]
1121    async fn test_host_include_databases() {
1122        let events = test_postgresql_metrics(
1123            pg_url(),
1124            None,
1125            Some(vec!["^vec".to_owned(), "gres$".to_owned()]),
1126            None,
1127        )
1128        .await;
1129
1130        for event in events {
1131            let metric = event.into_metric();
1132
1133            if let Some(db) = metric.tags().unwrap().get("db") {
1134                assert!(db == "vector" || db == "postgres");
1135            }
1136        }
1137    }
1138
1139    #[tokio::test]
1140    async fn test_host_exclude_databases() {
1141        let events = test_postgresql_metrics(
1142            pg_url(),
1143            None,
1144            None,
1145            Some(vec!["^vec".to_owned(), "gres$".to_owned()]),
1146        )
1147        .await;
1148
1149        for event in events {
1150            let metric = event.into_metric();
1151
1152            if let Some(db) = metric.tags().unwrap().get("db") {
1153                assert!(db != "vector" && db != "postgres");
1154            }
1155        }
1156    }
1157
1158    #[tokio::test]
1159    async fn test_host_exclude_databases_empty() {
1160        test_postgresql_metrics(pg_url(), None, None, Some(vec!["".to_owned()])).await;
1161    }
1162
1163    #[tokio::test]
1164    async fn test_host_include_databases_and_exclude_databases() {
1165        let events = test_postgresql_metrics(
1166            pg_url(),
1167            None,
1168            Some(vec!["template\\d+".to_owned()]),
1169            Some(vec!["template0".to_owned()]),
1170        )
1171        .await;
1172
1173        for event in events {
1174            let metric = event.into_metric();
1175
1176            if let Some(db) = metric.tags().unwrap().get("db") {
1177                assert!(db == "template1");
1178            }
1179        }
1180    }
1181}