vector_core/tls/
incoming.rs

1use ipnet::IpNet;
2use std::{
3    collections::HashMap,
4    future::Future,
5    net::SocketAddr,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9};
10
11use futures::{future::BoxFuture, stream, FutureExt, Stream};
12use openssl::ssl::{Ssl, SslAcceptor, SslMethod};
13use openssl::x509::X509;
14use snafu::ResultExt;
15use tokio::sync::{OwnedSemaphorePermit, Semaphore};
16use tokio::{
17    io::{self, AsyncRead, AsyncWrite, ReadBuf},
18    net::{TcpListener, TcpStream},
19};
20use tokio_openssl::SslStream;
21use tonic::transport::{server::Connected, Certificate};
22
23use super::{
24    CreateAcceptorSnafu, HandshakeSnafu, IncomingListenerSnafu, MaybeTlsSettings, MaybeTlsStream,
25    SslBuildSnafu, TcpBindSnafu, TlsError, TlsSettings,
26};
27use crate::tcp::{self, TcpKeepaliveConfig};
28
29impl TlsSettings {
30    pub fn acceptor(&self) -> crate::tls::Result<SslAcceptor> {
31        if self.identity.is_none() {
32            Err(TlsError::MissingRequiredIdentity)
33        } else {
34            let mut acceptor =
35                SslAcceptor::mozilla_intermediate(SslMethod::tls()).context(CreateAcceptorSnafu)?;
36            self.apply_context_base(&mut acceptor, true)?;
37            Ok(acceptor.build())
38        }
39    }
40}
41
42impl MaybeTlsSettings {
43    pub async fn bind(&self, addr: &SocketAddr) -> crate::tls::Result<MaybeTlsListener> {
44        let listener = TcpListener::bind(addr).await.context(TcpBindSnafu)?;
45
46        let acceptor = match self {
47            Self::Tls(tls) => Some(tls.acceptor()?),
48            Self::Raw(()) => None,
49        };
50
51        Ok(MaybeTlsListener {
52            listener,
53            acceptor,
54            origin_filter: None,
55        })
56    }
57
58    pub async fn bind_with_allowlist(
59        &self,
60        addr: &SocketAddr,
61        allow_origin: Vec<IpNet>,
62    ) -> crate::tls::Result<MaybeTlsListener> {
63        let listener = TcpListener::bind(addr).await.context(TcpBindSnafu)?;
64
65        let acceptor = match self {
66            Self::Tls(tls) => Some(tls.acceptor()?),
67            Self::Raw(()) => None,
68        };
69
70        Ok(MaybeTlsListener {
71            listener,
72            acceptor,
73            origin_filter: Some(allow_origin),
74        })
75    }
76}
77
78pub struct MaybeTlsListener {
79    listener: TcpListener,
80    acceptor: Option<SslAcceptor>,
81    origin_filter: Option<Vec<IpNet>>,
82}
83
84impl MaybeTlsListener {
85    pub async fn accept(&mut self) -> crate::tls::Result<MaybeTlsIncomingStream<TcpStream>> {
86        let listener = self
87            .listener
88            .accept()
89            .await
90            .map(|(stream, peer_addr)| {
91                MaybeTlsIncomingStream::new(stream, peer_addr, self.acceptor.clone())
92            })
93            .context(IncomingListenerSnafu)?;
94
95        if let Some(origin_filter) = &self.origin_filter {
96            if origin_filter
97                .iter()
98                .any(|net| net.contains(&listener.peer_addr().ip()))
99            {
100                Ok(listener)
101            } else {
102                Err(TlsError::Connect {
103                    source: std::io::ErrorKind::ConnectionRefused.into(),
104                })
105            }
106        } else {
107            Ok(listener)
108        }
109    }
110
111    async fn into_accept(
112        mut self,
113    ) -> (crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>, Self) {
114        (self.accept().await, self)
115    }
116
117    pub fn accept_stream(
118        self,
119    ) -> impl Stream<Item = crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>> {
120        let mut accept = Box::pin(self.into_accept());
121        stream::poll_fn(move |context| match accept.as_mut().poll(context) {
122            Poll::Ready((item, this)) => {
123                accept.set(this.into_accept());
124                Poll::Ready(Some(item))
125            }
126            Poll::Pending => Poll::Pending,
127        })
128    }
129
130    pub fn accept_stream_limited(
131        self,
132        max_connections: Option<u32>,
133    ) -> impl Stream<
134        Item = (
135            crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>,
136            Option<OwnedSemaphorePermit>,
137        ),
138    > {
139        let mut connection_semaphore_future = max_connections.map(|max| {
140            let semaphore = Arc::new(Semaphore::new(max as usize));
141            let future = Box::pin(semaphore.clone().acquire_owned());
142            (semaphore, future)
143        });
144
145        let mut accept = Box::pin(self.into_accept());
146        stream::poll_fn(move |context| {
147            let permit = match connection_semaphore_future.as_mut() {
148                Some((semaphore, future)) => match future.as_mut().poll(context) {
149                    Poll::Ready(permit) => {
150                        future.set(semaphore.clone().acquire_owned());
151                        permit.ok()
152                    }
153                    Poll::Pending => return Poll::Pending,
154                },
155                None => None,
156            };
157            match accept.as_mut().poll(context) {
158                Poll::Ready((item, this)) => {
159                    accept.set(this.into_accept());
160                    Poll::Ready(Some((item, permit)))
161                }
162                Poll::Pending => Poll::Pending,
163            }
164        })
165    }
166
167    pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
168        self.listener.local_addr()
169    }
170
171    #[must_use]
172    pub fn with_allowlist(mut self, allowlist: Option<Vec<IpNet>>) -> Self {
173        self.origin_filter = allowlist;
174        self
175    }
176}
177
178impl From<TcpListener> for MaybeTlsListener {
179    fn from(listener: TcpListener) -> Self {
180        Self {
181            listener,
182            acceptor: None,
183            origin_filter: None,
184        }
185    }
186}
187
188pub struct MaybeTlsIncomingStream<S> {
189    state: StreamState<S>,
190    // BoxFuture doesn't allow access to the inner stream, but users
191    // of MaybeTlsIncomingStream want access to the peer address while
192    // still handshaking, so we have to cache it here.
193    peer_addr: SocketAddr,
194}
195
196enum StreamState<S> {
197    Accepted(MaybeTlsStream<S>),
198    Accepting(BoxFuture<'static, Result<SslStream<S>, TlsError>>),
199    AcceptError(String),
200    Closed,
201}
202
203impl<S> MaybeTlsIncomingStream<S> {
204    pub const fn peer_addr(&self) -> SocketAddr {
205        self.peer_addr
206    }
207
208    /// None if connection still hasn't been established.
209    pub fn get_ref(&self) -> Option<&S> {
210        use super::MaybeTls;
211
212        match &self.state {
213            StreamState::Accepted(stream) => Some(match stream {
214                MaybeTls::Raw(s) => s,
215                MaybeTls::Tls(s) => s.get_ref(),
216            }),
217            StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
218        }
219    }
220
221    pub const fn ssl_stream(&self) -> Option<&SslStream<S>> {
222        use super::MaybeTls;
223
224        match &self.state {
225            StreamState::Accepted(stream) => match stream {
226                MaybeTls::Raw(_) => None,
227                MaybeTls::Tls(s) => Some(s),
228            },
229            StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
230        }
231    }
232
233    pub fn get_mut(&mut self) -> Option<&mut S> {
234        use super::MaybeTls;
235
236        match &mut self.state {
237            StreamState::Accepted(ref mut stream) => Some(match stream {
238                MaybeTls::Raw(ref mut s) => s,
239                MaybeTls::Tls(s) => s.get_mut(),
240            }),
241            StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
242        }
243    }
244}
245
246impl MaybeTlsIncomingStream<TcpStream> {
247    pub(super) fn new(
248        stream: TcpStream,
249        peer_addr: SocketAddr,
250        acceptor: Option<SslAcceptor>,
251    ) -> Self {
252        let state = match acceptor {
253            Some(acceptor) => StreamState::Accepting(
254                async move {
255                    let ssl = Ssl::new(acceptor.context()).context(SslBuildSnafu)?;
256                    let mut stream = SslStream::new(ssl, stream).context(SslBuildSnafu)?;
257                    Pin::new(&mut stream)
258                        .accept()
259                        .await
260                        .context(HandshakeSnafu)?;
261                    Ok(stream)
262                }
263                .boxed(),
264            ),
265            None => StreamState::Accepted(MaybeTlsStream::Raw(stream)),
266        };
267        Self { state, peer_addr }
268    }
269
270    // Explicit handshake method
271    pub async fn handshake(&mut self) -> crate::tls::Result<()> {
272        if let StreamState::Accepting(fut) = &mut self.state {
273            let stream = fut.await?;
274            self.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
275        }
276
277        Ok(())
278    }
279
280    pub fn set_keepalive(&mut self, keepalive: TcpKeepaliveConfig) -> io::Result<()> {
281        let stream = self.get_ref().ok_or_else(|| {
282            io::Error::new(
283                io::ErrorKind::NotConnected,
284                "Can't set keepalive on connection that has not been accepted yet.",
285            )
286        })?;
287
288        if let Some(time_secs) = keepalive.time_secs {
289            let config =
290                socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(time_secs));
291
292            tcp::set_keepalive(stream, &config)?;
293        }
294
295        Ok(())
296    }
297
298    pub fn set_receive_buffer_bytes(&mut self, bytes: usize) -> std::io::Result<()> {
299        let stream = self.get_ref().ok_or_else(|| {
300            io::Error::new(
301                io::ErrorKind::NotConnected,
302                "Can't set receive buffer size on connection that has not been accepted yet.",
303            )
304        })?;
305
306        tcp::set_receive_buffer_size(stream, bytes)
307    }
308
309    fn poll_io<T, F>(self: Pin<&mut Self>, cx: &mut Context, poll_fn: F) -> Poll<io::Result<T>>
310    where
311        F: FnOnce(Pin<&mut MaybeTlsStream<TcpStream>>, &mut Context) -> Poll<io::Result<T>>,
312    {
313        let this = self.get_mut();
314        loop {
315            return match &mut this.state {
316                StreamState::Accepted(stream) => poll_fn(Pin::new(stream), cx),
317                StreamState::Accepting(fut) => match std::task::ready!(fut.as_mut().poll(cx)) {
318                    Ok(stream) => {
319                        this.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
320                        continue;
321                    }
322                    Err(error) => {
323                        let error = io::Error::other(error);
324                        this.state = StreamState::AcceptError(error.to_string());
325                        Poll::Ready(Err(error))
326                    }
327                },
328                StreamState::AcceptError(error) => {
329                    Poll::Ready(Err(io::Error::other(error.clone())))
330                }
331                StreamState::Closed => Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())),
332            };
333        }
334    }
335}
336
337impl AsyncRead for MaybeTlsIncomingStream<TcpStream> {
338    fn poll_read(
339        self: Pin<&mut Self>,
340        cx: &mut Context,
341        buf: &mut ReadBuf<'_>,
342    ) -> Poll<io::Result<()>> {
343        self.poll_io(cx, |s, cx| s.poll_read(cx, buf))
344    }
345}
346
347impl AsyncWrite for MaybeTlsIncomingStream<TcpStream> {
348    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
349        self.poll_io(cx, |s, cx| s.poll_write(cx, buf))
350    }
351
352    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
353        self.poll_io(cx, AsyncWrite::poll_flush)
354    }
355
356    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
357        let this = self.get_mut();
358        match &mut this.state {
359            StreamState::Accepted(stream) => match Pin::new(stream).poll_shutdown(cx) {
360                Poll::Ready(Ok(())) => {
361                    this.state = StreamState::Closed;
362                    Poll::Ready(Ok(()))
363                }
364                poll_result => poll_result,
365            },
366            StreamState::Accepting(fut) => match std::task::ready!(fut.as_mut().poll(cx)) {
367                Ok(stream) => {
368                    this.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
369                    Poll::Pending
370                }
371                Err(error) => {
372                    let error = io::Error::other(error);
373                    this.state = StreamState::AcceptError(error.to_string());
374                    Poll::Ready(Err(error))
375                }
376            },
377            StreamState::AcceptError(error) => Poll::Ready(Err(io::Error::other(error.clone()))),
378            StreamState::Closed => Poll::Ready(Ok(())),
379        }
380    }
381}
382
383#[derive(Debug)]
384pub struct CertificateMetadata {
385    pub country_name: Option<String>,
386    pub state_or_province_name: Option<String>,
387    pub locality_name: Option<String>,
388    pub organization_name: Option<String>,
389    pub organizational_unit_name: Option<String>,
390    pub common_name: Option<String>,
391}
392
393impl CertificateMetadata {
394    pub fn subject(&self) -> String {
395        let mut components = Vec::<String>::with_capacity(6);
396        if let Some(cn) = &self.common_name {
397            components.push(format!("CN={cn}"));
398        }
399        if let Some(ou) = &self.organizational_unit_name {
400            components.push(format!("OU={ou}"));
401        }
402        if let Some(o) = &self.organization_name {
403            components.push(format!("O={o}"));
404        }
405        if let Some(l) = &self.locality_name {
406            components.push(format!("L={l}"));
407        }
408        if let Some(st) = &self.state_or_province_name {
409            components.push(format!("ST={st}"));
410        }
411        if let Some(c) = &self.country_name {
412            components.push(format!("C={c}"));
413        }
414        components.join(",")
415    }
416}
417
418impl From<X509> for CertificateMetadata {
419    fn from(cert: X509) -> Self {
420        let mut subject_metadata: HashMap<String, String> = HashMap::new();
421        for entry in cert.subject_name().entries() {
422            let data_string = match entry.data().as_utf8() {
423                Ok(data) => data.to_string(),
424                Err(_) => String::new(),
425            };
426            subject_metadata.insert(entry.object().to_string(), data_string);
427        }
428        Self {
429            country_name: subject_metadata.get("countryName").cloned(),
430            state_or_province_name: subject_metadata.get("stateOrProvinceName").cloned(),
431            locality_name: subject_metadata.get("localityName").cloned(),
432            organization_name: subject_metadata.get("organizationName").cloned(),
433            organizational_unit_name: subject_metadata.get("organizationalUnitName").cloned(),
434            common_name: subject_metadata.get("commonName").cloned(),
435        }
436    }
437}
438
439#[derive(Clone)]
440pub struct MaybeTlsConnectInfo {
441    pub remote_addr: SocketAddr,
442    pub peer_certs: Option<Vec<Certificate>>,
443}
444
445impl Connected for MaybeTlsIncomingStream<TcpStream> {
446    type ConnectInfo = MaybeTlsConnectInfo;
447
448    fn connect_info(&self) -> Self::ConnectInfo {
449        MaybeTlsConnectInfo {
450            remote_addr: self.peer_addr(),
451            peer_certs: self
452                .ssl_stream()
453                .and_then(|s| s.ssl().peer_cert_chain())
454                .map(|s| {
455                    s.into_iter()
456                        .filter_map(|c| c.to_pem().ok())
457                        .map(Certificate::from_pem)
458                        .collect()
459                }),
460        }
461    }
462}
463
464#[cfg(test)]
465mod test {
466    use super::*;
467
468    #[test]
469    fn certificate_metadata_full() {
470        let example_meta = CertificateMetadata {
471            common_name: Some("common".to_owned()),
472            country_name: Some("country".to_owned()),
473            locality_name: Some("locality".to_owned()),
474            organization_name: Some("organization".to_owned()),
475            organizational_unit_name: Some("org_unit".to_owned()),
476            state_or_province_name: Some("state".to_owned()),
477        };
478
479        let expected = format!(
480            "CN={},OU={},O={},L={},ST={},C={}",
481            example_meta.common_name.as_ref().unwrap(),
482            example_meta.organizational_unit_name.as_ref().unwrap(),
483            example_meta.organization_name.as_ref().unwrap(),
484            example_meta.locality_name.as_ref().unwrap(),
485            example_meta.state_or_province_name.as_ref().unwrap(),
486            example_meta.country_name.as_ref().unwrap()
487        );
488        assert_eq!(expected, example_meta.subject());
489    }
490
491    #[test]
492    fn certificate_metadata_partial() {
493        let example_meta = CertificateMetadata {
494            common_name: Some("common".to_owned()),
495            country_name: Some("country".to_owned()),
496            locality_name: None,
497            organization_name: Some("organization".to_owned()),
498            organizational_unit_name: Some("org_unit".to_owned()),
499            state_or_province_name: None,
500        };
501
502        let expected = format!(
503            "CN={},OU={},O={},C={}",
504            example_meta.common_name.as_ref().unwrap(),
505            example_meta.organizational_unit_name.as_ref().unwrap(),
506            example_meta.organization_name.as_ref().unwrap(),
507            example_meta.country_name.as_ref().unwrap()
508        );
509        assert_eq!(expected, example_meta.subject());
510    }
511}