use std::convert::Infallible;
use async_graphql::{
connection::{self, Connection, CursorType, Edge, EmptyFields},
Result, SimpleObject,
};
use base64::prelude::{Engine as _, BASE64_URL_SAFE_NO_PAD};
pub enum Base64CursorError {
Invalid,
DecodeError(base64::DecodeError),
}
impl std::fmt::Display for Base64CursorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Invalid cursor")
}
}
pub struct Base64Cursor {
name: &'static str,
index: usize,
}
impl Base64Cursor {
const fn new(index: usize) -> Self {
Self {
name: "Cursor",
index,
}
}
fn encode(&self) -> String {
BASE64_URL_SAFE_NO_PAD.encode(format!("{}:{}", self.name, self.index))
}
fn decode(s: &str) -> Result<Self, Base64CursorError> {
let bytes = BASE64_URL_SAFE_NO_PAD
.decode(s)
.map_err(Base64CursorError::DecodeError)?;
let cursor = String::from_utf8(bytes).map_err(|_| Base64CursorError::Invalid)?;
let index = cursor
.split(':')
.last()
.map(|s| s.parse::<usize>())
.ok_or(Base64CursorError::Invalid)?
.map_err(|_| Base64CursorError::Invalid)?;
Ok(Self::new(index))
}
const fn increment(&self) -> usize {
self.index.saturating_add(1)
}
}
impl From<Base64Cursor> for usize {
fn from(cursor: Base64Cursor) -> Self {
cursor.index
}
}
impl CursorType for Base64Cursor {
type Error = Base64CursorError;
fn decode_cursor(s: &str) -> Result<Self, Self::Error> {
Base64Cursor::decode(s)
}
fn encode_cursor(&self) -> String {
self.encode()
}
}
#[derive(SimpleObject)]
pub struct ConnectionFields {
total_count: usize,
}
pub type ConnectionResult<T> = Result<Connection<Base64Cursor, T, ConnectionFields, EmptyFields>>;
pub struct Params {
after: Option<String>,
before: Option<String>,
first: Option<i32>,
last: Option<i32>,
}
impl Params {
pub const fn new(
after: Option<String>,
before: Option<String>,
first: Option<i32>,
last: Option<i32>,
) -> Self {
Self {
after,
before,
first,
last,
}
}
}
pub async fn query<T: async_graphql::OutputType, I: ExactSizeIterator<Item = T>>(
iter: I,
p: Params,
default_page_size: usize,
) -> ConnectionResult<T> {
connection::query::<_, _, Base64Cursor, _, _, ConnectionFields, _, _, _, Infallible>(
p.after,
p.before,
p.first,
p.last,
|after, before, first, last| async move {
let iter_len = iter.len();
let (start, end) = {
let after = after.map(|a| a.increment()).unwrap_or(0);
let before: usize = before.map(|b| b.into()).unwrap_or(iter_len);
match (first, last) {
(Some(first), _) => (after, (after.saturating_add(first)).min(before)),
(_, Some(last)) => ((before.saturating_sub(last)).max(after), before),
_ => (after, default_page_size.min(before)),
}
};
let mut connection = Connection::with_additional_fields(
start > 0,
end < iter_len,
ConnectionFields {
total_count: iter_len,
},
);
connection.edges.extend(
(start..end)
.zip(iter.skip(start))
.map(|(cursor, node)| Edge::new(Base64Cursor::new(cursor), node)),
);
Ok(connection)
},
)
.await
}