1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
use std::convert::Infallible;

use async_graphql::{
    connection::{self, Connection, CursorType, Edge, EmptyFields},
    Result, SimpleObject,
};
use base64::prelude::{Engine as _, BASE64_URL_SAFE_NO_PAD};

/// Base64 invalid states, used by `Base64Cursor`.
pub enum Base64CursorError {
    /// Invalid cursor. This can happen if the base64 string is valid, but its contents don't
    /// conform to the `name:index` pattern.
    Invalid,
    /// Decoding error. If this happens, the string isn't valid base64.
    DecodeError(base64::DecodeError),
}

impl std::fmt::Display for Base64CursorError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Invalid cursor")
    }
}

/// Base64 cursor implementation
pub struct Base64Cursor {
    name: &'static str,
    index: usize,
}

impl Base64Cursor {
    const fn new(index: usize) -> Self {
        Self {
            name: "Cursor",
            index,
        }
    }

    /// Returns a base64 string representation of the cursor
    fn encode(&self) -> String {
        BASE64_URL_SAFE_NO_PAD.encode(format!("{}:{}", self.name, self.index))
    }

    /// Decodes a base64 string into a cursor result
    fn decode(s: &str) -> Result<Self, Base64CursorError> {
        let bytes = BASE64_URL_SAFE_NO_PAD
            .decode(s)
            .map_err(Base64CursorError::DecodeError)?;

        let cursor = String::from_utf8(bytes).map_err(|_| Base64CursorError::Invalid)?;
        let index = cursor
            .split(':')
            .last()
            .map(|s| s.parse::<usize>())
            .ok_or(Base64CursorError::Invalid)?
            .map_err(|_| Base64CursorError::Invalid)?;

        Ok(Self::new(index))
    }

    /// Increment and return the index. Uses saturating_add to avoid overflow
    /// issues.
    const fn increment(&self) -> usize {
        self.index.saturating_add(1)
    }
}

impl From<Base64Cursor> for usize {
    fn from(cursor: Base64Cursor) -> Self {
        cursor.index
    }
}

/// Makes the `Base64Cursor` compatible with Relay connections
impl CursorType for Base64Cursor {
    type Error = Base64CursorError;

    fn decode_cursor(s: &str) -> Result<Self, Self::Error> {
        Base64Cursor::decode(s)
    }

    fn encode_cursor(&self) -> String {
        self.encode()
    }
}

/// Additional fields to attach to the connection
#[derive(SimpleObject)]
pub struct ConnectionFields {
    /// Total result set count
    total_count: usize,
}

/// Relay connection result
pub type ConnectionResult<T> = Result<Connection<Base64Cursor, T, ConnectionFields, EmptyFields>>;

/// Relay-compliant connection parameters to page results by cursor/page size
pub struct Params {
    after: Option<String>,
    before: Option<String>,
    first: Option<i32>,
    last: Option<i32>,
}

impl Params {
    pub const fn new(
        after: Option<String>,
        before: Option<String>,
        first: Option<i32>,
        last: Option<i32>,
    ) -> Self {
        Self {
            after,
            before,
            first,
            last,
        }
    }
}

/// Creates a new Relay-compliant connection. Iterator must implement `ExactSizeIterator` to
/// determine page position in the total result set.
pub async fn query<T: async_graphql::OutputType, I: ExactSizeIterator<Item = T>>(
    iter: I,
    p: Params,
    default_page_size: usize,
) -> ConnectionResult<T> {
    connection::query::<_, _, Base64Cursor, _, _, ConnectionFields, _, _, _, Infallible>(
        p.after,
        p.before,
        p.first,
        p.last,
        |after, before, first, last| async move {
            let iter_len = iter.len();

            let (start, end) = {
                let after = after.map(|a| a.increment()).unwrap_or(0);
                let before: usize = before.map(|b| b.into()).unwrap_or(iter_len);

                // Calculate start/end based on the provided first/last. Note that async-graphql disallows
                // providing both (returning an error), so we can safely assume we have, at most, one of
                // first or last.
                match (first, last) {
                    // First
                    (Some(first), _) => (after, (after.saturating_add(first)).min(before)),
                    // Last
                    (_, Some(last)) => ((before.saturating_sub(last)).max(after), before),
                    // Default page size
                    _ => (after, default_page_size.min(before)),
                }
            };

            let mut connection = Connection::with_additional_fields(
                start > 0,
                end < iter_len,
                ConnectionFields {
                    total_count: iter_len,
                },
            );
            connection.edges.extend(
                (start..end)
                    .zip(iter.skip(start))
                    .map(|(cursor, node)| Edge::new(Base64Cursor::new(cursor), node)),
            );
            Ok(connection)
        },
    )
    .await
}