1#![allow(missing_docs)]
2use std::{
3 net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs},
4 task::{Context, Poll},
5};
6
7use futures::{future::BoxFuture, FutureExt};
8use hyper::client::connect::dns::Name;
9use snafu::ResultExt;
10use tokio::task::spawn_blocking;
11use tower::Service;
12
13pub struct LookupIp(std::vec::IntoIter<SocketAddr>);
14
15#[derive(Debug, Clone, Copy)]
16pub(super) struct Resolver;
17
18impl Resolver {
19 pub(crate) async fn lookup_ip(self, name: String) -> Result<LookupIp, DnsError> {
20 let dummy_port = 9;
26 if name == "localhost" {
28 Ok(LookupIp(
31 vec![SocketAddr::new(Ipv4Addr::LOCALHOST.into(), dummy_port)].into_iter(),
32 ))
33 } else {
34 spawn_blocking(move || {
35 let name_ref = match name.as_str() {
36 name if name.starts_with('[') && name.ends_with(']') => {
38 &name[1..name.len() - 1]
39 }
40 name => name,
41 };
42 (name_ref, dummy_port).to_socket_addrs()
43 })
44 .await
45 .context(JoinSnafu)?
46 .map(LookupIp)
47 .context(UnableLookupSnafu)
48 }
49 }
50}
51
52impl Iterator for LookupIp {
53 type Item = IpAddr;
54
55 fn next(&mut self) -> Option<Self::Item> {
56 self.0.next().map(|address| address.ip())
57 }
58}
59
60impl Service<Name> for Resolver {
61 type Response = LookupIp;
62 type Error = DnsError;
63 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
64
65 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
66 Ok(()).into()
67 }
68
69 fn call(&mut self, name: Name) -> Self::Future {
70 self.lookup_ip(name.as_str().to_owned()).boxed()
71 }
72}
73
74#[derive(Debug, snafu::Snafu)]
75pub enum DnsError {
76 #[snafu(display("Unable to resolve name: {}", source))]
77 UnableLookup { source: tokio::io::Error },
78 #[snafu(display("Failed to join with resolving future: {}", source))]
79 JoinError { source: tokio::task::JoinError },
80}
81
82#[cfg(test)]
83mod tests {
84 use super::Resolver;
85
86 async fn resolve(name: &str) -> bool {
87 let resolver = Resolver;
88 resolver.lookup_ip(name.to_owned()).await.is_ok()
89 }
90
91 #[tokio::test]
92 async fn resolve_example() {
93 assert!(resolve("example.com").await);
94 }
95
96 #[tokio::test]
97 async fn resolve_localhost() {
98 assert!(resolve("localhost").await);
99 }
100
101 #[tokio::test]
102 async fn resolve_ipv4() {
103 assert!(resolve("10.0.4.0").await);
104 }
105
106 #[tokio::test]
107 async fn resolve_ipv6() {
108 assert!(resolve("::1").await);
109 }
110}