vector/api/schema/
relay.rs

1use std::convert::Infallible;
2
3use async_graphql::{
4    connection::{self, Connection, CursorType, Edge, EmptyFields},
5    Result, SimpleObject,
6};
7use base64::prelude::{Engine as _, BASE64_URL_SAFE_NO_PAD};
8
9/// Base64 invalid states, used by `Base64Cursor`.
10pub enum Base64CursorError {
11    /// Invalid cursor. This can happen if the base64 string is valid, but its contents don't
12    /// conform to the `name:index` pattern.
13    Invalid,
14    /// Decoding error. If this happens, the string isn't valid base64.
15    DecodeError(base64::DecodeError),
16}
17
18impl std::fmt::Display for Base64CursorError {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(f, "Invalid cursor")
21    }
22}
23
24/// Base64 cursor implementation
25pub struct Base64Cursor {
26    name: &'static str,
27    index: usize,
28}
29
30impl Base64Cursor {
31    const fn new(index: usize) -> Self {
32        Self {
33            name: "Cursor",
34            index,
35        }
36    }
37
38    /// Returns a base64 string representation of the cursor
39    fn encode(&self) -> String {
40        BASE64_URL_SAFE_NO_PAD.encode(format!("{}:{}", self.name, self.index))
41    }
42
43    /// Decodes a base64 string into a cursor result
44    fn decode(s: &str) -> Result<Self, Base64CursorError> {
45        let bytes = BASE64_URL_SAFE_NO_PAD
46            .decode(s)
47            .map_err(Base64CursorError::DecodeError)?;
48
49        let cursor = String::from_utf8(bytes).map_err(|_| Base64CursorError::Invalid)?;
50        let index = cursor
51            .split(':')
52            .next_back()
53            .map(|s| s.parse::<usize>())
54            .ok_or(Base64CursorError::Invalid)?
55            .map_err(|_| Base64CursorError::Invalid)?;
56
57        Ok(Self::new(index))
58    }
59
60    /// Increment and return the index. Uses saturating_add to avoid overflow
61    /// issues.
62    const fn increment(&self) -> usize {
63        self.index.saturating_add(1)
64    }
65}
66
67impl From<Base64Cursor> for usize {
68    fn from(cursor: Base64Cursor) -> Self {
69        cursor.index
70    }
71}
72
73/// Makes the `Base64Cursor` compatible with Relay connections
74impl CursorType for Base64Cursor {
75    type Error = Base64CursorError;
76
77    fn decode_cursor(s: &str) -> Result<Self, Self::Error> {
78        Base64Cursor::decode(s)
79    }
80
81    fn encode_cursor(&self) -> String {
82        self.encode()
83    }
84}
85
86/// Additional fields to attach to the connection
87#[derive(SimpleObject)]
88pub struct ConnectionFields {
89    /// Total result set count
90    total_count: usize,
91}
92
93/// Relay connection result
94pub type ConnectionResult<T> = Result<Connection<Base64Cursor, T, ConnectionFields, EmptyFields>>;
95
96/// Relay-compliant connection parameters to page results by cursor/page size
97pub struct Params {
98    after: Option<String>,
99    before: Option<String>,
100    first: Option<i32>,
101    last: Option<i32>,
102}
103
104impl Params {
105    pub const fn new(
106        after: Option<String>,
107        before: Option<String>,
108        first: Option<i32>,
109        last: Option<i32>,
110    ) -> Self {
111        Self {
112            after,
113            before,
114            first,
115            last,
116        }
117    }
118}
119
120/// Creates a new Relay-compliant connection. Iterator must implement `ExactSizeIterator` to
121/// determine page position in the total result set.
122pub async fn query<T: async_graphql::OutputType, I: ExactSizeIterator<Item = T>>(
123    iter: I,
124    p: Params,
125    default_page_size: usize,
126) -> ConnectionResult<T> {
127    connection::query::<_, _, Base64Cursor, _, _, ConnectionFields, _, _, _, Infallible>(
128        p.after,
129        p.before,
130        p.first,
131        p.last,
132        |after, before, first, last| async move {
133            let iter_len = iter.len();
134
135            let (start, end) = {
136                let after = after.map(|a| a.increment()).unwrap_or(0);
137                let before: usize = before.map(|b| b.into()).unwrap_or(iter_len);
138
139                // Calculate start/end based on the provided first/last. Note that async-graphql disallows
140                // providing both (returning an error), so we can safely assume we have, at most, one of
141                // first or last.
142                match (first, last) {
143                    // First
144                    (Some(first), _) => (after, (after.saturating_add(first)).min(before)),
145                    // Last
146                    (_, Some(last)) => ((before.saturating_sub(last)).max(after), before),
147                    // Default page size
148                    _ => (after, default_page_size.min(before)),
149                }
150            };
151
152            let mut connection = Connection::with_additional_fields(
153                start > 0,
154                end < iter_len,
155                ConnectionFields {
156                    total_count: iter_len,
157                },
158            );
159            connection.edges.extend(
160                (start..end)
161                    .zip(iter.skip(start))
162                    .map(|(cursor, node)| Edge::new(Base64Cursor::new(cursor), node)),
163            );
164            Ok(connection)
165        },
166    )
167    .await
168}