1use ipnet::IpNet;
2use std::{
3 collections::HashMap,
4 future::Future,
5 net::SocketAddr,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9};
10
11use futures::{future::BoxFuture, stream, FutureExt, Stream};
12use openssl::ssl::{Ssl, SslAcceptor, SslMethod};
13use openssl::x509::X509;
14use snafu::ResultExt;
15use tokio::sync::{OwnedSemaphorePermit, Semaphore};
16use tokio::{
17 io::{self, AsyncRead, AsyncWrite, ReadBuf},
18 net::{TcpListener, TcpStream},
19};
20use tokio_openssl::SslStream;
21use tonic::transport::{server::Connected, Certificate};
22
23use super::{
24 CreateAcceptorSnafu, HandshakeSnafu, IncomingListenerSnafu, MaybeTlsSettings, MaybeTlsStream,
25 SslBuildSnafu, TcpBindSnafu, TlsError, TlsSettings,
26};
27use crate::tcp::{self, TcpKeepaliveConfig};
28
29impl TlsSettings {
30 pub fn acceptor(&self) -> crate::tls::Result<SslAcceptor> {
31 if self.identity.is_none() {
32 Err(TlsError::MissingRequiredIdentity)
33 } else {
34 let mut acceptor =
35 SslAcceptor::mozilla_intermediate(SslMethod::tls()).context(CreateAcceptorSnafu)?;
36 self.apply_context_base(&mut acceptor, true)?;
37 Ok(acceptor.build())
38 }
39 }
40}
41
42impl MaybeTlsSettings {
43 pub async fn bind(&self, addr: &SocketAddr) -> crate::tls::Result<MaybeTlsListener> {
44 let listener = TcpListener::bind(addr).await.context(TcpBindSnafu)?;
45
46 let acceptor = match self {
47 Self::Tls(tls) => Some(tls.acceptor()?),
48 Self::Raw(()) => None,
49 };
50
51 Ok(MaybeTlsListener {
52 listener,
53 acceptor,
54 origin_filter: None,
55 })
56 }
57
58 pub async fn bind_with_allowlist(
59 &self,
60 addr: &SocketAddr,
61 allow_origin: Vec<IpNet>,
62 ) -> crate::tls::Result<MaybeTlsListener> {
63 let listener = TcpListener::bind(addr).await.context(TcpBindSnafu)?;
64
65 let acceptor = match self {
66 Self::Tls(tls) => Some(tls.acceptor()?),
67 Self::Raw(()) => None,
68 };
69
70 Ok(MaybeTlsListener {
71 listener,
72 acceptor,
73 origin_filter: Some(allow_origin),
74 })
75 }
76}
77
78pub struct MaybeTlsListener {
79 listener: TcpListener,
80 acceptor: Option<SslAcceptor>,
81 origin_filter: Option<Vec<IpNet>>,
82}
83
84impl MaybeTlsListener {
85 pub async fn accept(&mut self) -> crate::tls::Result<MaybeTlsIncomingStream<TcpStream>> {
86 let listener = self
87 .listener
88 .accept()
89 .await
90 .map(|(stream, peer_addr)| {
91 MaybeTlsIncomingStream::new(stream, peer_addr, self.acceptor.clone())
92 })
93 .context(IncomingListenerSnafu)?;
94
95 if let Some(origin_filter) = &self.origin_filter {
96 if origin_filter
97 .iter()
98 .any(|net| net.contains(&listener.peer_addr().ip()))
99 {
100 Ok(listener)
101 } else {
102 Err(TlsError::Connect {
103 source: std::io::ErrorKind::ConnectionRefused.into(),
104 })
105 }
106 } else {
107 Ok(listener)
108 }
109 }
110
111 async fn into_accept(
112 mut self,
113 ) -> (crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>, Self) {
114 (self.accept().await, self)
115 }
116
117 pub fn accept_stream(
118 self,
119 ) -> impl Stream<Item = crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>> {
120 let mut accept = Box::pin(self.into_accept());
121 stream::poll_fn(move |context| match accept.as_mut().poll(context) {
122 Poll::Ready((item, this)) => {
123 accept.set(this.into_accept());
124 Poll::Ready(Some(item))
125 }
126 Poll::Pending => Poll::Pending,
127 })
128 }
129
130 pub fn accept_stream_limited(
131 self,
132 max_connections: Option<u32>,
133 ) -> impl Stream<
134 Item = (
135 crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>,
136 Option<OwnedSemaphorePermit>,
137 ),
138 > {
139 let mut connection_semaphore_future = max_connections.map(|max| {
140 let semaphore = Arc::new(Semaphore::new(max as usize));
141 let future = Box::pin(semaphore.clone().acquire_owned());
142 (semaphore, future)
143 });
144
145 let mut accept = Box::pin(self.into_accept());
146 stream::poll_fn(move |context| {
147 let permit = match connection_semaphore_future.as_mut() {
148 Some((semaphore, future)) => match future.as_mut().poll(context) {
149 Poll::Ready(permit) => {
150 future.set(semaphore.clone().acquire_owned());
151 permit.ok()
152 }
153 Poll::Pending => return Poll::Pending,
154 },
155 None => None,
156 };
157 match accept.as_mut().poll(context) {
158 Poll::Ready((item, this)) => {
159 accept.set(this.into_accept());
160 Poll::Ready(Some((item, permit)))
161 }
162 Poll::Pending => Poll::Pending,
163 }
164 })
165 }
166
167 pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
168 self.listener.local_addr()
169 }
170
171 #[must_use]
172 pub fn with_allowlist(mut self, allowlist: Option<Vec<IpNet>>) -> Self {
173 self.origin_filter = allowlist;
174 self
175 }
176}
177
178impl From<TcpListener> for MaybeTlsListener {
179 fn from(listener: TcpListener) -> Self {
180 Self {
181 listener,
182 acceptor: None,
183 origin_filter: None,
184 }
185 }
186}
187
188pub struct MaybeTlsIncomingStream<S> {
189 state: StreamState<S>,
190 peer_addr: SocketAddr,
194}
195
196enum StreamState<S> {
197 Accepted(MaybeTlsStream<S>),
198 Accepting(BoxFuture<'static, Result<SslStream<S>, TlsError>>),
199 AcceptError(String),
200 Closed,
201}
202
203impl<S> MaybeTlsIncomingStream<S> {
204 pub const fn peer_addr(&self) -> SocketAddr {
205 self.peer_addr
206 }
207
208 pub fn get_ref(&self) -> Option<&S> {
210 use super::MaybeTls;
211
212 match &self.state {
213 StreamState::Accepted(stream) => Some(match stream {
214 MaybeTls::Raw(s) => s,
215 MaybeTls::Tls(s) => s.get_ref(),
216 }),
217 StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
218 }
219 }
220
221 pub const fn ssl_stream(&self) -> Option<&SslStream<S>> {
222 use super::MaybeTls;
223
224 match &self.state {
225 StreamState::Accepted(stream) => match stream {
226 MaybeTls::Raw(_) => None,
227 MaybeTls::Tls(s) => Some(s),
228 },
229 StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
230 }
231 }
232
233 pub fn get_mut(&mut self) -> Option<&mut S> {
234 use super::MaybeTls;
235
236 match &mut self.state {
237 StreamState::Accepted(ref mut stream) => Some(match stream {
238 MaybeTls::Raw(ref mut s) => s,
239 MaybeTls::Tls(s) => s.get_mut(),
240 }),
241 StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
242 }
243 }
244}
245
246impl MaybeTlsIncomingStream<TcpStream> {
247 pub(super) fn new(
248 stream: TcpStream,
249 peer_addr: SocketAddr,
250 acceptor: Option<SslAcceptor>,
251 ) -> Self {
252 let state = match acceptor {
253 Some(acceptor) => StreamState::Accepting(
254 async move {
255 let ssl = Ssl::new(acceptor.context()).context(SslBuildSnafu)?;
256 let mut stream = SslStream::new(ssl, stream).context(SslBuildSnafu)?;
257 Pin::new(&mut stream)
258 .accept()
259 .await
260 .context(HandshakeSnafu)?;
261 Ok(stream)
262 }
263 .boxed(),
264 ),
265 None => StreamState::Accepted(MaybeTlsStream::Raw(stream)),
266 };
267 Self { state, peer_addr }
268 }
269
270 pub async fn handshake(&mut self) -> crate::tls::Result<()> {
272 if let StreamState::Accepting(fut) = &mut self.state {
273 let stream = fut.await?;
274 self.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
275 }
276
277 Ok(())
278 }
279
280 pub fn set_keepalive(&mut self, keepalive: TcpKeepaliveConfig) -> io::Result<()> {
281 let stream = self.get_ref().ok_or_else(|| {
282 io::Error::new(
283 io::ErrorKind::NotConnected,
284 "Can't set keepalive on connection that has not been accepted yet.",
285 )
286 })?;
287
288 if let Some(time_secs) = keepalive.time_secs {
289 let config =
290 socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(time_secs));
291
292 tcp::set_keepalive(stream, &config)?;
293 }
294
295 Ok(())
296 }
297
298 pub fn set_receive_buffer_bytes(&mut self, bytes: usize) -> std::io::Result<()> {
299 let stream = self.get_ref().ok_or_else(|| {
300 io::Error::new(
301 io::ErrorKind::NotConnected,
302 "Can't set receive buffer size on connection that has not been accepted yet.",
303 )
304 })?;
305
306 tcp::set_receive_buffer_size(stream, bytes)
307 }
308
309 fn poll_io<T, F>(self: Pin<&mut Self>, cx: &mut Context, poll_fn: F) -> Poll<io::Result<T>>
310 where
311 F: FnOnce(Pin<&mut MaybeTlsStream<TcpStream>>, &mut Context) -> Poll<io::Result<T>>,
312 {
313 let this = self.get_mut();
314 loop {
315 return match &mut this.state {
316 StreamState::Accepted(stream) => poll_fn(Pin::new(stream), cx),
317 StreamState::Accepting(fut) => match std::task::ready!(fut.as_mut().poll(cx)) {
318 Ok(stream) => {
319 this.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
320 continue;
321 }
322 Err(error) => {
323 let error = io::Error::other(error);
324 this.state = StreamState::AcceptError(error.to_string());
325 Poll::Ready(Err(error))
326 }
327 },
328 StreamState::AcceptError(error) => {
329 Poll::Ready(Err(io::Error::other(error.clone())))
330 }
331 StreamState::Closed => Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())),
332 };
333 }
334 }
335}
336
337impl AsyncRead for MaybeTlsIncomingStream<TcpStream> {
338 fn poll_read(
339 self: Pin<&mut Self>,
340 cx: &mut Context,
341 buf: &mut ReadBuf<'_>,
342 ) -> Poll<io::Result<()>> {
343 self.poll_io(cx, |s, cx| s.poll_read(cx, buf))
344 }
345}
346
347impl AsyncWrite for MaybeTlsIncomingStream<TcpStream> {
348 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
349 self.poll_io(cx, |s, cx| s.poll_write(cx, buf))
350 }
351
352 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
353 self.poll_io(cx, AsyncWrite::poll_flush)
354 }
355
356 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
357 let this = self.get_mut();
358 match &mut this.state {
359 StreamState::Accepted(stream) => match Pin::new(stream).poll_shutdown(cx) {
360 Poll::Ready(Ok(())) => {
361 this.state = StreamState::Closed;
362 Poll::Ready(Ok(()))
363 }
364 poll_result => poll_result,
365 },
366 StreamState::Accepting(fut) => match std::task::ready!(fut.as_mut().poll(cx)) {
367 Ok(stream) => {
368 this.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
369 Poll::Pending
370 }
371 Err(error) => {
372 let error = io::Error::other(error);
373 this.state = StreamState::AcceptError(error.to_string());
374 Poll::Ready(Err(error))
375 }
376 },
377 StreamState::AcceptError(error) => Poll::Ready(Err(io::Error::other(error.clone()))),
378 StreamState::Closed => Poll::Ready(Ok(())),
379 }
380 }
381}
382
383#[derive(Debug)]
384pub struct CertificateMetadata {
385 pub country_name: Option<String>,
386 pub state_or_province_name: Option<String>,
387 pub locality_name: Option<String>,
388 pub organization_name: Option<String>,
389 pub organizational_unit_name: Option<String>,
390 pub common_name: Option<String>,
391}
392
393impl CertificateMetadata {
394 pub fn subject(&self) -> String {
395 let mut components = Vec::<String>::with_capacity(6);
396 if let Some(cn) = &self.common_name {
397 components.push(format!("CN={cn}"));
398 }
399 if let Some(ou) = &self.organizational_unit_name {
400 components.push(format!("OU={ou}"));
401 }
402 if let Some(o) = &self.organization_name {
403 components.push(format!("O={o}"));
404 }
405 if let Some(l) = &self.locality_name {
406 components.push(format!("L={l}"));
407 }
408 if let Some(st) = &self.state_or_province_name {
409 components.push(format!("ST={st}"));
410 }
411 if let Some(c) = &self.country_name {
412 components.push(format!("C={c}"));
413 }
414 components.join(",")
415 }
416}
417
418impl From<X509> for CertificateMetadata {
419 fn from(cert: X509) -> Self {
420 let mut subject_metadata: HashMap<String, String> = HashMap::new();
421 for entry in cert.subject_name().entries() {
422 let data_string = match entry.data().as_utf8() {
423 Ok(data) => data.to_string(),
424 Err(_) => String::new(),
425 };
426 subject_metadata.insert(entry.object().to_string(), data_string);
427 }
428 Self {
429 country_name: subject_metadata.get("countryName").cloned(),
430 state_or_province_name: subject_metadata.get("stateOrProvinceName").cloned(),
431 locality_name: subject_metadata.get("localityName").cloned(),
432 organization_name: subject_metadata.get("organizationName").cloned(),
433 organizational_unit_name: subject_metadata.get("organizationalUnitName").cloned(),
434 common_name: subject_metadata.get("commonName").cloned(),
435 }
436 }
437}
438
439#[derive(Clone)]
440pub struct MaybeTlsConnectInfo {
441 pub remote_addr: SocketAddr,
442 pub peer_certs: Option<Vec<Certificate>>,
443}
444
445impl Connected for MaybeTlsIncomingStream<TcpStream> {
446 type ConnectInfo = MaybeTlsConnectInfo;
447
448 fn connect_info(&self) -> Self::ConnectInfo {
449 MaybeTlsConnectInfo {
450 remote_addr: self.peer_addr(),
451 peer_certs: self
452 .ssl_stream()
453 .and_then(|s| s.ssl().peer_cert_chain())
454 .map(|s| {
455 s.into_iter()
456 .filter_map(|c| c.to_pem().ok())
457 .map(Certificate::from_pem)
458 .collect()
459 }),
460 }
461 }
462}
463
464#[cfg(test)]
465mod test {
466 use super::*;
467
468 #[test]
469 fn certificate_metadata_full() {
470 let example_meta = CertificateMetadata {
471 common_name: Some("common".to_owned()),
472 country_name: Some("country".to_owned()),
473 locality_name: Some("locality".to_owned()),
474 organization_name: Some("organization".to_owned()),
475 organizational_unit_name: Some("org_unit".to_owned()),
476 state_or_province_name: Some("state".to_owned()),
477 };
478
479 let expected = format!(
480 "CN={},OU={},O={},L={},ST={},C={}",
481 example_meta.common_name.as_ref().unwrap(),
482 example_meta.organizational_unit_name.as_ref().unwrap(),
483 example_meta.organization_name.as_ref().unwrap(),
484 example_meta.locality_name.as_ref().unwrap(),
485 example_meta.state_or_province_name.as_ref().unwrap(),
486 example_meta.country_name.as_ref().unwrap()
487 );
488 assert_eq!(expected, example_meta.subject());
489 }
490
491 #[test]
492 fn certificate_metadata_partial() {
493 let example_meta = CertificateMetadata {
494 common_name: Some("common".to_owned()),
495 country_name: Some("country".to_owned()),
496 locality_name: None,
497 organization_name: Some("organization".to_owned()),
498 organizational_unit_name: Some("org_unit".to_owned()),
499 state_or_province_name: None,
500 };
501
502 let expected = format!(
503 "CN={},OU={},O={},C={}",
504 example_meta.common_name.as_ref().unwrap(),
505 example_meta.organizational_unit_name.as_ref().unwrap(),
506 example_meta.organization_name.as_ref().unwrap(),
507 example_meta.country_name.as_ref().unwrap()
508 );
509 assert_eq!(expected, example_meta.subject());
510 }
511}