1use std::{
2    convert::Infallible,
3    net::SocketAddr,
4    sync::{Arc, atomic::AtomicBool},
5};
6
7use async_graphql::{
8    Data, Request, Schema,
9    http::{GraphQLPlaygroundConfig, WebSocketProtocols, playground_source},
10};
11use async_graphql_warp::{GraphQLResponse, GraphQLWebSocket, graphql_protocol};
12use hyper::{Server as HyperServer, server::conn::AddrIncoming, service::make_service_fn};
13use tokio::{runtime::Handle, sync::oneshot};
14use tower::ServiceBuilder;
15use tracing::Span;
16use vector_lib::tap::topology;
17use warp::{Filter, Reply, filters::BoxedFilter, http::Response, ws::Ws};
18
19use super::{handler, schema};
20use crate::{
21    config::{self, api},
22    http::build_http_trace_layer,
23    internal_events::{SocketBindError, SocketMode},
24};
25
26pub struct Server {
27    _shutdown: oneshot::Sender<()>,
28    addr: SocketAddr,
29}
30
31impl Server {
32    pub fn start(
35        config: &config::Config,
36        watch_rx: topology::WatchRx,
37        running: Arc<AtomicBool>,
38        handle: &Handle,
39    ) -> crate::Result<Self> {
40        let routes = make_routes(config.api, watch_rx, running);
41
42        let (_shutdown, rx) = oneshot::channel();
43        let _guard = handle.enter();
45
46        let addr = config.api.address.expect("No socket address");
47        let incoming = AddrIncoming::bind(&addr).inspect_err(|error| {
48            emit!(SocketBindError {
49                mode: SocketMode::Tcp,
50                error,
51            });
52        })?;
53
54        let span = Span::current();
55        let make_svc = make_service_fn(move |_conn| {
56            let svc = ServiceBuilder::new()
57                .layer(build_http_trace_layer(span.clone()))
58                .service(warp::service(routes.clone()));
59            futures_util::future::ok::<_, Infallible>(svc)
60        });
61
62        let server = async move {
63            HyperServer::builder(incoming)
64                .serve(make_svc)
65                .with_graceful_shutdown(async {
66                    rx.await.ok();
67                })
68                .await
69                .map_err(|err| {
70                    error!("An error occurred: {:?}.", err);
71                })
72        };
73
74        schema::components::update_config(config);
76
77        handle.spawn(server);
79
80        Ok(Self { _shutdown, addr })
81    }
82
83    pub const fn addr(&self) -> SocketAddr {
85        self.addr
86    }
87
88    pub fn update_config(&self, config: &config::Config) {
92        schema::components::update_config(config)
93    }
94}
95
96fn make_routes(
97    api: api::Options,
98    watch_tx: topology::WatchRx,
99    running: Arc<AtomicBool>,
100) -> BoxedFilter<(impl Reply,)> {
101    let health = warp::path("health")
105        .and(with_shared(running))
106        .and_then(handler::health);
107
108    let not_found_graphql = warp::any().and_then(|| async { Err(warp::reject::not_found()) });
110    let not_found = warp::any().and_then(|| async { Err(warp::reject::not_found()) });
111
112    let graphql_subscription_handler =
117        warp::ws()
118            .and(graphql_protocol())
119            .map(move |ws: Ws, protocol: WebSocketProtocols| {
120                let schema = schema::build_schema().finish();
121                let watch_tx = watch_tx.clone();
122
123                let reply = ws.on_upgrade(move |socket| {
124                    let mut data = Data::default();
125                    data.insert(watch_tx);
126
127                    GraphQLWebSocket::new(socket, schema, protocol)
128                        .with_data(data)
129                        .serve()
130                });
131
132                warp::reply::with_header(
133                    reply,
134                    "Sec-WebSocket-Protocol",
135                    protocol.sec_websocket_protocol(),
136                )
137            });
138
139    let graphql_handler = if api.graphql {
143        warp::path("graphql")
144            .and(graphql_subscription_handler.or(
145                async_graphql_warp::graphql(schema::build_schema().finish()).and_then(
146                    |(schema, request): (Schema<_, _, _>, Request)| async move {
147                        Ok::<_, Infallible>(GraphQLResponse::from(schema.execute(request).await))
148                    },
149                ),
150            ))
151            .boxed()
152    } else {
153        not_found_graphql.boxed()
154    };
155
156    let graphql_playground = if api.playground && api.graphql {
158        warp::path("playground")
159            .map(move || {
160                Response::builder()
161                    .header("content-type", "text/html")
162                    .body(playground_source(
163                        GraphQLPlaygroundConfig::new("/graphql").subscription_endpoint("/graphql"),
164                    ))
165            })
166            .boxed()
167    } else {
168        not_found.boxed()
169    };
170
171    health
174        .or(graphql_handler)
175        .or(graphql_playground)
176        .or(not_found)
177        .with(
178            warp::cors()
179                .allow_any_origin()
180                .allow_headers(vec![
181                    "User-Agent",
182                    "Sec-Fetch-Mode",
183                    "Referer",
184                    "Origin",
185                    "Access-Control-Request-Method",
186                    "Access-Control-Allow-Origin",
187                    "Access-Control-Request-Headers",
188                    "Content-Type",
189                    "X-Apollo-Tracing", "Pragma",
191                    "Host",
192                    "Connection",
193                    "Cache-Control",
194                ])
195                .allow_methods(vec!["POST", "GET"]),
196        )
197        .boxed()
198}
199
200fn with_shared(
201    shared: Arc<AtomicBool>,
202) -> impl Filter<Extract = (Arc<AtomicBool>,), Error = Infallible> + Clone {
203    warp::any().map(move || Arc::<AtomicBool>::clone(&shared))
204}