1use std::collections::HashMap;
4use std::future::poll_fn;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7
8use bytes::Buf;
9use futures_util::future::{BoxFuture, Either};
10use h3::error::ErrorLevel;
11use h3::ext::{Datagram, Protocol};
12use h3::frame::FrameStream;
13use h3::proto::frame::Frame;
14use h3::quic::{self, RecvStream as _, SendStream};
15use h3::server::RequestStream;
16use h3::stream::BufRecvStream;
17use h3::webtransport::SessionId;
18use http::{Method, Request};
19use tokio::sync::{mpsc, oneshot};
20
21use crate::stream::{BidiStream, RecvStream};
22
23pub(crate) struct WebTransportCanUpgrade<C, B>
25where
26 C: quic::Connection<B>,
27 B: Buf,
28{
29 pub session_id: SessionId,
30 pub webtransport_request_tx: mpsc::Sender<WebTransportRequest<C, B>>,
31}
32
33impl<C, B> Clone for WebTransportCanUpgrade<C, B>
34where
35 C: quic::Connection<B>,
36 B: Buf,
37{
38 fn clone(&self) -> Self {
39 Self {
40 session_id: self.session_id,
41 webtransport_request_tx: self.webtransport_request_tx.clone(),
42 }
43 }
44}
45
46type OnUpgrade<C, B> =
47 Box<dyn FnOnce(RequestStream<C, B>) -> BoxFuture<'static, Result<(), h3::Error>> + Send + Sync + 'static>;
48
49pub struct WebTransportUpgradePending<C, B>
51where
52 C: quic::Connection<B>,
53 B: Buf,
54{
55 #[allow(clippy::type_complexity)]
56 pub(crate) complete_upgrade: Arc<Mutex<Option<OnUpgrade<C::BidiStream, B>>>>,
57}
58
59impl<C, B> WebTransportUpgradePending<C, B>
60where
61 C: quic::Connection<B>,
62 B: Buf,
63{
64 #[allow(clippy::type_complexity)]
66 pub fn upgrade(
67 &self,
68 stream: RequestStream<C::BidiStream, B>,
69 ) -> Result<BoxFuture<'static, Result<(), h3::Error>>, RequestStream<C::BidiStream, B>> {
70 let Some(result) = Option::take(&mut *self.complete_upgrade.lock().unwrap()) else {
71 return Err(stream);
72 };
73
74 Ok(result(stream))
75 }
76}
77
78impl<C, B> Clone for WebTransportUpgradePending<C, B>
79where
80 C: quic::Connection<B>,
81 B: Buf,
82{
83 fn clone(&self) -> Self {
84 Self {
85 complete_upgrade: self.complete_upgrade.clone(),
86 }
87 }
88}
89
90struct WebTransportSession<C, B>
91where
92 C: quic::Connection<B>,
93 B: Buf,
94{
95 bidi: mpsc::Sender<BidiStream<C::BidiStream, B>>,
96 uni: mpsc::Sender<RecvStream<C::RecvStream, B>>,
97 datagram: mpsc::Sender<B>,
98}
99
100pub(crate) enum WebTransportRequest<C, B>
101where
102 C: quic::Connection<B>,
103 B: Buf,
104{
105 Upgrade {
106 session_id: SessionId,
107 bidi_request: mpsc::Sender<BidiStream<C::BidiStream, B>>,
108 uni_request: mpsc::Sender<RecvStream<C::RecvStream, B>>,
109 datagram_request: mpsc::Sender<B>,
110 response: oneshot::Sender<(C::OpenStreams, mpsc::UnboundedSender<SessionId>)>,
111 },
112 SendDatagram {
113 session_id: SessionId,
114 datagram: B,
115 resp: oneshot::Sender<Result<(), h3::Error>>,
116 },
117}
118
119pub struct Connection<C, B>
129where
130 C: quic::Connection<B>,
131 B: Buf,
132{
133 pub(crate) incoming: Incoming<C, B>,
134 pub(crate) driver: ConnectionDriver<C, B>,
135}
136
137pub struct ConnectionDriver<C, B>
139where
140 C: quic::Connection<B>,
141 B: Buf,
142{
143 webtransport_session_map: HashMap<SessionId, WebTransportSession<C, B>>,
144 #[allow(clippy::type_complexity)]
145 request_sender: mpsc::Sender<(Request<()>, RequestStream<C::BidiStream, B>)>,
146 webtransport_request_rx: mpsc::Receiver<WebTransportRequest<C, B>>,
147 webtransport_request_tx: mpsc::Sender<WebTransportRequest<C, B>>,
148 session_close_rx: mpsc::UnboundedReceiver<SessionId>,
149 session_close_tx: mpsc::UnboundedSender<SessionId>,
150 inner: h3::server::Connection<C, B>,
151}
152
153impl<C, B, E, E2> ConnectionDriver<C, B>
154where
155 C: quic::Connection<B> + quic::SendDatagramExt<B, Error = E> + quic::RecvDatagramExt<Buf = B, Error = E2> + 'static,
156 B: Buf + 'static + Send + Sync,
157 C::AcceptError: Send + Sync,
158 C::BidiStream: Send + Sync,
159 C::RecvStream: Send + Sync,
160 C::OpenStreams: Send + Sync,
161 E: Into<h3::Error>,
162 E2: Into<h3::Error>,
163{
164 pub async fn drive(&mut self) -> Result<(), h3::Error> {
167 enum Winner<R, U, D, W> {
168 Request(R),
169 Uni(U),
170 Datagram(D),
171 WebTransport(W),
172 Close(SessionId),
173 }
174
175 let poll_inner = |this: &mut Self, cx: &mut Context<'_>| {
178 match this.session_close_rx.poll_recv(cx) {
179 Poll::Ready(Some(session_id)) => return Poll::Ready(Some(Ok(Winner::Close(session_id)))),
180 Poll::Ready(None) => {}
181 Poll::Pending => {}
182 }
183
184 match this.webtransport_request_rx.poll_recv(cx) {
185 Poll::Ready(Some(r)) => return Poll::Ready(Some(Ok(Winner::WebTransport(r)))),
186 Poll::Ready(None) => {}
187 Poll::Pending => {}
188 }
189
190 match this.inner.poll_accept_request(cx) {
191 Poll::Ready(Ok(None)) => return Poll::Ready(None),
192 Poll::Ready(Ok(Some(r))) => return Poll::Ready(Some(Ok(Winner::Request(r)))),
193 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
194 Poll::Pending => {}
195 }
196
197 match this.inner.inner.poll_accept_recv(cx) {
198 Ok(()) => {}
199 Err(err) => return Poll::Ready(Some(Err(err))),
200 }
201
202 let streams = this.inner.inner.accepted_streams_mut();
203 if let Some((id, stream)) = streams.wt_uni_streams.pop() {
204 return Poll::Ready(Some(Ok(Winner::Uni((id, RecvStream::new(stream))))));
205 }
206
207 match this.inner.inner.conn.poll_accept_datagram(cx) {
208 Poll::Ready(Ok(Some(r))) => match Datagram::decode(r) {
209 Ok(d) => return Poll::Ready(Some(Ok(Winner::Datagram(d)))),
210 Err(err) => return Poll::Ready(Some(Err(err))),
211 },
212 Poll::Ready(Ok(None)) => return Poll::Ready(None),
213 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
214 Poll::Pending => {}
215 }
216
217 Poll::Pending
218 };
219
220 loop {
221 let Some(winner) = poll_fn(|cx| poll_inner(self, cx)).await else {
222 return Ok(());
223 };
224
225 let winner = match winner {
226 Ok(w) => w,
227 Err(err) => {
228 match err.kind() {
229 h3::error::Kind::Closed => return Ok(()),
230 h3::error::Kind::Application {
231 code,
232 reason,
233 level: ErrorLevel::ConnectionError,
234 ..
235 } => {
236 return Err(self
237 .inner
238 .close(code, reason.unwrap_or_else(|| String::into_boxed_str(String::from("")))))
239 }
240 _ => return Err(err),
241 };
242 }
243 };
244
245 let mut stream = match winner {
246 Winner::Request(s) => FrameStream::new(BufRecvStream::new(s)),
247 Winner::Uni((session_id, mut stream)) => {
248 if let Some(webtransport) = self.webtransport_session_map.get_mut(&session_id) {
249 match webtransport.uni.send(stream).await {
250 Ok(_) => continue,
251 Err(err) => {
252 stream = err.0;
253 }
254 }
255 }
256
257 stream.stop_sending(h3::error::Code::H3_REQUEST_REJECTED.value());
259 continue;
260 }
261 Winner::Datagram(d) => {
262 if let Some(webtransport) = self.webtransport_session_map.get_mut(&d.stream_id().into()) {
263 webtransport.datagram.send(d.into_payload()).await.ok();
265 }
266 continue;
267 }
268 Winner::WebTransport(WebTransportRequest::SendDatagram {
269 session_id,
270 datagram,
271 resp,
272 }) => {
273 resp.send(self.inner.send_datagram(session_id.into(), datagram)).ok();
274 continue;
275 }
276 Winner::WebTransport(WebTransportRequest::Upgrade {
277 session_id,
278 bidi_request,
279 uni_request,
280 datagram_request,
281 response,
282 }) => {
283 if response
284 .send((self.inner.inner.conn.opener(), self.session_close_tx.clone()))
285 .is_ok()
286 {
287 self.webtransport_session_map.insert(
288 session_id,
289 WebTransportSession {
290 bidi: bidi_request,
291 uni: uni_request,
292 datagram: datagram_request,
293 },
294 );
295 }
296 continue;
297 }
298 Winner::Close(session_id) => {
299 self.webtransport_session_map.remove(&session_id);
300 continue;
301 }
302 };
303
304 let frame = poll_fn(|cx| stream.poll_next(cx)).await;
308
309 match frame {
310 Ok(None) => return Ok(()),
311 Ok(Some(Frame::WebTransportStream(session_id))) => {
312 let mut stream = BidiStream::new(stream.into_inner());
313 if let Some(session) = self.webtransport_session_map.get_mut(&session_id) {
314 match session.bidi.send(stream).await {
315 Ok(_) => continue,
316 Err(err) => {
317 stream = err.0;
318 }
319 }
320 }
321
322 stream.stop_sending(h3::error::Code::H3_REQUEST_REJECTED.value());
324 stream.reset(h3::error::Code::H3_REQUEST_REJECTED.value());
325 continue;
326 }
327 frame => {
329 let Some(req) = self.inner.accept_with_frame(stream, frame)? else {
330 return Ok(());
331 };
332
333 let (mut req, resp) = req.resolve().await?;
334
335 if validate_wt_connect(&req) {
336 req.extensions_mut().insert(WebTransportCanUpgrade {
337 session_id: resp.id().into(),
338 webtransport_request_tx: self.webtransport_request_tx.clone(),
339 });
340 }
341
342 if self.request_sender.send((req, resp)).await.is_err() {
343 return Err(self
344 .inner
345 .close(h3::error::Code::H3_INTERNAL_ERROR, "request sender channel closed"));
346 }
347 }
348 }
349 }
350 }
351}
352
353impl<C, B> ConnectionDriver<C, B>
354where
355 C: quic::Connection<B>,
356 B: Buf,
357{
358 pub fn close(&mut self, code: h3::error::Code, reason: &str) -> h3::Error {
360 self.inner.close(code, reason)
361 }
362}
363
364pub struct Incoming<C, B>
366where
367 C: quic::Connection<B>,
368 B: Buf,
369{
370 #[allow(clippy::type_complexity)]
371 recv: mpsc::Receiver<(Request<()>, RequestStream<C::BidiStream, B>)>,
372}
373
374impl<C, B, E, E2> Connection<C, B>
375where
376 C: quic::Connection<B> + quic::SendDatagramExt<B, Error = E> + quic::RecvDatagramExt<Buf = B, Error = E2> + 'static,
377 B: Buf + 'static + Send + Sync,
378 C::AcceptError: Send + Sync,
379 C::BidiStream: Send + Sync,
380 C::RecvStream: Send + Sync,
381 C::OpenStreams: Send + Sync,
382 E: Into<h3::Error>,
383 E2: Into<h3::Error>,
384{
385 pub fn new(inner: h3::server::Connection<C, B>) -> Self {
387 let (request_sender, request_recv) = mpsc::channel(128);
388 let (webtransport_request_tx, webtransport_request_rx) = mpsc::channel(128);
389 let (session_close_tx, session_close_rx) = mpsc::unbounded_channel();
390
391 Self {
392 driver: ConnectionDriver {
393 webtransport_session_map: HashMap::new(),
394 request_sender,
395 webtransport_request_rx,
396 webtransport_request_tx,
397 session_close_rx,
398 session_close_tx,
399 inner,
400 },
401 incoming: Incoming { recv: request_recv },
402 }
403 }
404
405 pub fn split(self) -> (Incoming<C, B>, ConnectionDriver<C, B>) {
407 (self.incoming, self.driver)
408 }
409
410 pub fn driver(&mut self) -> &mut ConnectionDriver<C, B> {
412 &mut self.driver
413 }
414
415 pub async fn accept(&mut self) -> Result<Option<(Request<()>, RequestStream<C::BidiStream, B>)>, h3::Error> {
419 match futures_util::future::select(std::pin::pin!(self.incoming.accept()), std::pin::pin!(self.driver.drive())).await
420 {
421 Either::Left((accept, _)) => Ok(accept),
422 Either::Right((drive, _)) => drive.map(|_| None),
423 }
424 }
425}
426
427impl<C, B> Connection<C, B>
428where
429 C: quic::Connection<B>,
430 B: Buf,
431{
432 pub fn close(&mut self, code: h3::error::Code, reason: &str) -> h3::Error {
434 self.driver.close(code, reason)
435 }
436}
437
438impl<C, B> Incoming<C, B>
439where
440 C: quic::Connection<B>,
441 B: Buf,
442{
443 pub async fn accept(&mut self) -> Option<(Request<()>, RequestStream<C::BidiStream, B>)> {
445 self.recv.recv().await
446 }
447
448 #[allow(clippy::type_complexity)]
450 pub fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Option<(Request<()>, RequestStream<C::BidiStream, B>)>> {
451 self.recv.poll_recv(cx)
452 }
453}
454
455fn validate_wt_connect(request: &Request<()>) -> bool {
456 let protocol = request.extensions().get::<Protocol>();
457 matches!((request.method(), protocol), (&Method::CONNECT, Some(p)) if p == &Protocol::WEB_TRANSPORT)
458}