1use std::{
2 collections::HashMap,
3 future::Future,
4 net::SocketAddr,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use futures::{FutureExt, Stream, future::BoxFuture, stream};
11use ipnet::IpNet;
12use openssl::{
13 ssl::{Ssl, SslAcceptor, SslMethod},
14 x509::X509,
15};
16use snafu::ResultExt;
17use tokio::{
18 io::{self, AsyncRead, AsyncWrite, ReadBuf},
19 net::{TcpListener, TcpStream},
20 sync::{OwnedSemaphorePermit, Semaphore},
21};
22use tokio_openssl::SslStream;
23use tonic::transport::{Certificate, server::Connected};
24
25use super::{
26 CreateAcceptorSnafu, HandshakeSnafu, IncomingListenerSnafu, MaybeTlsSettings, MaybeTlsStream,
27 SslBuildSnafu, TcpBindSnafu, TlsError, TlsSettings,
28};
29use crate::tcp::{self, TcpKeepaliveConfig};
30
31impl TlsSettings {
32 pub fn acceptor(&self) -> crate::tls::Result<SslAcceptor> {
33 if self.identity.is_none() {
34 Err(TlsError::MissingRequiredIdentity)
35 } else {
36 let mut acceptor =
37 SslAcceptor::mozilla_intermediate(SslMethod::tls()).context(CreateAcceptorSnafu)?;
38 self.apply_context_base(&mut acceptor, true)?;
39 Ok(acceptor.build())
40 }
41 }
42}
43
44impl MaybeTlsSettings {
45 pub async fn bind(&self, addr: &SocketAddr) -> crate::tls::Result<MaybeTlsListener> {
46 let listener = TcpListener::bind(addr).await.context(TcpBindSnafu)?;
47
48 let acceptor = match self {
49 Self::Tls(tls) => Some(tls.acceptor()?),
50 Self::Raw(()) => None,
51 };
52
53 Ok(MaybeTlsListener {
54 listener,
55 acceptor,
56 origin_filter: None,
57 })
58 }
59
60 pub async fn bind_with_allowlist(
61 &self,
62 addr: &SocketAddr,
63 allow_origin: Vec<IpNet>,
64 ) -> crate::tls::Result<MaybeTlsListener> {
65 let listener = TcpListener::bind(addr).await.context(TcpBindSnafu)?;
66
67 let acceptor = match self {
68 Self::Tls(tls) => Some(tls.acceptor()?),
69 Self::Raw(()) => None,
70 };
71
72 Ok(MaybeTlsListener {
73 listener,
74 acceptor,
75 origin_filter: Some(allow_origin),
76 })
77 }
78}
79
80pub struct MaybeTlsListener {
81 listener: TcpListener,
82 acceptor: Option<SslAcceptor>,
83 origin_filter: Option<Vec<IpNet>>,
84}
85
86impl MaybeTlsListener {
87 pub async fn accept(&mut self) -> crate::tls::Result<MaybeTlsIncomingStream<TcpStream>> {
88 let listener = self
89 .listener
90 .accept()
91 .await
92 .map(|(stream, peer_addr)| {
93 MaybeTlsIncomingStream::new(stream, peer_addr, self.acceptor.clone())
94 })
95 .context(IncomingListenerSnafu)?;
96
97 if let Some(origin_filter) = &self.origin_filter {
98 if origin_filter
99 .iter()
100 .any(|net| net.contains(&listener.peer_addr().ip()))
101 {
102 Ok(listener)
103 } else {
104 Err(TlsError::Connect {
105 source: std::io::ErrorKind::ConnectionRefused.into(),
106 })
107 }
108 } else {
109 Ok(listener)
110 }
111 }
112
113 async fn into_accept(
114 mut self,
115 ) -> (crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>, Self) {
116 (self.accept().await, self)
117 }
118
119 pub fn accept_stream(
120 self,
121 ) -> impl Stream<Item = crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>> {
122 let mut accept = Box::pin(self.into_accept());
123 stream::poll_fn(move |context| match accept.as_mut().poll(context) {
124 Poll::Ready((item, this)) => {
125 accept.set(this.into_accept());
126 Poll::Ready(Some(item))
127 }
128 Poll::Pending => Poll::Pending,
129 })
130 }
131
132 pub fn accept_stream_limited(
133 self,
134 max_connections: Option<u32>,
135 ) -> impl Stream<
136 Item = (
137 crate::tls::Result<MaybeTlsIncomingStream<TcpStream>>,
138 Option<OwnedSemaphorePermit>,
139 ),
140 > {
141 let mut connection_semaphore_future = max_connections.map(|max| {
142 let semaphore = Arc::new(Semaphore::new(max as usize));
143 let future = Box::pin(semaphore.clone().acquire_owned());
144 (semaphore, future)
145 });
146
147 let mut accept = Box::pin(self.into_accept());
148 stream::poll_fn(move |context| {
149 let permit = match connection_semaphore_future.as_mut() {
150 Some((semaphore, future)) => match future.as_mut().poll(context) {
151 Poll::Ready(permit) => {
152 future.set(semaphore.clone().acquire_owned());
153 permit.ok()
154 }
155 Poll::Pending => return Poll::Pending,
156 },
157 None => None,
158 };
159 match accept.as_mut().poll(context) {
160 Poll::Ready((item, this)) => {
161 accept.set(this.into_accept());
162 Poll::Ready(Some((item, permit)))
163 }
164 Poll::Pending => Poll::Pending,
165 }
166 })
167 }
168
169 pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
170 self.listener.local_addr()
171 }
172
173 #[must_use]
174 pub fn with_allowlist(mut self, allowlist: Option<Vec<IpNet>>) -> Self {
175 self.origin_filter = allowlist;
176 self
177 }
178}
179
180impl From<TcpListener> for MaybeTlsListener {
181 fn from(listener: TcpListener) -> Self {
182 Self {
183 listener,
184 acceptor: None,
185 origin_filter: None,
186 }
187 }
188}
189
190pub struct MaybeTlsIncomingStream<S> {
191 state: StreamState<S>,
192 peer_addr: SocketAddr,
196}
197
198enum StreamState<S> {
199 Accepted(MaybeTlsStream<S>),
200 Accepting(BoxFuture<'static, Result<SslStream<S>, TlsError>>),
201 AcceptError(String),
202 Closed,
203}
204
205impl<S> MaybeTlsIncomingStream<S> {
206 pub const fn peer_addr(&self) -> SocketAddr {
207 self.peer_addr
208 }
209
210 pub fn get_ref(&self) -> Option<&S> {
212 use super::MaybeTls;
213
214 match &self.state {
215 StreamState::Accepted(stream) => Some(match stream {
216 MaybeTls::Raw(s) => s,
217 MaybeTls::Tls(s) => s.get_ref(),
218 }),
219 StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
220 }
221 }
222
223 pub const fn ssl_stream(&self) -> Option<&SslStream<S>> {
224 use super::MaybeTls;
225
226 match &self.state {
227 StreamState::Accepted(stream) => match stream {
228 MaybeTls::Raw(_) => None,
229 MaybeTls::Tls(s) => Some(s),
230 },
231 StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
232 }
233 }
234
235 pub fn get_mut(&mut self) -> Option<&mut S> {
236 use super::MaybeTls;
237
238 match &mut self.state {
239 StreamState::Accepted(stream) => Some(match stream {
240 MaybeTls::Raw(s) => s,
241 MaybeTls::Tls(s) => s.get_mut(),
242 }),
243 StreamState::Accepting(_) | StreamState::AcceptError(_) | StreamState::Closed => None,
244 }
245 }
246}
247
248impl MaybeTlsIncomingStream<TcpStream> {
249 pub(super) fn new(
250 stream: TcpStream,
251 peer_addr: SocketAddr,
252 acceptor: Option<SslAcceptor>,
253 ) -> Self {
254 let state = match acceptor {
255 Some(acceptor) => StreamState::Accepting(
256 async move {
257 let ssl = Ssl::new(acceptor.context()).context(SslBuildSnafu)?;
258 let mut stream = SslStream::new(ssl, stream).context(SslBuildSnafu)?;
259 Pin::new(&mut stream)
260 .accept()
261 .await
262 .context(HandshakeSnafu)?;
263 Ok(stream)
264 }
265 .boxed(),
266 ),
267 None => StreamState::Accepted(MaybeTlsStream::Raw(stream)),
268 };
269 Self { state, peer_addr }
270 }
271
272 pub async fn handshake(&mut self) -> crate::tls::Result<()> {
274 if let StreamState::Accepting(fut) = &mut self.state {
275 let stream = fut.await?;
276 self.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
277 }
278
279 Ok(())
280 }
281
282 pub fn set_keepalive(&mut self, keepalive: TcpKeepaliveConfig) -> io::Result<()> {
283 let stream = self.get_ref().ok_or_else(|| {
284 io::Error::new(
285 io::ErrorKind::NotConnected,
286 "Can't set keepalive on connection that has not been accepted yet.",
287 )
288 })?;
289
290 if let Some(time_secs) = keepalive.time_secs {
291 let config =
292 socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(time_secs));
293
294 tcp::set_keepalive(stream, &config)?;
295 }
296
297 Ok(())
298 }
299
300 pub fn set_receive_buffer_bytes(&mut self, bytes: usize) -> std::io::Result<()> {
301 let stream = self.get_ref().ok_or_else(|| {
302 io::Error::new(
303 io::ErrorKind::NotConnected,
304 "Can't set receive buffer size on connection that has not been accepted yet.",
305 )
306 })?;
307
308 tcp::set_receive_buffer_size(stream, bytes)
309 }
310
311 fn poll_io<T, F>(self: Pin<&mut Self>, cx: &mut Context, poll_fn: F) -> Poll<io::Result<T>>
312 where
313 F: FnOnce(Pin<&mut MaybeTlsStream<TcpStream>>, &mut Context) -> Poll<io::Result<T>>,
314 {
315 let this = self.get_mut();
316 loop {
317 return match &mut this.state {
318 StreamState::Accepted(stream) => poll_fn(Pin::new(stream), cx),
319 StreamState::Accepting(fut) => match std::task::ready!(fut.as_mut().poll(cx)) {
320 Ok(stream) => {
321 this.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
322 continue;
323 }
324 Err(error) => {
325 let error = io::Error::other(error);
326 this.state = StreamState::AcceptError(error.to_string());
327 Poll::Ready(Err(error))
328 }
329 },
330 StreamState::AcceptError(error) => {
331 Poll::Ready(Err(io::Error::other(error.clone())))
332 }
333 StreamState::Closed => Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())),
334 };
335 }
336 }
337}
338
339impl AsyncRead for MaybeTlsIncomingStream<TcpStream> {
340 fn poll_read(
341 self: Pin<&mut Self>,
342 cx: &mut Context,
343 buf: &mut ReadBuf<'_>,
344 ) -> Poll<io::Result<()>> {
345 self.poll_io(cx, |s, cx| s.poll_read(cx, buf))
346 }
347}
348
349impl AsyncWrite for MaybeTlsIncomingStream<TcpStream> {
350 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
351 self.poll_io(cx, |s, cx| s.poll_write(cx, buf))
352 }
353
354 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
355 self.poll_io(cx, AsyncWrite::poll_flush)
356 }
357
358 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
359 let this = self.get_mut();
360 match &mut this.state {
361 StreamState::Accepted(stream) => match Pin::new(stream).poll_shutdown(cx) {
362 Poll::Ready(Ok(())) => {
363 this.state = StreamState::Closed;
364 Poll::Ready(Ok(()))
365 }
366 poll_result => poll_result,
367 },
368 StreamState::Accepting(fut) => match std::task::ready!(fut.as_mut().poll(cx)) {
369 Ok(stream) => {
370 this.state = StreamState::Accepted(MaybeTlsStream::Tls(stream));
371 Poll::Pending
372 }
373 Err(error) => {
374 let error = io::Error::other(error);
375 this.state = StreamState::AcceptError(error.to_string());
376 Poll::Ready(Err(error))
377 }
378 },
379 StreamState::AcceptError(error) => Poll::Ready(Err(io::Error::other(error.clone()))),
380 StreamState::Closed => Poll::Ready(Ok(())),
381 }
382 }
383}
384
385#[derive(Debug)]
386pub struct CertificateMetadata {
387 pub country_name: Option<String>,
388 pub state_or_province_name: Option<String>,
389 pub locality_name: Option<String>,
390 pub organization_name: Option<String>,
391 pub organizational_unit_name: Option<String>,
392 pub common_name: Option<String>,
393}
394
395impl CertificateMetadata {
396 pub fn subject(&self) -> String {
397 let mut components = Vec::<String>::with_capacity(6);
398 if let Some(cn) = &self.common_name {
399 components.push(format!("CN={cn}"));
400 }
401 if let Some(ou) = &self.organizational_unit_name {
402 components.push(format!("OU={ou}"));
403 }
404 if let Some(o) = &self.organization_name {
405 components.push(format!("O={o}"));
406 }
407 if let Some(l) = &self.locality_name {
408 components.push(format!("L={l}"));
409 }
410 if let Some(st) = &self.state_or_province_name {
411 components.push(format!("ST={st}"));
412 }
413 if let Some(c) = &self.country_name {
414 components.push(format!("C={c}"));
415 }
416 components.join(",")
417 }
418}
419
420impl From<X509> for CertificateMetadata {
421 fn from(cert: X509) -> Self {
422 let mut subject_metadata: HashMap<String, String> = HashMap::new();
423 for entry in cert.subject_name().entries() {
424 let data_string = match entry.data().as_utf8() {
425 Ok(data) => data.to_string(),
426 Err(_) => String::new(),
427 };
428 subject_metadata.insert(entry.object().to_string(), data_string);
429 }
430 Self {
431 country_name: subject_metadata.get("countryName").cloned(),
432 state_or_province_name: subject_metadata.get("stateOrProvinceName").cloned(),
433 locality_name: subject_metadata.get("localityName").cloned(),
434 organization_name: subject_metadata.get("organizationName").cloned(),
435 organizational_unit_name: subject_metadata.get("organizationalUnitName").cloned(),
436 common_name: subject_metadata.get("commonName").cloned(),
437 }
438 }
439}
440
441#[derive(Clone)]
442pub struct MaybeTlsConnectInfo {
443 pub remote_addr: SocketAddr,
444 pub peer_certs: Option<Vec<Certificate>>,
445}
446
447impl Connected for MaybeTlsIncomingStream<TcpStream> {
448 type ConnectInfo = MaybeTlsConnectInfo;
449
450 fn connect_info(&self) -> Self::ConnectInfo {
451 MaybeTlsConnectInfo {
452 remote_addr: self.peer_addr(),
453 peer_certs: self
454 .ssl_stream()
455 .and_then(|s| s.ssl().peer_cert_chain())
456 .map(|s| {
457 s.into_iter()
458 .filter_map(|c| c.to_pem().ok())
459 .map(Certificate::from_pem)
460 .collect()
461 }),
462 }
463 }
464}
465
466#[cfg(test)]
467mod test {
468 use super::*;
469
470 #[test]
471 fn certificate_metadata_full() {
472 let example_meta = CertificateMetadata {
473 common_name: Some("common".to_owned()),
474 country_name: Some("country".to_owned()),
475 locality_name: Some("locality".to_owned()),
476 organization_name: Some("organization".to_owned()),
477 organizational_unit_name: Some("org_unit".to_owned()),
478 state_or_province_name: Some("state".to_owned()),
479 };
480
481 let expected = format!(
482 "CN={},OU={},O={},L={},ST={},C={}",
483 example_meta.common_name.as_ref().unwrap(),
484 example_meta.organizational_unit_name.as_ref().unwrap(),
485 example_meta.organization_name.as_ref().unwrap(),
486 example_meta.locality_name.as_ref().unwrap(),
487 example_meta.state_or_province_name.as_ref().unwrap(),
488 example_meta.country_name.as_ref().unwrap()
489 );
490 assert_eq!(expected, example_meta.subject());
491 }
492
493 #[test]
494 fn certificate_metadata_partial() {
495 let example_meta = CertificateMetadata {
496 common_name: Some("common".to_owned()),
497 country_name: Some("country".to_owned()),
498 locality_name: None,
499 organization_name: Some("organization".to_owned()),
500 organizational_unit_name: Some("org_unit".to_owned()),
501 state_or_province_name: None,
502 };
503
504 let expected = format!(
505 "CN={},OU={},O={},C={}",
506 example_meta.common_name.as_ref().unwrap(),
507 example_meta.organizational_unit_name.as_ref().unwrap(),
508 example_meta.organization_name.as_ref().unwrap(),
509 example_meta.country_name.as_ref().unwrap()
510 );
511 assert_eq!(expected, example_meta.subject());
512 }
513}