scuffle_http/backend/hyper/
mod.rs1use std::fmt::Debug;
2use std::net::SocketAddr;
3
4use scuffle_context::ContextFutExt;
5#[cfg(feature = "tracing")]
6use tracing::Instrument;
7
8use crate::error::Error;
9use crate::service::{HttpService, HttpServiceFactory};
10
11mod handler;
12mod stream;
13mod utils;
14
15#[derive(Debug, Clone, bon::Builder)]
21pub struct HyperBackend<F> {
22 #[builder(default = scuffle_context::Context::global())]
24 ctx: scuffle_context::Context,
25 #[builder(default = 1)]
27 worker_tasks: usize,
28 service_factory: F,
30 bind: SocketAddr,
35 #[cfg(feature = "tls-rustls")]
40 #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
41 rustls_config: Option<rustls::ServerConfig>,
42 #[cfg(feature = "http1")]
44 #[cfg_attr(docsrs, doc(cfg(feature = "http1")))]
45 #[builder(default = true)]
46 http1_enabled: bool,
47 #[cfg(feature = "http2")]
49 #[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
50 #[builder(default = true)]
51 http2_enabled: bool,
52}
53
54impl<F> HyperBackend<F>
55where
56 F: HttpServiceFactory + Clone + Send + 'static,
57 F::Error: std::error::Error + Send,
58 F::Service: Clone + Send + 'static,
59 <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
60 <F::Service as HttpService>::ResBody: Send,
61 <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
62 <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
63{
64 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
68 #[allow(unused_mut)] pub async fn run(mut self) -> Result<(), Error<F>> {
70 #[cfg(feature = "tracing")]
71 tracing::debug!("starting server");
72
73 #[cfg(feature = "tls-rustls")]
76 if let Some(rustls_config) = self.rustls_config.as_mut() {
77 rustls_config.max_early_data_size = 0;
78 }
79
80 let std_listener = std::net::TcpListener::bind(self.bind)?;
82 std_listener.set_nonblocking(true)?;
85
86 #[cfg(feature = "tls-rustls")]
87 let tls_acceptor = self
88 .rustls_config
89 .map(|c| tokio_rustls::TlsAcceptor::from(std::sync::Arc::new(c)));
90
91 let (worker_ctx, worker_handler) = self.ctx.new_child();
93
94 let workers = (0..self.worker_tasks).map(|_n| {
95 let service_factory = self.service_factory.clone();
96 let ctx = worker_ctx.clone();
97 let std_listener = std_listener.try_clone().expect("failed to clone listener");
98 let listener = tokio::net::TcpListener::from_std(std_listener).expect("failed to create tokio listener");
99 #[cfg(feature = "tls-rustls")]
100 let tls_acceptor = tls_acceptor.clone();
101
102 let worker_fut = async move {
103 loop {
104 #[cfg(feature = "tracing")]
105 tracing::trace!("waiting for connections");
106
107 let (mut stream, addr) = match listener.accept().with_context(ctx.clone()).await {
108 Some(Ok((tcp_stream, addr))) => (stream::Stream::Tcp(tcp_stream), addr),
109 Some(Err(e)) if utils::is_fatal_tcp_error(&e) => {
110 #[cfg(feature = "tracing")]
111 tracing::error!(err = %e, "failed to accept tcp connection");
112 return Err(Error::<F>::from(e));
113 }
114 Some(Err(_)) => continue,
115 None => {
116 #[cfg(feature = "tracing")]
117 tracing::trace!("context done, stopping listener");
118 break;
119 }
120 };
121
122 #[cfg(feature = "tracing")]
123 tracing::trace!(addr = %addr, "accepted tcp connection");
124
125 let ctx = ctx.clone();
126 #[cfg(feature = "tls-rustls")]
127 let tls_acceptor = tls_acceptor.clone();
128 let mut service_factory = service_factory.clone();
129
130 let connection_fut = async move {
131 #[cfg(feature = "tls-rustls")]
133 if let Some(tls_acceptor) = tls_acceptor {
134 #[cfg(feature = "tracing")]
135 tracing::trace!("accepting tls connection");
136
137 stream = match stream.try_accept_tls(&tls_acceptor).with_context(&ctx).await {
138 Some(Ok(stream)) => stream,
139 Some(Err(_err)) => {
140 #[cfg(feature = "tracing")]
141 tracing::warn!(err = %_err, "failed to accept tls connection");
142 return;
143 }
144 None => {
145 #[cfg(feature = "tracing")]
146 tracing::trace!("context done, stopping tls acceptor");
147 return;
148 }
149 };
150
151 #[cfg(feature = "tracing")]
152 tracing::trace!("accepted tls connection");
153 }
154
155 let http_service = match service_factory.new_service(addr).await {
157 Ok(service) => service,
158 Err(_e) => {
159 #[cfg(feature = "tracing")]
160 tracing::warn!(err = %_e, "failed to create service");
161 return;
162 }
163 };
164
165 #[cfg(feature = "tracing")]
166 tracing::trace!("handling connection");
167
168 #[cfg(feature = "http1")]
169 let http1 = self.http1_enabled;
170 #[cfg(not(feature = "http1"))]
171 let http1 = false;
172
173 #[cfg(feature = "http2")]
174 let http2 = self.http2_enabled;
175 #[cfg(not(feature = "http2"))]
176 let http2 = false;
177
178 let _res = handler::handle_connection::<F, _, _>(ctx, http_service, stream, http1, http2).await;
179
180 #[cfg(feature = "tracing")]
181 if let Err(e) = _res {
182 tracing::warn!(err = %e, "error handling connection");
183 }
184
185 #[cfg(feature = "tracing")]
186 tracing::trace!("connection closed");
187 };
188
189 #[cfg(feature = "tracing")]
190 let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
191
192 tokio::spawn(connection_fut);
193 }
194
195 #[cfg(feature = "tracing")]
196 tracing::trace!("listener closed");
197
198 Ok(())
199 };
200
201 #[cfg(feature = "tracing")]
202 let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
203
204 tokio::spawn(worker_fut)
205 });
206
207 match futures::future::try_join_all(workers).await {
208 Ok(res) => {
209 for r in res {
210 if let Err(e) = r {
211 drop(worker_ctx);
212 worker_handler.shutdown().await;
213 return Err(e);
214 }
215 }
216 }
217 Err(_e) => {
218 #[cfg(feature = "tracing")]
219 tracing::error!(err = %_e, "error running workers");
220 }
221 }
222
223 drop(worker_ctx);
224 worker_handler.shutdown().await;
225
226 #[cfg(feature = "tracing")]
227 tracing::debug!("all workers finished");
228
229 Ok(())
230 }
231}