scuffle_h3_webtransport/
session.rs

1use std::future::poll_fn;
2use std::sync::{Arc, Mutex};
3use std::task::{Context, Poll};
4
5use bytes::Buf;
6use futures_util::future::BoxFuture;
7use h3::quic::{self, OpenStreams};
8use h3::server::RequestStream;
9use h3::stream::BufRecvStream;
10use h3::webtransport::SessionId;
11use http::{Request, Response, StatusCode};
12use tokio::sync::{mpsc, oneshot};
13
14use crate::server::{WebTransportCanUpgrade, WebTransportRequest, WebTransportUpgradePending};
15use crate::stream::{BidiStream, RecvStream};
16
17/// WebTransport session driver.
18///
19/// Maintains the session using the underlying HTTP/3 connection.
20///
21/// Similar to [`h3::server::Connection`](https://docs.rs/h3/latest/h3/server/struct.Connection.html) it is generic over the QUIC implementation and Buffer.
22pub struct WebTransportSession<C, B>
23where
24    C: quic::Connection<B>,
25    B: Buf,
26{
27    // See: https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-2-3
28    connect_stream: RequestStream<C::BidiStream, B>,
29    webtransport_request_tx: mpsc::Sender<WebTransportRequest<C, B>>,
30    session_close_tx: mpsc::UnboundedSender<SessionId>,
31    inner: Mutex<Inner<C, B>>,
32}
33
34struct Inner<C, B>
35where
36    C: quic::Connection<B>,
37    B: Buf,
38{
39    opener: C::OpenStreams,
40    bidi_request_rx: mpsc::Receiver<BidiStream<C::BidiStream, B>>,
41    uni_request_rx: mpsc::Receiver<RecvStream<C::RecvStream, B>>,
42    datagram_request_rx: mpsc::Receiver<B>,
43}
44
45impl<C, B> Drop for WebTransportSession<C, B>
46where
47    C: quic::Connection<B>,
48    B: Buf,
49{
50    fn drop(&mut self) {
51        self.session_close_tx.send(self.session_id()).ok();
52        self.connect_stream.stop_sending(h3::error::Code::H3_NO_ERROR);
53        self.connect_stream.stop_stream(h3::error::Code::H3_NO_ERROR);
54    }
55}
56
57impl<C, B> WebTransportSession<C, B>
58where
59    C: quic::Connection<B>,
60    B: Buf,
61{
62    async fn new(
63        mut stream: RequestStream<C::BidiStream, B>,
64        webtransport_request_tx: mpsc::Sender<WebTransportRequest<C, B>>,
65    ) -> Result<Option<Self>, h3::Error> {
66        let (bidi_tx, bidi_rx) = mpsc::channel(16);
67        let (uni_tx, uni_rx) = mpsc::channel(16);
68        let (datagram_tx, datagram_rx) = mpsc::channel(16);
69
70        let (tx, rx) = oneshot::channel();
71
72        if webtransport_request_tx
73            .send(WebTransportRequest::Upgrade {
74                session_id: stream.id().into(),
75                bidi_request: bidi_tx,
76                uni_request: uni_tx,
77                datagram_request: datagram_tx,
78                response: tx,
79            })
80            .await
81            .is_err()
82        {
83            stream
84                .send_response(
85                    http::Response::builder()
86                        .status(StatusCode::INTERNAL_SERVER_ERROR)
87                        .body(())
88                        .unwrap(),
89                )
90                .await?;
91            return Ok(None);
92        }
93
94        let Ok((opener, session_close_tx)) = rx.await else {
95            stream
96                .send_response(
97                    http::Response::builder()
98                        .status(StatusCode::INTERNAL_SERVER_ERROR)
99                        .body(())
100                        .unwrap(),
101                )
102                .await?;
103            return Ok(None);
104        };
105
106        stream
107            .send_response(
108                http::Response::builder()
109                    // This is the only header that chrome cares about.
110                    .header("sec-webtransport-http3-draft", "draft02")
111                    .status(StatusCode::OK)
112                    .body(())
113                    .unwrap(),
114            )
115            .await?;
116
117        Ok(Some(Self {
118            connect_stream: stream,
119            webtransport_request_tx,
120            session_close_tx,
121            inner: Mutex::new(Inner {
122                opener,
123                bidi_request_rx: bidi_rx,
124                uni_request_rx: uni_rx,
125                datagram_request_rx: datagram_rx,
126            }),
127        }))
128    }
129
130    /// Returns the session id
131    pub fn session_id(&self) -> SessionId {
132        self.connect_stream.id().into()
133    }
134
135    /// Accepts a bidi stream
136    pub async fn accept_bi(&self) -> Option<BidiStream<C::BidiStream, B>> {
137        poll_fn(|cx| self.poll_accept_bi(cx)).await
138    }
139
140    /// Polls for a bidi stream
141    pub fn poll_accept_bi(&self, cx: &mut Context<'_>) -> Poll<Option<BidiStream<C::BidiStream, B>>> {
142        self.inner.lock().unwrap().bidi_request_rx.poll_recv(cx)
143    }
144
145    /// Accepts a uni stream
146    pub async fn accept_uni(&self) -> Option<RecvStream<C::RecvStream, B>> {
147        poll_fn(|cx| self.poll_accept_uni(cx)).await
148    }
149
150    /// Polls for a uni stream
151    pub fn poll_accept_uni(&self, cx: &mut Context<'_>) -> Poll<Option<RecvStream<C::RecvStream, B>>> {
152        self.inner.lock().unwrap().uni_request_rx.poll_recv(cx)
153    }
154
155    /// Accepts a datagram
156    pub async fn accept_datagram(&self) -> Option<B> {
157        poll_fn(|cx| self.poll_accept_datagram(cx)).await
158    }
159
160    /// Polls for a datagram
161    pub fn poll_accept_datagram(&self, cx: &mut Context<'_>) -> Poll<Option<B>> {
162        self.inner.lock().unwrap().datagram_request_rx.poll_recv(cx)
163    }
164
165    /// Sends a datagram
166    pub async fn send_datagram(&self, datagram: B) -> Result<(), h3::Error> {
167        let (tx, rx) = oneshot::channel();
168        self.webtransport_request_tx
169            .send(WebTransportRequest::SendDatagram {
170                session_id: self.session_id(),
171                datagram,
172                resp: tx,
173            })
174            .await
175            .ok();
176
177        match rx.await {
178            Ok(Ok(())) => Ok(()),
179            Ok(Err(e)) => Err(e),
180            // If the channel is closed, we can ignore the error
181            Err(_) => Ok(()),
182        }
183    }
184
185    /// Opens a bidi stream
186    pub async fn open_bi(
187        &self,
188    ) -> Result<crate::stream::BidiStream<C::BidiStream, B>, <C::OpenStreams as OpenStreams<B>>::OpenError> {
189        poll_fn(|cx| self.poll_open_bi(cx)).await
190    }
191
192    /// Polls to open a bidi stream
193    #[allow(clippy::type_complexity)]
194    pub fn poll_open_bi(
195        &self,
196        cx: &mut Context<'_>,
197    ) -> Poll<Result<crate::stream::BidiStream<C::BidiStream, B>, <C::OpenStreams as OpenStreams<B>>::OpenError>> {
198        self.inner
199            .lock()
200            .unwrap()
201            .opener
202            .poll_open_bidi(cx)
203            .map(|res| res.map(|stream| crate::stream::BidiStream::new(BufRecvStream::new(stream))))
204    }
205
206    /// Opens a uni stream
207    pub async fn open_uni(
208        &self,
209    ) -> Result<crate::stream::SendStream<C::SendStream, B>, <C::OpenStreams as OpenStreams<B>>::OpenError> {
210        poll_fn(|cx| self.poll_open_uni(cx)).await
211    }
212
213    /// Polls to open a uni stream
214    #[allow(clippy::type_complexity)]
215    pub fn poll_open_uni(
216        &self,
217        cx: &mut Context<'_>,
218    ) -> Poll<Result<crate::stream::SendStream<C::SendStream, B>, <C::OpenStreams as OpenStreams<B>>::OpenError>> {
219        self.inner
220            .lock()
221            .unwrap()
222            .opener
223            .poll_open_send(cx)
224            .map(|res| res.map(|stream| crate::stream::SendStream::new(BufRecvStream::new(stream))))
225    }
226}
227
228impl<C, B> WebTransportSession<C, B>
229where
230    C: quic::Connection<B>,
231    B: Buf,
232{
233    /// Begin a WebTransport session upgrade
234    pub fn begin<B2, F, Fut>(request: &mut Request<B2>, on_upgrade: F) -> Option<http::Response<()>>
235    where
236        C: quic::Connection<B> + 'static,
237        B: Buf + 'static + Send + Sync,
238        C::AcceptError: Send + Sync,
239        C::BidiStream: Send + Sync,
240        C::RecvStream: Send + Sync,
241        C::OpenStreams: Send + Sync,
242        Fut: std::future::Future<Output = ()> + Send + Sync + 'static,
243        F: FnOnce(WebTransportSession<C, B>) -> Fut + Send + Sync + 'static,
244    {
245        let can_upgrade = request.extensions_mut().remove::<WebTransportCanUpgrade<C, B>>()?;
246
247        let resp = Response::builder()
248            .extension(WebTransportUpgradePending::<C, B> {
249                complete_upgrade: Arc::new(Mutex::new(Some(Box::new(move |stream| {
250                    Box::pin(async move {
251                        let Some(session) = WebTransportSession::new(stream, can_upgrade.webtransport_request_tx).await?
252                        else {
253                            return Ok(());
254                        };
255
256                        on_upgrade(session).await;
257                        Ok(())
258                    })
259                })))),
260            })
261            .status(StatusCode::BAD_REQUEST)
262            .body(())
263            .unwrap();
264
265        Some(resp)
266    }
267
268    /// Completes the WebTransport upgrade
269    #[allow(clippy::type_complexity)]
270    pub fn complete(
271        response: &mut Response<B>,
272        stream: RequestStream<C::BidiStream, B>,
273    ) -> Result<BoxFuture<'static, Result<(), h3::Error>>, RequestStream<C::BidiStream, B>>
274    where
275        C: quic::Connection<B> + 'static,
276        B: Buf + 'static + Send + Sync,
277    {
278        let Some(upgrade_pending) = response.extensions_mut().remove::<WebTransportUpgradePending<C, B>>() else {
279            return Err(stream);
280        };
281
282        upgrade_pending.upgrade(stream)
283    }
284
285    /// Accepts a WebTransport session from an incoming request
286    pub async fn accept<E, E2>(
287        request: &mut Request<()>,
288        mut stream: RequestStream<C::BidiStream, B>,
289    ) -> Result<Option<WebTransportSession<C, B>>, h3::Error>
290    where
291        C: quic::Connection<B> + quic::SendDatagramExt<B, Error = E> + quic::RecvDatagramExt<Buf = B, Error = E2> + 'static,
292        B: Buf + 'static + Send + Sync,
293        C::AcceptError: Send + Sync,
294        C::BidiStream: Send + Sync,
295        C::RecvStream: Send + Sync,
296        C::OpenStreams: Send + Sync,
297        E: Into<h3::Error>,
298        E2: Into<h3::Error>,
299    {
300        let Some(can_upgrade) = request.extensions_mut().remove::<WebTransportCanUpgrade<C, B>>() else {
301            stream
302                .send_response(http::Response::builder().status(StatusCode::BAD_REQUEST).body(()).unwrap())
303                .await?;
304            stream.finish().await?;
305            return Ok(None);
306        };
307
308        WebTransportSession::new(stream, can_upgrade.webtransport_request_tx).await
309    }
310}