scuffle_http/backend/h3/
mod.rs

1use std::fmt::Debug;
2use std::io;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use body::QuicIncomingBody;
7use scuffle_context::ContextFutExt;
8#[cfg(feature = "tracing")]
9use tracing::Instrument;
10use utils::copy_response_body;
11
12use crate::error::Error;
13use crate::service::{HttpService, HttpServiceFactory};
14
15pub mod body;
16mod utils;
17
18/// A backend that handles incoming HTTP3 connections.
19///
20/// This is used internally by the [`HttpServer`](crate::server::HttpServer) but can be used directly if preferred.
21///
22/// Call [`run`](Http3Backend::run) to start the server.
23#[derive(bon::Builder, Debug, Clone)]
24pub struct Http3Backend<F> {
25    /// The [`scuffle_context::Context`] this server will live by.
26    #[builder(default = scuffle_context::Context::global())]
27    ctx: scuffle_context::Context,
28    /// The number of worker tasks to spawn for each server backend.
29    #[builder(default = 1)]
30    worker_tasks: usize,
31    /// The service factory that will be used to create new services.
32    service_factory: F,
33    /// The address to bind to.
34    ///
35    /// Use `[::]` for a dual-stack listener.
36    /// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6.
37    bind: SocketAddr,
38    /// rustls config.
39    ///
40    /// Use this field to set the server into TLS mode.
41    /// It will only accept TLS connections when this is set.
42    rustls_config: rustls::ServerConfig,
43}
44
45impl<F> Http3Backend<F>
46where
47    F: HttpServiceFactory + Clone + Send + 'static,
48    F::Error: std::error::Error + Send,
49    F::Service: Clone + Send + 'static,
50    <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
51    <F::Service as HttpService>::ResBody: Send,
52    <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
53    <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
54{
55    /// Run the HTTP3 server
56    ///
57    /// This function will bind to the address specified in `bind`, listen for incoming connections and handle requests.
58    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
59    pub async fn run(mut self) -> Result<(), Error<F>> {
60        #[cfg(feature = "tracing")]
61        tracing::debug!("starting server");
62
63        // not quite sure why this is necessary but it is
64        self.rustls_config.max_early_data_size = u32::MAX;
65        let crypto = h3_quinn::quinn::crypto::rustls::QuicServerConfig::try_from(self.rustls_config)?;
66        let server_config = h3_quinn::quinn::ServerConfig::with_crypto(Arc::new(crypto));
67
68        // Bind the UDP socket
69        let socket = std::net::UdpSocket::bind(self.bind)?;
70
71        // Runtime for the quinn endpoint
72        let runtime = h3_quinn::quinn::default_runtime()
73            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no async runtime found"))?;
74
75        // Create a child context for the workers so we can shut them down if one of them fails without shutting down the main context
76        let (worker_ctx, worker_handler) = self.ctx.new_child();
77
78        let workers = (0..self.worker_tasks).map(|_n| {
79            let ctx = worker_ctx.clone();
80            let service_factory = self.service_factory.clone();
81            let server_config = server_config.clone();
82            let socket = socket.try_clone().expect("failed to clone socket");
83            let runtime = Arc::clone(&runtime);
84
85            let worker_fut = async move {
86                let endpoint = h3_quinn::quinn::Endpoint::new(
87                    h3_quinn::quinn::EndpointConfig::default(),
88                    Some(server_config),
89                    socket,
90                    runtime,
91                )?;
92
93                #[cfg(feature = "tracing")]
94                tracing::trace!("waiting for connections");
95
96                while let Some(Some(new_conn)) = endpoint.accept().with_context(&ctx).await {
97                    let mut service_factory = service_factory.clone();
98                    let ctx = ctx.clone();
99
100                    tokio::spawn(async move {
101                        let _res: Result<_, Error<F>> = async move {
102                            let Some(conn) = new_conn.with_context(&ctx).await.transpose()? else {
103                                #[cfg(feature = "tracing")]
104                                tracing::trace!("context done while accepting connection");
105                                return Ok(());
106                            };
107                            let addr = conn.remote_address();
108
109                            #[cfg(feature = "tracing")]
110                            tracing::debug!(addr = %addr, "accepted quic connection");
111
112                            let connection_fut = async move {
113                                let Some(mut h3_conn) = h3::server::Connection::new(h3_quinn::Connection::new(conn))
114                                    .with_context(&ctx)
115                                    .await
116                                    .transpose()?
117                                else {
118                                    #[cfg(feature = "tracing")]
119                                    tracing::trace!("context done while establishing connection");
120                                    return Ok(());
121                                };
122
123                                // make a new service for this connection
124                                let http_service = service_factory
125                                    .new_service(addr)
126                                    .await
127                                    .map_err(|e| Error::ServiceFactoryError(e))?;
128
129                                loop {
130                                    match h3_conn.accept().with_context(&ctx).await {
131                                        Some(Ok(Some((req, stream)))) => {
132                                            #[cfg(feature = "tracing")]
133                                            tracing::debug!(method = %req.method(), uri = %req.uri(), "received request");
134
135                                            let (mut send, recv) = stream.split();
136
137                                            let size_hint = req
138                                                .headers()
139                                                .get(http::header::CONTENT_LENGTH)
140                                                .and_then(|len| len.to_str().ok().and_then(|x| x.parse().ok()));
141                                            let body = QuicIncomingBody::new(recv, size_hint);
142                                            let req = req.map(|_| crate::body::IncomingBody::from(body));
143
144                                            let ctx = ctx.clone();
145                                            let mut http_service = http_service.clone();
146                                            tokio::spawn(async move {
147                                                let _res: Result<_, Error<F>> = async move {
148                                                    let resp =
149                                                        http_service.call(req).await.map_err(|e| Error::ServiceError(e))?;
150                                                    let (parts, body) = resp.into_parts();
151
152                                                    send.send_response(http::Response::from_parts(parts, ())).await?;
153                                                    copy_response_body(send, body).await?;
154
155                                                    Ok(())
156                                                }
157                                                .await;
158
159                                                #[cfg(feature = "tracing")]
160                                                if let Err(e) = _res {
161                                                    tracing::warn!(err = %e, "error handling request");
162                                                }
163
164                                                // This moves the context into the async block because it is dropped here
165                                                drop(ctx);
166                                            });
167                                        }
168                                        // indicating no more streams to be received
169                                        Some(Ok(None)) => {
170                                            break;
171                                        }
172                                        Some(Err(err)) => match err.get_error_level() {
173                                            h3::error::ErrorLevel::ConnectionError => return Err(err.into()),
174                                            h3::error::ErrorLevel::StreamError => {
175                                                #[cfg(feature = "tracing")]
176                                                tracing::warn!("error on accept: {}", err);
177                                                continue;
178                                            }
179                                        },
180                                        // context is done
181                                        None => {
182                                            #[cfg(feature = "tracing")]
183                                            tracing::trace!("context done, stopping connection loop");
184                                            break;
185                                        }
186                                    }
187                                }
188
189                                #[cfg(feature = "tracing")]
190                                tracing::trace!("connection closed");
191
192                                Ok(())
193                            };
194
195                            #[cfg(feature = "tracing")]
196                            let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
197
198                            connection_fut.await
199                        }
200                        .await;
201
202                        #[cfg(feature = "tracing")]
203                        if let Err(err) = _res {
204                            tracing::warn!(err = %err, "error handling connection");
205                        }
206                    });
207                }
208
209                // shut down gracefully
210                // wait for connections to be closed before exiting
211                endpoint.wait_idle().await;
212
213                Ok::<_, crate::error::Error<F>>(())
214            };
215
216            #[cfg(feature = "tracing")]
217            let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
218
219            tokio::spawn(worker_fut)
220        });
221
222        if let Err(_e) = futures::future::try_join_all(workers).await {
223            #[cfg(feature = "tracing")]
224            tracing::error!(err = %_e, "error running workers");
225        }
226
227        drop(worker_ctx);
228        worker_handler.shutdown().await;
229
230        #[cfg(feature = "tracing")]
231        tracing::debug!("all workers finished");
232
233        Ok(())
234    }
235}