vector_core/tls/
incoming.rs

1use std::{
2    collections::HashMap,
3    future::Future,
4    net::SocketAddr,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use futures::{FutureExt, Stream, future::BoxFuture, stream};
11use ipnet::IpNet;
12use openssl::{
13    ssl::{Ssl, SslAcceptor, SslMethod},
14    x509::X509,
15};
16use snafu::ResultExt;
17use tokio::{
18    io::{self, AsyncRead, AsyncWrite, ReadBuf},
19    net::{TcpListener, TcpStream},
20    sync::{OwnedSemaphorePermit, Semaphore},
21};
22use tokio_openssl::SslStream;
23use tonic::transport::{Certificate, server::Connected};
24
25use super::{
26    CreateAcceptorSnafu, HandshakeSnafu, IncomingListenerSnafu, MaybeTlsSettings, MaybeTlsStream,
27    SslBuildSnafu, TcpBindSnafu, TlsError, TlsSettings,
28};
29use crate::tcp::{self, TcpKeepaliveConfig};
30
31impl TlsSettings {
32    pub fn acceptor(&self) -> crate::tls::Result<SslAcceptor> {
33        if self.identity.is_none() {
34            Err(TlsError::MissingRequiredIdentity)
35        } else {
36            let mut acceptor =
37                SslAcceptor::mozilla_intermediate(SslMethod::tls()).context(CreateAcceptorSnafu)?;
38            self.apply_context_base(&mut acceptor, true)?;
39            Ok(acceptor.build())
40        }
41    }
42}
43
44impl MaybeTlsSettings {
45    pub async fn bind(&self, addr: &SocketAddr) -> crate::tls::Result<MaybeTlsListener> {
46        let listener = TcpListener::bind(addr).await.context(TcpBindSnafu)?;
47
48        let acceptor = match self {
49            Self::Tls(tls) => Some(tls.acceptor()?),
50            Self::Raw(()) => None,
51        };
52
53        Ok(MaybeTlsListener {
54            listener,
55            acceptor,
56            origin_filter: None,
57        })
58    }
59
60    pub async fn bind_with_allowlist(
61        &self,
62        addr: &SocketAddr,
63        allow_origin: Vec<IpNet>,
64    ) -> crate::tls::Result<MaybeTlsListener> {
65        let listener = TcpListener::bind(addr).await.context(TcpBindSnafu)?;
66
67        let acceptor = match self {
68            Self::Tls(tls) => Some(tls.acceptor()?),
69            Self::Raw(()) => None,
70        };
71
72        Ok(MaybeTlsListener {
73            listener,
74            acceptor,
75            origin_filter: Some(allow_origin),
76        })
77    }
78}
79
80pub struct MaybeTlsListener {
81    listener: TcpListener,
82    acceptor: Option<SslAcceptor>,
83    origin_filter: Option<Vec<IpNet>>,
84}
85
86impl MaybeTlsListener {
87    pub async fn accept(&mut self) -> crate::tls::Result<MaybeTlsIncomingStream<TcpStream>> {
88        let listener = self
89            .listener
90            .accept()
91            .await
92            .map(|(stream, peer_addr)| {
93                MaybeTlsIncomingStream::new(stream, peer_addr, self.acceptor.clone())
94            })
95            .context(IncomingListenerSnafu)?;
96
97        if let Some(origin_filter) = &self.origin_filter {
98            if origin_filter
99                .iter()
100                .any(|net| net.contains(&listener.peer_addr().ip()))
101            {
102                Ok(listener)
103            } else {
104                Err(TlsError::Connect {
105                    source: std::io::ErrorKind::ConnectionRefused.into(),
106                })
107            }
108        } else {
109            Ok(listener)
110        }
111    }
112
113    async fn into_accept(
114        mut self,
115    ) -> (crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>, Self) {
116        (self.accept().await, self)
117    }
118
119    pub fn accept_stream(
120        self,
121    ) -> impl Stream<Item = crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>> {
122        let mut accept = Box::pin(self.into_accept());
123        stream::poll_fn(move |context| match accept.as_mut().poll(context) {
124            Poll::Ready((item, this)) => {
125                accept.set(this.into_accept());
126                Poll::Ready(Some(item))
127            }
128            Poll::Pending => Poll::Pending,
129        })
130    }
131
132    pub fn accept_stream_limited(
133        self,
134        max_connections: Option<u32>,
135    ) -> impl Stream<
136        Item = (
137            crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>,
138            Option<OwnedSemaphorePermit>,
139        ),
140    > {
141        let mut connection_semaphore_future = max_connections.map(|max| {
142            let semaphore = Arc::new(Semaphore::new(max as usize));
143            let future = Box::pin(semaphore.clone().acquire_owned());
144            (semaphore, future)
145        });
146
147        let mut accept = Box::pin(self.into_accept());
148        stream::poll_fn(move |context| {
149            let permit = match connection_semaphore_future.as_mut() {
150                Some((semaphore, future)) => match future.as_mut().poll(context) {
151                    Poll::Ready(permit) => {
152                        future.set(semaphore.clone().acquire_owned());
153                        permit.ok()
154                    }
155                    Poll::Pending => return Poll::Pending,
156                },
157                None => None,
158            };
159            match accept.as_mut().poll(context) {
160                Poll::Ready((item, this)) => {
161                    accept.set(this.into_accept());
162                    Poll::Ready(Some((item, permit)))
163                }
164                Poll::Pending => Poll::Pending,
165            }
166        })
167    }
168
169    pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
170        self.listener.local_addr()
171    }
172
173    #[must_use]
174    pub fn with_allowlist(mut self, allowlist: Option<Vec<IpNet>>) -> Self {
175        self.origin_filter = allowlist;
176        self
177    }
178}
179
180impl From<TcpListener> for MaybeTlsListener {
181    fn from(listener: TcpListener) -> Self {
182        Self {
183            listener,
184            acceptor: None,
185            origin_filter: None,
186        }
187    }
188}
189
190pub struct MaybeTlsIncomingStream<S> {
191    state: StreamState<S>,
192    // BoxFuture doesn't allow access to the inner stream, but users
193    // of MaybeTlsIncomingStream want access to the peer address while
194    // still handshaking, so we have to cache it here.
195    peer_addr: SocketAddr,
196}
197
198enum StreamState<S> {
199    Accepted(MaybeTlsStream<S>),
200    Accepting(BoxFuture<'static, Result<SslStream<S>, TlsError>>),
201    AcceptError(String),
202    Closed,
203}
204
205impl<S> MaybeTlsIncomingStream<S> {
206    pub const fn peer_addr(&self) -> SocketAddr {
207        self.peer_addr
208    }
209
210    /// None if connection still hasn't been established.
211    pub fn get_ref(&self) -> Option<&S> {
212        use super::MaybeTls;
213
214        match &self.state {
215            StreamState::Accepted(stream) => Some(match stream {
216                MaybeTls::Raw(s) => s,
217                MaybeTls::Tls(s) => s.get_ref(),
218            }),
219            StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
220        }
221    }
222
223    pub const fn ssl_stream(&self) -> Option<&SslStream<S>> {
224        use super::MaybeTls;
225
226        match &self.state {
227            StreamState::Accepted(stream) => match stream {
228                MaybeTls::Raw(_) => None,
229                MaybeTls::Tls(s) => Some(s),
230            },
231            StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
232        }
233    }
234
235    pub fn get_mut(&mut self) -> Option<&mut S> {
236        use super::MaybeTls;
237
238        match &mut self.state {
239            StreamState::Accepted(stream) => Some(match stream {
240                MaybeTls::Raw(s) => s,
241                MaybeTls::Tls(s) => s.get_mut(),
242            }),
243            StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
244        }
245    }
246}
247
248impl MaybeTlsIncomingStream<TcpStream> {
249    pub(super) fn new(
250        stream: TcpStream,
251        peer_addr: SocketAddr,
252        acceptor: Option<SslAcceptor>,
253    ) -> Self {
254        let state = match acceptor {
255            Some(acceptor) => StreamState::Accepting(
256                async move {
257                    let ssl = Ssl::new(acceptor.context()).context(SslBuildSnafu)?;
258                    let mut stream = SslStream::new(ssl, stream).context(SslBuildSnafu)?;
259                    Pin::new(&mut stream)
260                        .accept()
261                        .await
262                        .context(HandshakeSnafu)?;
263                    Ok(stream)
264                }
265                .boxed(),
266            ),
267            None => StreamState::Accepted(MaybeTlsStream::Raw(stream)),
268        };
269        Self { state, peer_addr }
270    }
271
272    // Explicit handshake method
273    pub async fn handshake(&mut self) -> crate::tls::Result<()> {
274        if let StreamState::Accepting(fut) = &mut self.state {
275            let stream = fut.await?;
276            self.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
277        }
278
279        Ok(())
280    }
281
282    pub fn set_keepalive(&mut self, keepalive: TcpKeepaliveConfig) -> io::Result<()> {
283        let stream = self.get_ref().ok_or_else(|| {
284            io::Error::new(
285                io::ErrorKind::NotConnected,
286                "Can't set keepalive on connection that has not been accepted yet.",
287            )
288        })?;
289
290        if let Some(time_secs) = keepalive.time_secs {
291            let config =
292                socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(time_secs));
293
294            tcp::set_keepalive(stream, &config)?;
295        }
296
297        Ok(())
298    }
299
300    pub fn set_receive_buffer_bytes(&mut self, bytes: usize) -> std::io::Result<()> {
301        let stream = self.get_ref().ok_or_else(|| {
302            io::Error::new(
303                io::ErrorKind::NotConnected,
304                "Can't set receive buffer size on connection that has not been accepted yet.",
305            )
306        })?;
307
308        tcp::set_receive_buffer_size(stream, bytes)
309    }
310
311    fn poll_io<T, F>(self: Pin<&mut Self>, cx: &mut Context, poll_fn: F) -> Poll<io::Result<T>>
312    where
313        F: FnOnce(Pin<&mut MaybeTlsStream<TcpStream>>, &mut Context) -> Poll<io::Result<T>>,
314    {
315        let this = self.get_mut();
316        loop {
317            return match &mut this.state {
318                StreamState::Accepted(stream) => poll_fn(Pin::new(stream), cx),
319                StreamState::Accepting(fut) => match std::task::ready!(fut.as_mut().poll(cx)) {
320                    Ok(stream) => {
321                        this.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
322                        continue;
323                    }
324                    Err(error) => {
325                        let error = io::Error::other(error);
326                        this.state = StreamState::AcceptError(error.to_string());
327                        Poll::Ready(Err(error))
328                    }
329                },
330                StreamState::AcceptError(error) => {
331                    Poll::Ready(Err(io::Error::other(error.clone())))
332                }
333                StreamState::Closed => Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())),
334            };
335        }
336    }
337}
338
339impl AsyncRead for MaybeTlsIncomingStream<TcpStream> {
340    fn poll_read(
341        self: Pin<&mut Self>,
342        cx: &mut Context,
343        buf: &mut ReadBuf<'_>,
344    ) -> Poll<io::Result<()>> {
345        self.poll_io(cx, |s, cx| s.poll_read(cx, buf))
346    }
347}
348
349impl AsyncWrite for MaybeTlsIncomingStream<TcpStream> {
350    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
351        self.poll_io(cx, |s, cx| s.poll_write(cx, buf))
352    }
353
354    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
355        self.poll_io(cx, AsyncWrite::poll_flush)
356    }
357
358    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
359        let this = self.get_mut();
360        match &mut this.state {
361            StreamState::Accepted(stream) => match Pin::new(stream).poll_shutdown(cx) {
362                Poll::Ready(Ok(())) => {
363                    this.state = StreamState::Closed;
364                    Poll::Ready(Ok(()))
365                }
366                poll_result => poll_result,
367            },
368            StreamState::Accepting(fut) => match std::task::ready!(fut.as_mut().poll(cx)) {
369                Ok(stream) => {
370                    this.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
371                    Poll::Pending
372                }
373                Err(error) => {
374                    let error = io::Error::other(error);
375                    this.state = StreamState::AcceptError(error.to_string());
376                    Poll::Ready(Err(error))
377                }
378            },
379            StreamState::AcceptError(error) => Poll::Ready(Err(io::Error::other(error.clone()))),
380            StreamState::Closed => Poll::Ready(Ok(())),
381        }
382    }
383}
384
385#[derive(Debug)]
386pub struct CertificateMetadata {
387    pub country_name: Option<String>,
388    pub state_or_province_name: Option<String>,
389    pub locality_name: Option<String>,
390    pub organization_name: Option<String>,
391    pub organizational_unit_name: Option<String>,
392    pub common_name: Option<String>,
393}
394
395impl CertificateMetadata {
396    pub fn subject(&self) -> String {
397        let mut components = Vec::<String>::with_capacity(6);
398        if let Some(cn) = &self.common_name {
399            components.push(format!("CN={cn}"));
400        }
401        if let Some(ou) = &self.organizational_unit_name {
402            components.push(format!("OU={ou}"));
403        }
404        if let Some(o) = &self.organization_name {
405            components.push(format!("O={o}"));
406        }
407        if let Some(l) = &self.locality_name {
408            components.push(format!("L={l}"));
409        }
410        if let Some(st) = &self.state_or_province_name {
411            components.push(format!("ST={st}"));
412        }
413        if let Some(c) = &self.country_name {
414            components.push(format!("C={c}"));
415        }
416        components.join(",")
417    }
418}
419
420impl From<X509> for CertificateMetadata {
421    fn from(cert: X509) -> Self {
422        let mut subject_metadata: HashMap<String, String> = HashMap::new();
423        for entry in cert.subject_name().entries() {
424            let data_string = match entry.data().as_utf8() {
425                Ok(data) => data.to_string(),
426                Err(_) => String::new(),
427            };
428            subject_metadata.insert(entry.object().to_string(), data_string);
429        }
430        Self {
431            country_name: subject_metadata.get("countryName").cloned(),
432            state_or_province_name: subject_metadata.get("stateOrProvinceName").cloned(),
433            locality_name: subject_metadata.get("localityName").cloned(),
434            organization_name: subject_metadata.get("organizationName").cloned(),
435            organizational_unit_name: subject_metadata.get("organizationalUnitName").cloned(),
436            common_name: subject_metadata.get("commonName").cloned(),
437        }
438    }
439}
440
441#[derive(Clone)]
442pub struct MaybeTlsConnectInfo {
443    pub remote_addr: SocketAddr,
444    pub peer_certs: Option<Vec<Certificate>>,
445}
446
447impl Connected for MaybeTlsIncomingStream<TcpStream> {
448    type ConnectInfo = MaybeTlsConnectInfo;
449
450    fn connect_info(&self) -> Self::ConnectInfo {
451        MaybeTlsConnectInfo {
452            remote_addr: self.peer_addr(),
453            peer_certs: self
454                .ssl_stream()
455                .and_then(|s| s.ssl().peer_cert_chain())
456                .map(|s| {
457                    s.into_iter()
458                        .filter_map(|c| c.to_pem().ok())
459                        .map(Certificate::from_pem)
460                        .collect()
461                }),
462        }
463    }
464}
465
466#[cfg(test)]
467mod test {
468    use super::*;
469
470    #[test]
471    fn certificate_metadata_full() {
472        let example_meta = CertificateMetadata {
473            common_name: Some("common".to_owned()),
474            country_name: Some("country".to_owned()),
475            locality_name: Some("locality".to_owned()),
476            organization_name: Some("organization".to_owned()),
477            organizational_unit_name: Some("org_unit".to_owned()),
478            state_or_province_name: Some("state".to_owned()),
479        };
480
481        let expected = format!(
482            "CN={},OU={},O={},L={},ST={},C={}",
483            example_meta.common_name.as_ref().unwrap(),
484            example_meta.organizational_unit_name.as_ref().unwrap(),
485            example_meta.organization_name.as_ref().unwrap(),
486            example_meta.locality_name.as_ref().unwrap(),
487            example_meta.state_or_province_name.as_ref().unwrap(),
488            example_meta.country_name.as_ref().unwrap()
489        );
490        assert_eq!(expected, example_meta.subject());
491    }
492
493    #[test]
494    fn certificate_metadata_partial() {
495        let example_meta = CertificateMetadata {
496            common_name: Some("common".to_owned()),
497            country_name: Some("country".to_owned()),
498            locality_name: None,
499            organization_name: Some("organization".to_owned()),
500            organizational_unit_name: Some("org_unit".to_owned()),
501            state_or_province_name: None,
502        };
503
504        let expected = format!(
505            "CN={},OU={},O={},C={}",
506            example_meta.common_name.as_ref().unwrap(),
507            example_meta.organizational_unit_name.as_ref().unwrap(),
508            example_meta.organization_name.as_ref().unwrap(),
509            example_meta.country_name.as_ref().unwrap()
510        );
511        assert_eq!(expected, example_meta.subject());
512    }
513}