vector_api_client/
subscription.rs

1use std::{
2    collections::HashMap,
3    pin::Pin,
4    sync::{Arc, Mutex},
5};
6
7use futures::SinkExt;
8use graphql_client::GraphQLQuery;
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11use tokio::sync::{
12    broadcast::{self, Sender},
13    mpsc, oneshot,
14};
15use tokio_stream::{wrappers::BroadcastStream, Stream, StreamExt};
16use tokio_tungstenite::{connect_async, tungstenite::Message};
17use url::Url;
18use uuid::Uuid;
19
20/// Subscription GraphQL response, returned from an active stream.
21pub type BoxedSubscription<T> = Pin<
22    Box<
23        dyn Stream<Item = Option<graphql_client::Response<<T as GraphQLQuery>::ResponseData>>>
24            + Send
25            + Sync,
26    >,
27>;
28
29/// Payload contains the raw data received back from a GraphQL subscription. At the point
30/// of receiving data, the only known fields are { id, type }; what's contained inside the
31/// `payload` field is unknown until we attempt to deserialize it against a generated
32/// GraphQLQuery::ResponseData later.
33#[derive(Serialize, Deserialize, Debug, Clone)]
34pub struct Payload {
35    id: Uuid,
36    #[serde(rename = "type")]
37    payload_type: String,
38    payload: serde_json::Value,
39}
40
41impl Payload {
42    /// Returns a "start" payload necessary for starting a new subscription.
43    pub fn start<T: GraphQLQuery + Send + Sync>(
44        id: Uuid,
45        payload: &graphql_client::QueryBody<T::Variables>,
46    ) -> Self {
47        Self {
48            id,
49            payload_type: "start".to_owned(),
50            payload: json!(payload),
51        }
52    }
53
54    /// Returns a "stop" payload for terminating the subscription in the GraphQL server.
55    fn stop(id: Uuid) -> Self {
56        Self {
57            id,
58            payload_type: "stop".to_owned(),
59            payload: serde_json::Value::Null,
60        }
61    }
62
63    /// Attempts to return a definitive ResponseData on the `payload` field, matched against
64    /// a generated `GraphQLQuery`.
65    fn response<T: GraphQLQuery + Send + Sync>(
66        &self,
67    ) -> Option<graphql_client::Response<T::ResponseData>> {
68        serde_json::from_value::<graphql_client::Response<T::ResponseData>>(self.payload.clone())
69            .ok()
70    }
71}
72
73/// A single `SubscriptionClient` enables subscription multiplexing.
74#[derive(Debug)]
75pub struct SubscriptionClient {
76    tx: mpsc::UnboundedSender<Payload>,
77    subscriptions: Arc<Mutex<HashMap<Uuid, Sender<Payload>>>>,
78    _shutdown_tx: oneshot::Sender<()>,
79}
80
81impl SubscriptionClient {
82    /// Create a new subscription client. `tx` is a channel for sending `Payload`s to the
83    /// GraphQL server; `rx` is a channel for `Payload` back.
84    fn new(tx: mpsc::UnboundedSender<Payload>, mut rx: mpsc::UnboundedReceiver<Payload>) -> Self {
85        // Oneshot channel for cancelling the listener if SubscriptionClient is dropped
86        let (_shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
87
88        let subscriptions = Arc::new(Mutex::new(HashMap::new()));
89        let subscriptions_clone = Arc::clone(&subscriptions);
90
91        // Spawn a handler for shutdown, and relaying received `Payload`s back to the relevant
92        // subscription.
93        let tx_clone = tx.clone();
94        tokio::spawn(async move {
95            loop {
96                tokio::select! {
97                    // Break the loop if shutdown is triggered. This happens implicitly once
98                    // the client goes out of scope
99                    _ = &mut shutdown_rx => {
100                        let subscriptions = subscriptions_clone.lock().unwrap();
101                        for id in subscriptions.keys() {
102                            _ = tx_clone.send(Payload::stop(*id));
103                        }
104                        break
105                    },
106
107                    // Handle receiving payloads back _from_ the server
108                    message = rx.recv() => {
109                        match message {
110                            Some(p) => {
111                                let subscriptions = subscriptions_clone.lock().unwrap();
112                                let s: Option<&Sender<Payload>> = subscriptions.get::<Uuid>(&p.id);
113                                if let Some(s) = s {
114                                    _ = s.send(p);
115                                }
116                            }
117                            None => {
118                                subscriptions_clone.lock().unwrap().clear();
119                                break;
120                            },
121                        }
122                    }
123                }
124            }
125        });
126
127        Self {
128            tx,
129            subscriptions,
130            _shutdown_tx,
131        }
132    }
133
134    /// Start a new subscription request.
135    pub fn start<T>(
136        &self,
137        request_body: &graphql_client::QueryBody<T::Variables>,
138    ) -> BoxedSubscription<T>
139    where
140        T: GraphQLQuery + Send + Sync,
141        <T as GraphQLQuery>::ResponseData: Unpin + Send + Sync + 'static,
142    {
143        // Generate a unique ID for the subscription. Subscriptions can be multiplexed
144        // over a single connection, so we'll keep a copy of this against the client to
145        // handling routing responses back to the relevant subscriber.
146        let id = Uuid::new_v4();
147
148        let (tx, rx) = broadcast::channel::<Payload>(100);
149
150        self.subscriptions.lock().unwrap().insert(id, tx);
151
152        // Send start subscription command with the relevant control messages.
153        _ = self.tx.send(Payload::start::<T>(id, request_body));
154
155        Box::pin(
156            BroadcastStream::new(rx)
157                .filter(Result::is_ok)
158                .map(|p| p.unwrap().response::<T>()),
159        )
160    }
161}
162
163/// Connect to a new WebSocket GraphQL server endpoint, and return a `SubscriptionClient`.
164/// This method will a) connect to a ws(s):// endpoint, and perform the initial handshake, and b)
165/// set up channel forwarding to expose just the returned `Payload`s to the client.
166pub async fn connect_subscription_client(
167    url: Url,
168) -> Result<SubscriptionClient, tokio_tungstenite::tungstenite::Error> {
169    let (ws, _) = connect_async(url).await?;
170    let (mut ws_tx, mut ws_rx) = futures::StreamExt::split(ws);
171
172    let (send_tx, mut send_rx) = mpsc::unbounded_channel::<Payload>();
173    let (recv_tx, recv_rx) = mpsc::unbounded_channel::<Payload>();
174
175    // Initialize the connection
176    _ = ws_tx
177        .send(Message::Text(r#"{"type":"connection_init"}"#.to_string()))
178        .await;
179
180    // Forwarded received messages back upstream to the GraphQL server
181    tokio::spawn(async move {
182        while let Some(p) = send_rx.recv().await {
183            _ = ws_tx
184                .send(Message::Text(serde_json::to_string(&p).unwrap()))
185                .await;
186        }
187    });
188
189    // Forward received messages to the receiver channel.
190    tokio::spawn(async move {
191        while let Some(Ok(Message::Text(m))) = ws_rx.next().await {
192            if let Ok(p) = serde_json::from_str::<Payload>(&m) {
193                _ = recv_tx.send(p);
194            }
195        }
196    });
197
198    Ok(SubscriptionClient::new(send_tx, recv_rx))
199}