vector/sources/util/grpc/
decompression.rs

1use std::{
2    cmp,
3    io::Write,
4    mem,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use bytes::{Buf, BufMut, BytesMut};
10use flate2::write::GzDecoder;
11use futures_util::FutureExt;
12use http::{Request, Response};
13use hyper::{
14    body::{HttpBody, Sender},
15    Body,
16};
17use std::future::Future;
18use tokio::{pin, select};
19use tonic::{body::BoxBody, metadata::AsciiMetadataValue, Status};
20use tower::{Layer, Service};
21use vector_lib::internal_event::{
22    ByteSize, BytesReceived, InternalEventHandle as _, Protocol, Registered,
23};
24
25use crate::internal_events::{GrpcError, GrpcInvalidCompressionSchemeError};
26
27// Every gRPC message has a five byte header:
28// - a compressed flag (u8, 0/1 for compressed/decompressed)
29// - a length prefix, indicating the number of remaining bytes to read (u32)
30const GRPC_MESSAGE_HEADER_LEN: usize = mem::size_of::<u8>() + mem::size_of::<u32>();
31const GRPC_ENCODING_HEADER: &str = "grpc-encoding";
32const GRPC_ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
33
34enum CompressionScheme {
35    Gzip,
36}
37
38impl CompressionScheme {
39    fn from_encoding_header(req: &Request<Body>) -> Result<Option<Self>, Status> {
40        req.headers()
41            .get(GRPC_ENCODING_HEADER)
42            .map(|s| {
43                s.to_str().map(|s| s.to_string()).map_err(|_| {
44                    Status::unimplemented(format!(
45                        "`{GRPC_ENCODING_HEADER}` contains non-visible characters and is not a valid encoding"
46                    ))
47                })
48            })
49            .transpose()
50            .and_then(|value| match value {
51                None => Ok(None),
52                Some(scheme) => match scheme.as_str() {
53                    "gzip" => Ok(Some(CompressionScheme::Gzip)),
54                    other => Err(Status::unimplemented(format!(
55                        "compression scheme `{other}` is not supported"
56                    ))),
57                },
58            })
59            .map_err(|mut status| {
60                status.metadata_mut().insert(
61                    GRPC_ACCEPT_ENCODING_HEADER,
62                    AsciiMetadataValue::from_static("gzip,identity"),
63                );
64                status
65            })
66    }
67}
68
69enum State {
70    WaitingForHeader,
71    Forward { overall_len: usize },
72    Decompress { remaining: usize },
73}
74
75impl Default for State {
76    fn default() -> Self {
77        Self::WaitingForHeader
78    }
79}
80
81fn new_decompressor() -> GzDecoder<Vec<u8>> {
82    // Create the backing buffer for the decompressor and set the compression flag to false (0) and pre-allocate
83    // the space for the length prefix, which we'll fill out once we've finalized the decompressor.
84    let buf = vec![0; GRPC_MESSAGE_HEADER_LEN];
85
86    GzDecoder::new(buf)
87}
88
89async fn drive_body_decompression(
90    mut source: Body,
91    mut destination: Sender,
92) -> Result<usize, Status> {
93    let mut state = State::default();
94    let mut buf = BytesMut::new();
95    let mut decompressor = None;
96    let mut bytes_received = 0;
97
98    // Drain all message chunks from the body first.
99    while let Some(result) = source.data().await {
100        let chunk = result.map_err(|_| Status::internal("failed to read from underlying body"))?;
101        buf.put(chunk);
102
103        let maybe_message = loop {
104            match state {
105                State::WaitingForHeader => {
106                    // If we don't have enough data yet to even read the gRPC message header, we can't do anything yet.
107                    if buf.len() < GRPC_MESSAGE_HEADER_LEN {
108                        break None;
109                    }
110
111                    // Extract the compressed flag and length prefix.
112                    let (is_compressed, message_len) = {
113                        let header = &buf[..GRPC_MESSAGE_HEADER_LEN];
114
115                        let message_len_raw: u32 = header[1..]
116                            .try_into()
117                            .map(u32::from_be_bytes)
118                            .expect("there must be four bytes remaining in the header slice");
119                        let message_len = message_len_raw
120                            .try_into()
121                            .expect("Vector does not support 16-bit platforms");
122
123                        (header[0] == 1, message_len)
124                    };
125
126                    // Now, if the message is not compressed, then put ourselves into forward mode, where we'll wait for
127                    // the rest of the message to come in -- decoding isn't streaming so there's no benefit there --
128                    // before we emit it.
129                    //
130                    // If the message _is_ compressed, we do roughly the same thing but we shove it into the
131                    // decompressor incrementally because there's no good reason to make both the internal buffer and
132                    // the decompressor buffer expand if we don't have to.
133                    if is_compressed {
134                        // We skip the header in the buffer because it doesn't matter to the decompressor and we
135                        // recreate it anyways.
136                        buf.advance(GRPC_MESSAGE_HEADER_LEN);
137
138                        state = State::Decompress {
139                            remaining: message_len,
140                        };
141                    } else {
142                        let overall_len = GRPC_MESSAGE_HEADER_LEN + message_len;
143                        state = State::Forward { overall_len };
144                    }
145                }
146                State::Forward { overall_len } => {
147                    // All we're doing at this point is waiting until we have all the bytes for the current gRPC message
148                    // before we emit them to the caller.
149                    if buf.len() < overall_len {
150                        break None;
151                    }
152
153                    // Now that we have all the bytes we need, slice them out of our internal buffer, reset our state,
154                    // and hand the message back to the caller.
155                    let message = buf.split_to(overall_len).freeze();
156                    state = State::WaitingForHeader;
157
158                    bytes_received += overall_len;
159
160                    break Some(message);
161                }
162                State::Decompress { ref mut remaining } => {
163                    if *remaining > 0 {
164                        // We're waiting for `remaining` more bytes to feed to the decompressor before we finalize it and
165                        // generate our new chunk of data. We might have data in our internal buffer, so try and drain that
166                        // first before polling the underlying body for more.
167                        let available = buf.len();
168                        if available > 0 {
169                            // Write the lesser of what the buffer has, or what is remaining for the current message, into
170                            // the decompressor. This is _technically_ synchronous but there's really no way to do it
171                            // asynchronously since we already have the data, and that's the only asynchronous part.
172                            let to_take = cmp::min(available, *remaining);
173                            let decompressor = decompressor.get_or_insert_with(new_decompressor);
174                            if decompressor.write_all(&buf[..to_take]).is_err() {
175                                return Err(Status::internal("failed to write to decompressor"));
176                            }
177
178                            *remaining -= to_take;
179                            buf.advance(to_take);
180                        } else {
181                            break None;
182                        }
183                    } else {
184                        // We don't need any more data, so consume the decompressor, finalize it by updating the length
185                        // prefix, and then pass it back to the caller.
186                        let result = decompressor
187                            .take()
188                            .expect("consumed decompressor when no decompressor was present")
189                            .finish();
190
191                        // The only I/O errors that occur during `finish` should be I/O errors from writing to the internal
192                        // buffer, but `Vec<T>` is infallible in this regard, so this should be impossible without having
193                        // first panicked due to memory exhaustion.
194                        let mut buf = result.map_err(|_| {
195                            Status::internal(
196                                "reached impossible error during decompressor finalization",
197                            )
198                        })?;
199                        bytes_received += buf.len();
200
201                        // Write the length of our decompressed message in the pre-allocated slot for the message's length prefix.
202                        let message_len_actual = buf.len() - GRPC_MESSAGE_HEADER_LEN;
203                        let message_len = u32::try_from(message_len_actual).map_err(|_| {
204                            Status::out_of_range("messages greater than 4GB are not supported")
205                        })?;
206
207                        let message_len_bytes = message_len.to_be_bytes();
208                        let message_len_slot = &mut buf[1..GRPC_MESSAGE_HEADER_LEN];
209                        message_len_slot.copy_from_slice(&message_len_bytes[..]);
210
211                        // Reset our state before returning the decompressed message.
212                        state = State::WaitingForHeader;
213
214                        break Some(buf.into());
215                    }
216                }
217            }
218        };
219
220        if let Some(message) = maybe_message {
221            // We got a decompressed (or passthrough) message chunk, so just forward it to the destination.
222            if destination.send_data(message).await.is_err() {
223                return Err(Status::internal("destination body abnormally closed"));
224            }
225        }
226    }
227
228    // When we've exhausted all the message chunks, we try sending any trailers that came in on the underlying body.
229    let result = source.trailers().await;
230    let maybe_trailers =
231        result.map_err(|_| Status::internal("error reading trailers from underlying body"))?;
232    if let Some(trailers) = maybe_trailers {
233        if destination.send_trailers(trailers).await.is_err() {
234            return Err(Status::internal("destination body abnormally closed"));
235        }
236    }
237
238    Ok(bytes_received)
239}
240
241async fn drive_request<F, E>(
242    source: Body,
243    destination: Sender,
244    inner: F,
245    bytes_received: Registered<BytesReceived>,
246) -> Result<Response<BoxBody>, E>
247where
248    F: Future<Output = Result<Response<BoxBody>, E>>,
249    E: std::fmt::Display,
250{
251    let body_decompression = drive_body_decompression(source, destination);
252
253    pin!(inner);
254    pin!(body_decompression);
255
256    let mut body_eof = false;
257    let mut body_bytes_received = 0;
258
259    let result = loop {
260        select! {
261            biased;
262
263            // Drive the inner future, as this will be consuming the message chunks we give it.
264            result = &mut inner => break result,
265
266            // Drive the core decompression loop, reading chunks from the underlying body, decompressing them if needed,
267            // and eventually handling trailers at the end, if they're present.
268            result = &mut body_decompression, if !body_eof => match result {
269                Err(e) => break Ok(e.to_http()),
270                Ok(bytes_received) => {
271                    body_bytes_received = bytes_received;
272                    body_eof = true;
273                },
274            }
275        }
276    };
277
278    // If the response indicates success, then emit the necessary metrics
279    // otherwise emit the error.
280    match &result {
281        Ok(res) if res.status().is_success() => {
282            bytes_received.emit(ByteSize(body_bytes_received));
283        }
284        Ok(res) => {
285            emit!(GrpcError {
286                error: format!("Received {}", res.status())
287            });
288        }
289        Err(error) => {
290            emit!(GrpcError { error: &error });
291        }
292    };
293
294    result
295}
296
297#[derive(Clone)]
298pub struct DecompressionAndMetrics<S> {
299    inner: S,
300    bytes_received: Registered<BytesReceived>,
301}
302
303impl<S> Service<Request<Body>> for DecompressionAndMetrics<S>
304where
305    S: Service<Request<Body>, Response = Response<BoxBody>> + Clone + Send + 'static,
306    S::Future: Send + 'static,
307    S::Error: std::fmt::Display,
308{
309    type Response = Response<BoxBody>;
310    type Error = S::Error;
311    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
312
313    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
314        self.inner.poll_ready(cx)
315    }
316
317    fn call(&mut self, req: Request<Body>) -> Self::Future {
318        match CompressionScheme::from_encoding_header(&req) {
319            // There was a header for the encoding, but it was either invalid data or a scheme we don't support.
320            Err(status) => {
321                emit!(GrpcInvalidCompressionSchemeError { status: &status });
322                Box::pin(async move { Ok(status.to_http()) })
323            }
324
325            // The request either isn't using compression, or it has indicated compression may be used and we know we
326            // can support decompression based on the indicated compression scheme... so wrap the body to decompress, if
327            // need be, and then track the bytes that flowed through.
328            //
329            // TODO: Actually use the scheme given back to us to support other compression schemes.
330            Ok(_) => {
331                let (destination, decompressed_body) = Body::channel();
332                let (req_parts, req_body) = req.into_parts();
333                let mapped_req = Request::from_parts(req_parts, decompressed_body);
334
335                let inner = self.inner.call(mapped_req);
336
337                drive_request(req_body, destination, inner, self.bytes_received.clone()).boxed()
338            }
339        }
340    }
341}
342
343/// A layer for decompressing Tonic request payloads and emitting telemetry for the payload sizes.
344///
345/// In some cases, we configure `tonic` to use compression on requests to save CPU and throughput when sending those
346/// large requests. In the case of Vector-to-Vector communication, this means the Vector v2 source may deal with
347/// compressed requests. The code already transparently handles decompression, but as part of our component
348/// specification, we have specific goals around what event representations we pay attention to.
349///
350/// In the case of tracking bytes sent/received, we always want to track the number of bytes received _after_
351/// decompression to faithfully represent the amount of data being processed by Vector. This poses a problem with the
352/// out-of-the-box `tonic` codegen as there is no hook whatsoever to inspect the raw request payload (after
353/// decompression, if it was compressed at all) prior to the payload being decoded as a Protocol Buffers payload.
354///
355/// This layer wraps the incoming body in our own body type, which allows us to do two things: decompress the payload
356/// before it enters the decoding phase, and emit metrics based on the decompressed payload.
357///
358/// Since we can see the decompressed bytes, and also know if the underlying service responded successfully -- i.e. the
359/// request was valid, and was processed -- we can now report the number of bytes (after decompression) that were
360/// received _and_ processed correctly.
361///
362/// The only supported compression scheme is gzip, which is also the only supported compression scheme in `tonic` itself.
363#[derive(Clone, Default)]
364pub struct DecompressionAndMetricsLayer;
365
366impl<S> Layer<S> for DecompressionAndMetricsLayer {
367    type Service = DecompressionAndMetrics<S>;
368
369    fn layer(&self, inner: S) -> Self::Service {
370        DecompressionAndMetrics {
371            inner,
372            bytes_received: register!(BytesReceived::from(Protocol::from("grpc"))),
373        }
374    }
375}