scuffle_http/backend/hyper/
mod.rs

1use std::fmt::Debug;
2use std::net::SocketAddr;
3
4use scuffle_context::ContextFutExt;
5#[cfg(feature = "tracing")]
6use tracing::Instrument;
7
8use crate::error::Error;
9use crate::service::{HttpService, HttpServiceFactory};
10
11mod handler;
12mod stream;
13mod utils;
14
15/// A backend that handles incoming HTTP connections using a hyper backend.
16///
17/// This is used internally by the [`HttpServer`](crate::server::HttpServer) but can be used directly if preferred.
18///
19/// Call [`run`](HyperBackend::run) to start the server.
20#[derive(Debug, Clone, bon::Builder)]
21pub struct HyperBackend<F> {
22    /// The [`scuffle_context::Context`] this server will live by.
23    #[builder(default = scuffle_context::Context::global())]
24    ctx: scuffle_context::Context,
25    /// The number of worker tasks to spawn for each server backend.
26    #[builder(default = 1)]
27    worker_tasks: usize,
28    /// The service factory that will be used to create new services.
29    service_factory: F,
30    /// The address to bind to.
31    ///
32    /// Use `[::]` for a dual-stack listener.
33    /// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6.
34    bind: SocketAddr,
35    /// rustls config.
36    ///
37    /// Use this field to set the server into TLS mode.
38    /// It will only accept TLS connections when this is set.
39    #[cfg(feature = "tls-rustls")]
40    #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
41    rustls_config: Option<rustls::ServerConfig>,
42    /// Enable HTTP/1.1.
43    #[cfg(feature = "http1")]
44    #[cfg_attr(docsrs, doc(cfg(feature = "http1")))]
45    #[builder(default = true)]
46    http1_enabled: bool,
47    /// Enable HTTP/2.
48    #[cfg(feature = "http2")]
49    #[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
50    #[builder(default = true)]
51    http2_enabled: bool,
52}
53
54impl<F> HyperBackend<F>
55where
56    F: HttpServiceFactory + Clone + Send + 'static,
57    F::Error: std::error::Error + Send,
58    F::Service: Clone + Send + 'static,
59    <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
60    <F::Service as HttpService>::ResBody: Send,
61    <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
62    <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
63{
64    /// Run the HTTP server
65    ///
66    /// This function will bind to the address specified in `bind`, listen for incoming connections and handle requests.
67    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
68    #[allow(unused_mut)] // allow the unused `mut self`
69    pub async fn run(mut self) -> Result<(), Error<F>> {
70        #[cfg(feature = "tracing")]
71        tracing::debug!("starting server");
72
73        // reset to 0 because everything explodes if it's not
74        // https://github.com/hyperium/hyper/issues/3841
75        #[cfg(feature = "tls-rustls")]
76        if let Some(rustls_config) = self.rustls_config.as_mut() {
77            rustls_config.max_early_data_size = 0;
78        }
79
80        // We have to create an std listener first because the tokio listener isn't clonable
81        let std_listener = std::net::TcpListener::bind(self.bind)?;
82        // Set nonblocking so we can use it in the async runtime
83        // This should be the default when converting to a tokio listener
84        std_listener.set_nonblocking(true)?;
85
86        #[cfg(feature = "tls-rustls")]
87        let tls_acceptor = self
88            .rustls_config
89            .map(|c| tokio_rustls::TlsAcceptor::from(std::sync::Arc::new(c)));
90
91        // Create a child context for the workers so we can shut them down if one of them fails without shutting down the main context
92        let (worker_ctx, worker_handler) = self.ctx.new_child();
93
94        let workers = (0..self.worker_tasks).map(|_n| {
95            let service_factory = self.service_factory.clone();
96            let ctx = worker_ctx.clone();
97            let std_listener = std_listener.try_clone().expect("failed to clone listener");
98            let listener = tokio::net::TcpListener::from_std(std_listener).expect("failed to create tokio listener");
99            #[cfg(feature = "tls-rustls")]
100            let tls_acceptor = tls_acceptor.clone();
101
102            let worker_fut = async move {
103                loop {
104                    #[cfg(feature = "tracing")]
105                    tracing::trace!("waiting for connections");
106
107                    let (mut stream, addr) = match listener.accept().with_context(ctx.clone()).await {
108                        Some(Ok((tcp_stream, addr))) => (stream::Stream::Tcp(tcp_stream), addr),
109                        Some(Err(e)) if utils::is_fatal_tcp_error(&e) => {
110                            #[cfg(feature = "tracing")]
111                            tracing::error!(err = %e, "failed to accept tcp connection");
112                            return Err(Error::<F>::from(e));
113                        }
114                        Some(Err(_)) => continue,
115                        None => {
116                            #[cfg(feature = "tracing")]
117                            tracing::trace!("context done, stopping listener");
118                            break;
119                        }
120                    };
121
122                    #[cfg(feature = "tracing")]
123                    tracing::trace!(addr = %addr, "accepted tcp connection");
124
125                    let ctx = ctx.clone();
126                    #[cfg(feature = "tls-rustls")]
127                    let tls_acceptor = tls_acceptor.clone();
128                    let mut service_factory = service_factory.clone();
129
130                    let connection_fut = async move {
131                        // Perform the TLS handshake if the acceptor is set
132                        #[cfg(feature = "tls-rustls")]
133                        if let Some(tls_acceptor) = tls_acceptor {
134                            #[cfg(feature = "tracing")]
135                            tracing::trace!("accepting tls connection");
136
137                            stream = match stream.try_accept_tls(&tls_acceptor).with_context(&ctx).await {
138                                Some(Ok(stream)) => stream,
139                                Some(Err(_err)) => {
140                                    #[cfg(feature = "tracing")]
141                                    tracing::warn!(err = %_err, "failed to accept tls connection");
142                                    return;
143                                }
144                                None => {
145                                    #[cfg(feature = "tracing")]
146                                    tracing::trace!("context done, stopping tls acceptor");
147                                    return;
148                                }
149                            };
150
151                            #[cfg(feature = "tracing")]
152                            tracing::trace!("accepted tls connection");
153                        }
154
155                        // make a new service
156                        let http_service = match service_factory.new_service(addr).await {
157                            Ok(service) => service,
158                            Err(_e) => {
159                                #[cfg(feature = "tracing")]
160                                tracing::warn!(err = %_e, "failed to create service");
161                                return;
162                            }
163                        };
164
165                        #[cfg(feature = "tracing")]
166                        tracing::trace!("handling connection");
167
168                        #[cfg(feature = "http1")]
169                        let http1 = self.http1_enabled;
170                        #[cfg(not(feature = "http1"))]
171                        let http1 = false;
172
173                        #[cfg(feature = "http2")]
174                        let http2 = self.http2_enabled;
175                        #[cfg(not(feature = "http2"))]
176                        let http2 = false;
177
178                        let _res = handler::handle_connection::<F, _, _>(ctx, http_service, stream, http1, http2).await;
179
180                        #[cfg(feature = "tracing")]
181                        if let Err(e) = _res {
182                            tracing::warn!(err = %e, "error handling connection");
183                        }
184
185                        #[cfg(feature = "tracing")]
186                        tracing::trace!("connection closed");
187                    };
188
189                    #[cfg(feature = "tracing")]
190                    let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
191
192                    tokio::spawn(connection_fut);
193                }
194
195                #[cfg(feature = "tracing")]
196                tracing::trace!("listener closed");
197
198                Ok(())
199            };
200
201            #[cfg(feature = "tracing")]
202            let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
203
204            tokio::spawn(worker_fut)
205        });
206
207        match futures::future::try_join_all(workers).await {
208            Ok(res) => {
209                for r in res {
210                    if let Err(e) = r {
211                        drop(worker_ctx);
212                        worker_handler.shutdown().await;
213                        return Err(e);
214                    }
215                }
216            }
217            Err(_e) => {
218                #[cfg(feature = "tracing")]
219                tracing::error!(err = %_e, "error running workers");
220            }
221        }
222
223        drop(worker_ctx);
224        worker_handler.shutdown().await;
225
226        #[cfg(feature = "tracing")]
227        tracing::debug!("all workers finished");
228
229        Ok(())
230    }
231}