1use std::{
2 convert::Infallible,
3 net::SocketAddr,
4 sync::{atomic::AtomicBool, Arc},
5};
6
7use async_graphql::{
8 http::{playground_source, GraphQLPlaygroundConfig, WebSocketProtocols},
9 Data, Request, Schema,
10};
11use async_graphql_warp::{graphql_protocol, GraphQLResponse, GraphQLWebSocket};
12use hyper::{server::conn::AddrIncoming, service::make_service_fn, Server as HyperServer};
13use tokio::runtime::Handle;
14use tokio::sync::oneshot;
15use tower::ServiceBuilder;
16use tracing::Span;
17use vector_lib::tap::topology;
18use warp::{filters::BoxedFilter, http::Response, ws::Ws, Filter, Reply};
19
20use super::{handler, schema};
21use crate::{
22 config::{self, api},
23 http::build_http_trace_layer,
24 internal_events::{SocketBindError, SocketMode},
25};
26
27pub struct Server {
28 _shutdown: oneshot::Sender<()>,
29 addr: SocketAddr,
30}
31
32impl Server {
33 pub fn start(
36 config: &config::Config,
37 watch_rx: topology::WatchRx,
38 running: Arc<AtomicBool>,
39 handle: &Handle,
40 ) -> crate::Result<Self> {
41 let routes = make_routes(config.api, watch_rx, running);
42
43 let (_shutdown, rx) = oneshot::channel();
44 let _guard = handle.enter();
46
47 let addr = config.api.address.expect("No socket address");
48 let incoming = AddrIncoming::bind(&addr).inspect_err(|error| {
49 emit!(SocketBindError {
50 mode: SocketMode::Tcp,
51 error,
52 });
53 })?;
54
55 let span = Span::current();
56 let make_svc = make_service_fn(move |_conn| {
57 let svc = ServiceBuilder::new()
58 .layer(build_http_trace_layer(span.clone()))
59 .service(warp::service(routes.clone()));
60 futures_util::future::ok::<_, Infallible>(svc)
61 });
62
63 let server = async move {
64 HyperServer::builder(incoming)
65 .serve(make_svc)
66 .with_graceful_shutdown(async {
67 rx.await.ok();
68 })
69 .await
70 .map_err(|err| {
71 error!("An error occurred: {:?}.", err);
72 })
73 };
74
75 schema::components::update_config(config);
77
78 handle.spawn(server);
80
81 Ok(Self { _shutdown, addr })
82 }
83
84 pub const fn addr(&self) -> SocketAddr {
86 self.addr
87 }
88
89 pub fn update_config(&self, config: &config::Config) {
93 schema::components::update_config(config)
94 }
95}
96
97fn make_routes(
98 api: api::Options,
99 watch_tx: topology::WatchRx,
100 running: Arc<AtomicBool>,
101) -> BoxedFilter<(impl Reply,)> {
102 let health = warp::path("health")
106 .and(with_shared(running))
107 .and_then(handler::health);
108
109 let not_found_graphql = warp::any().and_then(|| async { Err(warp::reject::not_found()) });
111 let not_found = warp::any().and_then(|| async { Err(warp::reject::not_found()) });
112
113 let graphql_subscription_handler =
118 warp::ws()
119 .and(graphql_protocol())
120 .map(move |ws: Ws, protocol: WebSocketProtocols| {
121 let schema = schema::build_schema().finish();
122 let watch_tx = watch_tx.clone();
123
124 let reply = ws.on_upgrade(move |socket| {
125 let mut data = Data::default();
126 data.insert(watch_tx);
127
128 GraphQLWebSocket::new(socket, schema, protocol)
129 .with_data(data)
130 .serve()
131 });
132
133 warp::reply::with_header(
134 reply,
135 "Sec-WebSocket-Protocol",
136 protocol.sec_websocket_protocol(),
137 )
138 });
139
140 let graphql_handler = if api.graphql {
144 warp::path("graphql")
145 .and(graphql_subscription_handler.or(
146 async_graphql_warp::graphql(schema::build_schema().finish()).and_then(
147 |(schema, request): (Schema<_, _, _>, Request)| async move {
148 Ok::<_, Infallible>(GraphQLResponse::from(schema.execute(request).await))
149 },
150 ),
151 ))
152 .boxed()
153 } else {
154 not_found_graphql.boxed()
155 };
156
157 let graphql_playground = if api.playground && api.graphql {
159 warp::path("playground")
160 .map(move || {
161 Response::builder()
162 .header("content-type", "text/html")
163 .body(playground_source(
164 GraphQLPlaygroundConfig::new("/graphql").subscription_endpoint("/graphql"),
165 ))
166 })
167 .boxed()
168 } else {
169 not_found.boxed()
170 };
171
172 health
175 .or(graphql_handler)
176 .or(graphql_playground)
177 .or(not_found)
178 .with(
179 warp::cors()
180 .allow_any_origin()
181 .allow_headers(vec![
182 "User-Agent",
183 "Sec-Fetch-Mode",
184 "Referer",
185 "Origin",
186 "Access-Control-Request-Method",
187 "Access-Control-Allow-Origin",
188 "Access-Control-Request-Headers",
189 "Content-Type",
190 "X-Apollo-Tracing", "Pragma",
192 "Host",
193 "Connection",
194 "Cache-Control",
195 ])
196 .allow_methods(vec!["POST", "GET"]),
197 )
198 .boxed()
199}
200
201fn with_shared(
202 shared: Arc<AtomicBool>,
203) -> impl Filter<Extract = (Arc<AtomicBool>,), Error = Infallible> + Clone {
204 warp::any().map(move || Arc::<AtomicBool>::clone(&shared))
205}