scuffle_h3_webtransport/
server.rs

1//! Provides the server side WebTransport session
2
3use std::collections::HashMap;
4use std::future::poll_fn;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7
8use bytes::Buf;
9use futures_util::future::{BoxFuture, Either};
10use h3::error::ErrorLevel;
11use h3::ext::{Datagram, Protocol};
12use h3::frame::FrameStream;
13use h3::proto::frame::Frame;
14use h3::quic::{self, RecvStream as _, SendStream};
15use h3::server::RequestStream;
16use h3::stream::BufRecvStream;
17use h3::webtransport::SessionId;
18use http::{Method, Request};
19use tokio::sync::{mpsc, oneshot};
20
21use crate::stream::{BidiStream, RecvStream};
22
23/// A struct used when upgrading a request to a webtransport session
24pub(crate) struct WebTransportCanUpgrade<C, B>
25where
26    C: quic::Connection<B>,
27    B: Buf,
28{
29    pub session_id: SessionId,
30    pub webtransport_request_tx: mpsc::Sender<WebTransportRequest<C, B>>,
31}
32
33impl<C, B> Clone for WebTransportCanUpgrade<C, B>
34where
35    C: quic::Connection<B>,
36    B: Buf,
37{
38    fn clone(&self) -> Self {
39        Self {
40            session_id: self.session_id,
41            webtransport_request_tx: self.webtransport_request_tx.clone(),
42        }
43    }
44}
45
46type OnUpgrade<C, B> =
47    Box<dyn FnOnce(RequestStream<C, B>) -> BoxFuture<'static, Result<(), h3::Error>> + Send + Sync + 'static>;
48
49/// A struct used when upgrading a request to a webtransport session
50pub struct WebTransportUpgradePending<C, B>
51where
52    C: quic::Connection<B>,
53    B: Buf,
54{
55    #[allow(clippy::type_complexity)]
56    pub(crate) complete_upgrade: Arc<Mutex<Option<OnUpgrade<C::BidiStream, B>>>>,
57}
58
59impl<C, B> WebTransportUpgradePending<C, B>
60where
61    C: quic::Connection<B>,
62    B: Buf,
63{
64    /// Completes the upgrade to a WebTransport session
65    #[allow(clippy::type_complexity)]
66    pub fn upgrade(
67        &self,
68        stream: RequestStream<C::BidiStream, B>,
69    ) -> Result<BoxFuture<'static, Result<(), h3::Error>>, RequestStream<C::BidiStream, B>> {
70        let Some(result) = Option::take(&mut *self.complete_upgrade.lock().unwrap()) else {
71            return Err(stream);
72        };
73
74        Ok(result(stream))
75    }
76}
77
78impl<C, B> Clone for WebTransportUpgradePending<C, B>
79where
80    C: quic::Connection<B>,
81    B: Buf,
82{
83    fn clone(&self) -> Self {
84        Self {
85            complete_upgrade: self.complete_upgrade.clone(),
86        }
87    }
88}
89
90struct WebTransportSession<C, B>
91where
92    C: quic::Connection<B>,
93    B: Buf,
94{
95    bidi: mpsc::Sender<BidiStream<C::BidiStream, B>>,
96    uni: mpsc::Sender<RecvStream<C::RecvStream, B>>,
97    datagram: mpsc::Sender<B>,
98}
99
100pub(crate) enum WebTransportRequest<C, B>
101where
102    C: quic::Connection<B>,
103    B: Buf,
104{
105    Upgrade {
106        session_id: SessionId,
107        bidi_request: mpsc::Sender<BidiStream<C::BidiStream, B>>,
108        uni_request: mpsc::Sender<RecvStream<C::RecvStream, B>>,
109        datagram_request: mpsc::Sender<B>,
110        response: oneshot::Sender<(C::OpenStreams, mpsc::UnboundedSender<SessionId>)>,
111    },
112    SendDatagram {
113        session_id: SessionId,
114        datagram: B,
115        resp: oneshot::Sender<Result<(), h3::Error>>,
116    },
117}
118
119// A WebTransport server that allows incoming requests to be upgraded to
120// `WebTransportSessions`
121//
122// The [`WebTransportServer`] struct manages a connection from the side of the
123// HTTP/3 server
124//
125// Create a new Instance with [`WebTransportServer::new()`].
126// Accept incoming requests with [`WebTransportServer::accept()`].
127// And shutdown a connection with [`WebTransportServer::shutdown()`].
128pub struct Connection<C, B>
129where
130    C: quic::Connection<B>,
131    B: Buf,
132{
133    pub(crate) incoming: Incoming<C, B>,
134    pub(crate) driver: ConnectionDriver<C, B>,
135}
136
137/// The driver for the WebTransport connection
138pub struct ConnectionDriver<C, B>
139where
140    C: quic::Connection<B>,
141    B: Buf,
142{
143    webtransport_session_map: HashMap<SessionId, WebTransportSession<C, B>>,
144    #[allow(clippy::type_complexity)]
145    request_sender: mpsc::Sender<(Request<()>, RequestStream<C::BidiStream, B>)>,
146    webtransport_request_rx: mpsc::Receiver<WebTransportRequest<C, B>>,
147    webtransport_request_tx: mpsc::Sender<WebTransportRequest<C, B>>,
148    session_close_rx: mpsc::UnboundedReceiver<SessionId>,
149    session_close_tx: mpsc::UnboundedSender<SessionId>,
150    inner: h3::server::Connection<C, B>,
151}
152
153impl<C, B, E, E2> ConnectionDriver<C, B>
154where
155    C: quic::Connection<B> + quic::SendDatagramExt<B, Error = E> + quic::RecvDatagramExt<Buf = B, Error = E2> + 'static,
156    B: Buf + 'static + Send + Sync,
157    C::AcceptError: Send + Sync,
158    C::BidiStream: Send + Sync,
159    C::RecvStream: Send + Sync,
160    C::OpenStreams: Send + Sync,
161    E: Into<h3::Error>,
162    E2: Into<h3::Error>,
163{
164    /// Drives the server, accepting requests from the underlying HTTP/3
165    /// connection, and forwarding datagrams to the webtransport sessions
166    pub async fn drive(&mut self) -> Result<(), h3::Error> {
167        enum Winner<R, U, D, W> {
168            Request(R),
169            Uni(U),
170            Datagram(D),
171            WebTransport(W),
172            Close(SessionId),
173        }
174
175        // Polls the underlying HTTP/3 connection for incoming requests
176        // Yields a winner of either a request, a uni-stream, or a datagram (if enabled)
177        let poll_inner = |this: &mut Self, cx: &mut Context<'_>| {
178            match this.session_close_rx.poll_recv(cx) {
179                Poll::Ready(Some(session_id)) => return Poll::Ready(Some(Ok(Winner::Close(session_id)))),
180                Poll::Ready(None) => {}
181                Poll::Pending => {}
182            }
183
184            match this.webtransport_request_rx.poll_recv(cx) {
185                Poll::Ready(Some(r)) => return Poll::Ready(Some(Ok(Winner::WebTransport(r)))),
186                Poll::Ready(None) => {}
187                Poll::Pending => {}
188            }
189
190            match this.inner.poll_accept_request(cx) {
191                Poll::Ready(Ok(None)) => return Poll::Ready(None),
192                Poll::Ready(Ok(Some(r))) => return Poll::Ready(Some(Ok(Winner::Request(r)))),
193                Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
194                Poll::Pending => {}
195            }
196
197            match this.inner.inner.poll_accept_recv(cx) {
198                Ok(()) => {}
199                Err(err) => return Poll::Ready(Some(Err(err))),
200            }
201
202            let streams = this.inner.inner.accepted_streams_mut();
203            if let Some((id, stream)) = streams.wt_uni_streams.pop() {
204                return Poll::Ready(Some(Ok(Winner::Uni((id, RecvStream::new(stream))))));
205            }
206
207            match this.inner.inner.conn.poll_accept_datagram(cx) {
208                Poll::Ready(Ok(Some(r))) => match Datagram::decode(r) {
209                    Ok(d) => return Poll::Ready(Some(Ok(Winner::Datagram(d)))),
210                    Err(err) => return Poll::Ready(Some(Err(err))),
211                },
212                Poll::Ready(Ok(None)) => return Poll::Ready(None),
213                Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
214                Poll::Pending => {}
215            }
216
217            Poll::Pending
218        };
219
220        loop {
221            let Some(winner) = poll_fn(|cx| poll_inner(self, cx)).await else {
222                return Ok(());
223            };
224
225            let winner = match winner {
226                Ok(w) => w,
227                Err(err) => {
228                    match err.kind() {
229                        h3::error::Kind::Closed => return Ok(()),
230                        h3::error::Kind::Application {
231                            code,
232                            reason,
233                            level: ErrorLevel::ConnectionError,
234                            ..
235                        } => {
236                            return Err(self
237                                .inner
238                                .close(code, reason.unwrap_or_else(|| String::into_boxed_str(String::from("")))))
239                        }
240                        _ => return Err(err),
241                    };
242                }
243            };
244
245            let mut stream = match winner {
246                Winner::Request(s) => FrameStream::new(BufRecvStream::new(s)),
247                Winner::Uni((session_id, mut stream)) => {
248                    if let Some(webtransport) = self.webtransport_session_map.get_mut(&session_id) {
249                        match webtransport.uni.send(stream).await {
250                            Ok(_) => continue,
251                            Err(err) => {
252                                stream = err.0;
253                            }
254                        }
255                    }
256
257                    // We reject the stream because it is not a for a valid webtransport session
258                    stream.stop_sending(h3::error::Code::H3_REQUEST_REJECTED.value());
259                    continue;
260                }
261                Winner::Datagram(d) => {
262                    if let Some(webtransport) = self.webtransport_session_map.get_mut(&d.stream_id().into()) {
263                        // We dont care about datagram drops because they do not have any state.
264                        webtransport.datagram.send(d.into_payload()).await.ok();
265                    }
266                    continue;
267                }
268                Winner::WebTransport(WebTransportRequest::SendDatagram {
269                    session_id,
270                    datagram,
271                    resp,
272                }) => {
273                    resp.send(self.inner.send_datagram(session_id.into(), datagram)).ok();
274                    continue;
275                }
276                Winner::WebTransport(WebTransportRequest::Upgrade {
277                    session_id,
278                    bidi_request,
279                    uni_request,
280                    datagram_request,
281                    response,
282                }) => {
283                    if response
284                        .send((self.inner.inner.conn.opener(), self.session_close_tx.clone()))
285                        .is_ok()
286                    {
287                        self.webtransport_session_map.insert(
288                            session_id,
289                            WebTransportSession {
290                                bidi: bidi_request,
291                                uni: uni_request,
292                                datagram: datagram_request,
293                            },
294                        );
295                    }
296                    continue;
297                }
298                Winner::Close(session_id) => {
299                    self.webtransport_session_map.remove(&session_id);
300                    continue;
301                }
302            };
303
304            // Read the first frame.
305            //
306            // This will determine if it is a webtransport bi-stream or a request stream
307            let frame = poll_fn(|cx| stream.poll_next(cx)).await;
308
309            match frame {
310                Ok(None) => return Ok(()),
311                Ok(Some(Frame::WebTransportStream(session_id))) => {
312                    let mut stream = BidiStream::new(stream.into_inner());
313                    if let Some(session) = self.webtransport_session_map.get_mut(&session_id) {
314                        match session.bidi.send(stream).await {
315                            Ok(_) => continue,
316                            Err(err) => {
317                                stream = err.0;
318                            }
319                        }
320                    }
321
322                    // We reject the stream because it is not a for a valid webtransport session
323                    stream.stop_sending(h3::error::Code::H3_REQUEST_REJECTED.value());
324                    stream.reset(h3::error::Code::H3_REQUEST_REJECTED.value());
325                    continue;
326                }
327                // Make the underlying HTTP/3 connection handle the rest
328                frame => {
329                    let Some(req) = self.inner.accept_with_frame(stream, frame)? else {
330                        return Ok(());
331                    };
332
333                    let (mut req, resp) = req.resolve().await?;
334
335                    if validate_wt_connect(&req) {
336                        req.extensions_mut().insert(WebTransportCanUpgrade {
337                            session_id: resp.id().into(),
338                            webtransport_request_tx: self.webtransport_request_tx.clone(),
339                        });
340                    }
341
342                    if self.request_sender.send((req, resp)).await.is_err() {
343                        return Err(self
344                            .inner
345                            .close(h3::error::Code::H3_INTERNAL_ERROR, "request sender channel closed"));
346                    }
347                }
348            }
349        }
350    }
351}
352
353impl<C, B> ConnectionDriver<C, B>
354where
355    C: quic::Connection<B>,
356    B: Buf,
357{
358    /// Closes the connection with a code and a reason.
359    pub fn close(&mut self, code: h3::error::Code, reason: &str) -> h3::Error {
360        self.inner.close(code, reason)
361    }
362}
363
364/// Accepts incoming requests
365pub struct Incoming<C, B>
366where
367    C: quic::Connection<B>,
368    B: Buf,
369{
370    #[allow(clippy::type_complexity)]
371    recv: mpsc::Receiver<(Request<()>, RequestStream<C::BidiStream, B>)>,
372}
373
374impl<C, B, E, E2> Connection<C, B>
375where
376    C: quic::Connection<B> + quic::SendDatagramExt<B, Error = E> + quic::RecvDatagramExt<Buf = B, Error = E2> + 'static,
377    B: Buf + 'static + Send + Sync,
378    C::AcceptError: Send + Sync,
379    C::BidiStream: Send + Sync,
380    C::RecvStream: Send + Sync,
381    C::OpenStreams: Send + Sync,
382    E: Into<h3::Error>,
383    E2: Into<h3::Error>,
384{
385    /// Create a new `WebTransportServer`
386    pub fn new(inner: h3::server::Connection<C, B>) -> Self {
387        let (request_sender, request_recv) = mpsc::channel(128);
388        let (webtransport_request_tx, webtransport_request_rx) = mpsc::channel(128);
389        let (session_close_tx, session_close_rx) = mpsc::unbounded_channel();
390
391        Self {
392            driver: ConnectionDriver {
393                webtransport_session_map: HashMap::new(),
394                request_sender,
395                webtransport_request_rx,
396                webtransport_request_tx,
397                session_close_rx,
398                session_close_tx,
399                inner,
400            },
401            incoming: Incoming { recv: request_recv },
402        }
403    }
404
405    /// Take the request acceptor
406    pub fn split(self) -> (Incoming<C, B>, ConnectionDriver<C, B>) {
407        (self.incoming, self.driver)
408    }
409
410    /// Get a mutable reference to the driver
411    pub fn driver(&mut self) -> &mut ConnectionDriver<C, B> {
412        &mut self.driver
413    }
414
415    /// Accepts an incoming request
416    /// Internally this method will drive the server until an incoming request
417    /// is available And returns the request and a request stream.
418    pub async fn accept(&mut self) -> Result<Option<(Request<()>, RequestStream<C::BidiStream, B>)>, h3::Error> {
419        match futures_util::future::select(std::pin::pin!(self.incoming.accept()), std::pin::pin!(self.driver.drive())).await
420        {
421            Either::Left((accept, _)) => Ok(accept),
422            Either::Right((drive, _)) => drive.map(|_| None),
423        }
424    }
425}
426
427impl<C, B> Connection<C, B>
428where
429    C: quic::Connection<B>,
430    B: Buf,
431{
432    /// Closes the connection with a code and a reason.
433    pub fn close(&mut self, code: h3::error::Code, reason: &str) -> h3::Error {
434        self.driver.close(code, reason)
435    }
436}
437
438impl<C, B> Incoming<C, B>
439where
440    C: quic::Connection<B>,
441    B: Buf,
442{
443    /// Accept an incoming request
444    pub async fn accept(&mut self) -> Option<(Request<()>, RequestStream<C::BidiStream, B>)> {
445        self.recv.recv().await
446    }
447
448    /// Poll the request acceptor
449    #[allow(clippy::type_complexity)]
450    pub fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Option<(Request<()>, RequestStream<C::BidiStream, B>)>> {
451        self.recv.poll_recv(cx)
452    }
453}
454
455fn validate_wt_connect(request: &Request<()>) -> bool {
456    let protocol = request.extensions().get::<Protocol>();
457    matches!((request.method(), protocol), (&Method::CONNECT, Some(p)) if p == &Protocol::WEB_TRANSPORT)
458}