vector_api_client/
subscription.rs1use std::{
2 collections::HashMap,
3 pin::Pin,
4 sync::{Arc, Mutex},
5};
6
7use futures::SinkExt;
8use graphql_client::GraphQLQuery;
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11use tokio::sync::{
12 broadcast::{self, Sender},
13 mpsc, oneshot,
14};
15use tokio_stream::{wrappers::BroadcastStream, Stream, StreamExt};
16use tokio_tungstenite::{connect_async, tungstenite::Message};
17use url::Url;
18use uuid::Uuid;
19
20pub type BoxedSubscription<T> = Pin<
22 Box<
23 dyn Stream<Item = Option<graphql_client::Response<<T as GraphQLQuery>::ResponseData>>>
24 + Send
25 + Sync,
26 >,
27>;
28
29#[derive(Serialize, Deserialize, Debug, Clone)]
34pub struct Payload {
35 id: Uuid,
36 #[serde(rename = "type")]
37 payload_type: String,
38 payload: serde_json::Value,
39}
40
41impl Payload {
42 pub fn start<T: GraphQLQuery + Send + Sync>(
44 id: Uuid,
45 payload: &graphql_client::QueryBody<T::Variables>,
46 ) -> Self {
47 Self {
48 id,
49 payload_type: "start".to_owned(),
50 payload: json!(payload),
51 }
52 }
53
54 fn stop(id: Uuid) -> Self {
56 Self {
57 id,
58 payload_type: "stop".to_owned(),
59 payload: serde_json::Value::Null,
60 }
61 }
62
63 fn response<T: GraphQLQuery + Send + Sync>(
66 &self,
67 ) -> Option<graphql_client::Response<T::ResponseData>> {
68 serde_json::from_value::<graphql_client::Response<T::ResponseData>>(self.payload.clone())
69 .ok()
70 }
71}
72
73#[derive(Debug)]
75pub struct SubscriptionClient {
76 tx: mpsc::UnboundedSender<Payload>,
77 subscriptions: Arc<Mutex<HashMap<Uuid, Sender<Payload>>>>,
78 _shutdown_tx: oneshot::Sender<()>,
79}
80
81impl SubscriptionClient {
82 fn new(tx: mpsc::UnboundedSender<Payload>, mut rx: mpsc::UnboundedReceiver<Payload>) -> Self {
85 let (_shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
87
88 let subscriptions = Arc::new(Mutex::new(HashMap::new()));
89 let subscriptions_clone = Arc::clone(&subscriptions);
90
91 let tx_clone = tx.clone();
94 tokio::spawn(async move {
95 loop {
96 tokio::select! {
97 _ = &mut shutdown_rx => {
100 let subscriptions = subscriptions_clone.lock().unwrap();
101 for id in subscriptions.keys() {
102 _ = tx_clone.send(Payload::stop(*id));
103 }
104 break
105 },
106
107 message = rx.recv() => {
109 match message {
110 Some(p) => {
111 let subscriptions = subscriptions_clone.lock().unwrap();
112 let s: Option<&Sender<Payload>> = subscriptions.get::<Uuid>(&p.id);
113 if let Some(s) = s {
114 _ = s.send(p);
115 }
116 }
117 None => {
118 subscriptions_clone.lock().unwrap().clear();
119 break;
120 },
121 }
122 }
123 }
124 }
125 });
126
127 Self {
128 tx,
129 subscriptions,
130 _shutdown_tx,
131 }
132 }
133
134 pub fn start<T>(
136 &self,
137 request_body: &graphql_client::QueryBody<T::Variables>,
138 ) -> BoxedSubscription<T>
139 where
140 T: GraphQLQuery + Send + Sync,
141 <T as GraphQLQuery>::ResponseData: Unpin + Send + Sync + 'static,
142 {
143 let id = Uuid::new_v4();
147
148 let (tx, rx) = broadcast::channel::<Payload>(100);
149
150 self.subscriptions.lock().unwrap().insert(id, tx);
151
152 _ = self.tx.send(Payload::start::<T>(id, request_body));
154
155 Box::pin(
156 BroadcastStream::new(rx)
157 .filter(Result::is_ok)
158 .map(|p| p.unwrap().response::<T>()),
159 )
160 }
161}
162
163pub async fn connect_subscription_client(
167 url: Url,
168) -> Result<SubscriptionClient, tokio_tungstenite::tungstenite::Error> {
169 let (ws, _) = connect_async(url).await?;
170 let (mut ws_tx, mut ws_rx) = futures::StreamExt::split(ws);
171
172 let (send_tx, mut send_rx) = mpsc::unbounded_channel::<Payload>();
173 let (recv_tx, recv_rx) = mpsc::unbounded_channel::<Payload>();
174
175 _ = ws_tx
177 .send(Message::Text(r#"{"type":"connection_init"}"#.to_string()))
178 .await;
179
180 tokio::spawn(async move {
182 while let Some(p) = send_rx.recv().await {
183 _ = ws_tx
184 .send(Message::Text(serde_json::to_string(&p).unwrap()))
185 .await;
186 }
187 });
188
189 tokio::spawn(async move {
191 while let Some(Ok(Message::Text(m))) = ws_rx.next().await {
192 if let Ok(p) = serde_json::from_str::<Payload>(&m) {
193 _ = recv_tx.send(p);
194 }
195 }
196 });
197
198 Ok(SubscriptionClient::new(send_tx, recv_rx))
199}