scuffle_http/backend/h3/
mod.rs1use std::fmt::Debug;
2use std::io;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use body::QuicIncomingBody;
7use scuffle_context::ContextFutExt;
8#[cfg(feature = "tracing")]
9use tracing::Instrument;
10use utils::copy_response_body;
11
12use crate::error::Error;
13use crate::service::{HttpService, HttpServiceFactory};
14
15pub mod body;
16mod utils;
17
18#[derive(bon::Builder, Debug, Clone)]
24pub struct Http3Backend<F> {
25 #[builder(default = scuffle_context::Context::global())]
27 ctx: scuffle_context::Context,
28 #[builder(default = 1)]
30 worker_tasks: usize,
31 service_factory: F,
33 bind: SocketAddr,
38 rustls_config: rustls::ServerConfig,
43}
44
45impl<F> Http3Backend<F>
46where
47 F: HttpServiceFactory + Clone + Send + 'static,
48 F::Error: std::error::Error + Send,
49 F::Service: Clone + Send + 'static,
50 <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
51 <F::Service as HttpService>::ResBody: Send,
52 <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
53 <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
54{
55 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
59 pub async fn run(mut self) -> Result<(), Error<F>> {
60 #[cfg(feature = "tracing")]
61 tracing::debug!("starting server");
62
63 self.rustls_config.max_early_data_size = u32::MAX;
65 let crypto = h3_quinn::quinn::crypto::rustls::QuicServerConfig::try_from(self.rustls_config)?;
66 let server_config = h3_quinn::quinn::ServerConfig::with_crypto(Arc::new(crypto));
67
68 let socket = std::net::UdpSocket::bind(self.bind)?;
70
71 let runtime = h3_quinn::quinn::default_runtime()
73 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no async runtime found"))?;
74
75 let (worker_ctx, worker_handler) = self.ctx.new_child();
77
78 let workers = (0..self.worker_tasks).map(|_n| {
79 let ctx = worker_ctx.clone();
80 let service_factory = self.service_factory.clone();
81 let server_config = server_config.clone();
82 let socket = socket.try_clone().expect("failed to clone socket");
83 let runtime = Arc::clone(&runtime);
84
85 let worker_fut = async move {
86 let endpoint = h3_quinn::quinn::Endpoint::new(
87 h3_quinn::quinn::EndpointConfig::default(),
88 Some(server_config),
89 socket,
90 runtime,
91 )?;
92
93 #[cfg(feature = "tracing")]
94 tracing::trace!("waiting for connections");
95
96 while let Some(Some(new_conn)) = endpoint.accept().with_context(&ctx).await {
97 let mut service_factory = service_factory.clone();
98 let ctx = ctx.clone();
99
100 tokio::spawn(async move {
101 let _res: Result<_, Error<F>> = async move {
102 let Some(conn) = new_conn.with_context(&ctx).await.transpose()? else {
103 #[cfg(feature = "tracing")]
104 tracing::trace!("context done while accepting connection");
105 return Ok(());
106 };
107 let addr = conn.remote_address();
108
109 #[cfg(feature = "tracing")]
110 tracing::debug!(addr = %addr, "accepted quic connection");
111
112 let connection_fut = async move {
113 let Some(mut h3_conn) = h3::server::Connection::new(h3_quinn::Connection::new(conn))
114 .with_context(&ctx)
115 .await
116 .transpose()?
117 else {
118 #[cfg(feature = "tracing")]
119 tracing::trace!("context done while establishing connection");
120 return Ok(());
121 };
122
123 let http_service = service_factory
125 .new_service(addr)
126 .await
127 .map_err(|e| Error::ServiceFactoryError(e))?;
128
129 loop {
130 match h3_conn.accept().with_context(&ctx).await {
131 Some(Ok(Some((req, stream)))) => {
132 #[cfg(feature = "tracing")]
133 tracing::debug!(method = %req.method(), uri = %req.uri(), "received request");
134
135 let (mut send, recv) = stream.split();
136
137 let size_hint = req
138 .headers()
139 .get(http::header::CONTENT_LENGTH)
140 .and_then(|len| len.to_str().ok().and_then(|x| x.parse().ok()));
141 let body = QuicIncomingBody::new(recv, size_hint);
142 let req = req.map(|_| crate::body::IncomingBody::from(body));
143
144 let ctx = ctx.clone();
145 let mut http_service = http_service.clone();
146 tokio::spawn(async move {
147 let _res: Result<_, Error<F>> = async move {
148 let resp =
149 http_service.call(req).await.map_err(|e| Error::ServiceError(e))?;
150 let (parts, body) = resp.into_parts();
151
152 send.send_response(http::Response::from_parts(parts, ())).await?;
153 copy_response_body(send, body).await?;
154
155 Ok(())
156 }
157 .await;
158
159 #[cfg(feature = "tracing")]
160 if let Err(e) = _res {
161 tracing::warn!(err = %e, "error handling request");
162 }
163
164 drop(ctx);
166 });
167 }
168 Some(Ok(None)) => {
170 break;
171 }
172 Some(Err(err)) => match err.get_error_level() {
173 h3::error::ErrorLevel::ConnectionError => return Err(err.into()),
174 h3::error::ErrorLevel::StreamError => {
175 #[cfg(feature = "tracing")]
176 tracing::warn!("error on accept: {}", err);
177 continue;
178 }
179 },
180 None => {
182 #[cfg(feature = "tracing")]
183 tracing::trace!("context done, stopping connection loop");
184 break;
185 }
186 }
187 }
188
189 #[cfg(feature = "tracing")]
190 tracing::trace!("connection closed");
191
192 Ok(())
193 };
194
195 #[cfg(feature = "tracing")]
196 let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
197
198 connection_fut.await
199 }
200 .await;
201
202 #[cfg(feature = "tracing")]
203 if let Err(err) = _res {
204 tracing::warn!(err = %err, "error handling connection");
205 }
206 });
207 }
208
209 endpoint.wait_idle().await;
212
213 Ok::<_, crate::error::Error<F>>(())
214 };
215
216 #[cfg(feature = "tracing")]
217 let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
218
219 tokio::spawn(worker_fut)
220 });
221
222 if let Err(_e) = futures::future::try_join_all(workers).await {
223 #[cfg(feature = "tracing")]
224 tracing::error!(err = %_e, "error running workers");
225 }
226
227 drop(worker_ctx);
228 worker_handler.shutdown().await;
229
230 #[cfg(feature = "tracing")]
231 tracing::debug!("all workers finished");
232
233 Ok(())
234 }
235}