1
use axum::serve::Listener;
2
use std::fs;
3
use std::net::SocketAddr;
4
use tokio::net::{TcpListener, TcpStream};
5
use tokio::sync::mpsc;
6
use tokio_native_tls::{
7
    TlsAcceptor,
8
    native_tls::{Identity, Protocol, TlsAcceptor as NativeTlsAcceptor},
9
};
10
use tracing::{error, warn};
11

            
12
type TlsStream = (tokio_native_tls::TlsStream<TcpStream>, SocketAddr);
13

            
14
/// Wrapper around a TcpListener that handles TLS encryption/decryption for incoming connections.
15
pub struct TlsListener {
16
    stream_rx: mpsc::Receiver<TlsStream>,
17
    addr: SocketAddr,
18
}
19

            
20
impl TlsListener {
21
    /// Read the key and cert files, bind to the given socket and handle decryption/encryption for incoming traffic.
22
    pub async fn bind(addr: SocketAddr, key: &str, cert: &str) -> Result<Self, TlsError> {
23
        let key = fs::read(key).map_err(TlsError::ReadKeyError)?;
24
        let cert = fs::read(cert).map_err(TlsError::ReadCertError)?;
25

            
26
        let id = Identity::from_pkcs8(&cert, &key).map_err(TlsError::CreateIdentityError)?;
27

            
28
        let tls_acceptor = NativeTlsAcceptor::builder(id)
29
            .min_protocol_version(Some(Protocol::Tlsv12))
30
            .build()
31
            .map_err(TlsError::CreateAcceptorError)?;
32

            
33
        let tls_acceptor = TlsAcceptor::from(tls_acceptor);
34

            
35
        let mut listener = TcpListener::bind(addr)
36
            .await
37
            .map_err(TlsError::FailedToBindListener)?;
38

            
39
        let addr = listener
40
            .local_addr()
41
            .map_err(TlsError::FailedToBindListener)?;
42

            
43
        let (stream_tx, stream_rx) = mpsc::channel(10);
44

            
45
        tokio::spawn(async move {
46
            loop {
47
                let tls_acceptor = tls_acceptor.clone();
48
                let stream_tx = stream_tx.clone();
49

            
50
                let (stream, addr) = Listener::accept(&mut listener).await;
51
                match tls_acceptor.accept(stream).await {
52
                    Ok(stream) => stream_tx.send((stream, addr)).await.unwrap_or_else(|e| {
53
                        error!("Failed to send stream to listener: {e}");
54
                    }),
55
                    Err(e) => {
56
                        warn!("Error during TLS handshake: {e}");
57
                    }
58
                };
59
            }
60
        });
61

            
62
        Ok(Self { stream_rx, addr })
63
    }
64
}
65

            
66
impl Listener for TlsListener {
67
    type Io = tokio_native_tls::TlsStream<TcpStream>;
68
    type Addr = SocketAddr;
69

            
70
    async fn accept(&mut self) -> TlsStream {
71
        self.stream_rx
72
            .recv()
73
            .await
74
            .expect("TlsListener channel should not close before shutdown")
75
    }
76

            
77
    fn local_addr(&self) -> tokio::io::Result<Self::Addr> {
78
        Ok(self.addr)
79
    }
80
}
81

            
82
#[derive(Debug)]
83
pub enum TlsError {
84
    ReadKeyError(std::io::Error),
85
    ReadCertError(std::io::Error),
86
    CreateIdentityError(tokio_native_tls::native_tls::Error),
87
    CreateAcceptorError(tokio_native_tls::native_tls::Error),
88
    FailedToBindListener(std::io::Error),
89
}
90

            
91
impl std::fmt::Display for TlsError {
92
3
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93
3
        match self {
94
1
            TlsError::ReadKeyError(e) => write!(f, "Failed to read SSL key file: {e}"),
95
1
            TlsError::ReadCertError(e) => write!(f, "Failed to read SSL cert file: {e}"),
96
            TlsError::CreateIdentityError(e) => write!(f, "Failed to create SSL identity: {e}"),
97
            TlsError::CreateAcceptorError(e) => write!(f, "Failed to create SSL acceptor: {e}"),
98
1
            TlsError::FailedToBindListener(e) => write!(f, "Failed to bind listener: {e}"),
99
        }
100
3
    }
101
}
102

            
103
impl std::error::Error for TlsError {}
104

            
105
#[cfg(test)]
106
mod tests {
107
    use super::*;
108
    use std::io;
109

            
110
    #[test]
111
1
    fn test_tls_error_display_read_key_error() {
112
1
        let io_error = io::Error::new(io::ErrorKind::NotFound, "key file not found");
113
1
        let error = TlsError::ReadKeyError(io_error);
114
1
        let display_string = format!("{}", error);
115
1
        assert!(display_string.contains("Failed to read SSL key file"));
116
1
        assert!(display_string.contains("key file not found"));
117
1
    }
118

            
119
    #[test]
120
1
    fn test_tls_error_display_read_cert_error() {
121
1
        let io_error = io::Error::new(io::ErrorKind::PermissionDenied, "permission denied");
122
1
        let error = TlsError::ReadCertError(io_error);
123
1
        let display_string = format!("{}", error);
124
1
        assert!(display_string.contains("Failed to read SSL cert file"));
125
1
        assert!(display_string.contains("permission denied"));
126
1
    }
127

            
128
    #[test]
129
1
    fn test_tls_error_display_failed_to_bind_listener() {
130
1
        let io_error = io::Error::new(io::ErrorKind::AddrInUse, "address already in use");
131
1
        let error = TlsError::FailedToBindListener(io_error);
132
1
        let display_string = format!("{}", error);
133
1
        assert!(display_string.contains("Failed to bind listener"));
134
1
        assert!(display_string.contains("address already in use"));
135
1
    }
136

            
137
    #[test]
138
1
    fn test_tls_error_implements_error_trait() {
139
1
        let error = TlsError::ReadKeyError(io::Error::new(io::ErrorKind::NotFound, "test"));
140
        // Test that TlsError implements std::error::Error
141
1
        let _: &dyn std::error::Error = &error;
142
1
    }
143

            
144
    #[test]
145
1
    fn test_tls_error_debug() {
146
1
        let io_error = io::Error::new(io::ErrorKind::NotFound, "test error");
147
1
        let error = TlsError::ReadKeyError(io_error);
148
1
        let debug_string = format!("{:?}", error);
149
1
        assert!(debug_string.contains("ReadKeyError"));
150
1
    }
151
}