1use std::future::poll_fn;
2use std::sync::{Arc, Mutex};
3use std::task::{Context, Poll};
4
5use bytes::Buf;
6use futures_util::future::BoxFuture;
7use h3::quic::{self, OpenStreams};
8use h3::server::RequestStream;
9use h3::stream::BufRecvStream;
10use h3::webtransport::SessionId;
11use http::{Request, Response, StatusCode};
12use tokio::sync::{mpsc, oneshot};
13
14use crate::server::{WebTransportCanUpgrade, WebTransportRequest, WebTransportUpgradePending};
15use crate::stream::{BidiStream, RecvStream};
16
17pub struct WebTransportSession<C, B>
23where
24 C: quic::Connection<B>,
25 B: Buf,
26{
27 connect_stream: RequestStream<C::BidiStream, B>,
29 webtransport_request_tx: mpsc::Sender<WebTransportRequest<C, B>>,
30 session_close_tx: mpsc::UnboundedSender<SessionId>,
31 inner: Mutex<Inner<C, B>>,
32}
33
34struct Inner<C, B>
35where
36 C: quic::Connection<B>,
37 B: Buf,
38{
39 opener: C::OpenStreams,
40 bidi_request_rx: mpsc::Receiver<BidiStream<C::BidiStream, B>>,
41 uni_request_rx: mpsc::Receiver<RecvStream<C::RecvStream, B>>,
42 datagram_request_rx: mpsc::Receiver<B>,
43}
44
45impl<C, B> Drop for WebTransportSession<C, B>
46where
47 C: quic::Connection<B>,
48 B: Buf,
49{
50 fn drop(&mut self) {
51 self.session_close_tx.send(self.session_id()).ok();
52 self.connect_stream.stop_sending(h3::error::Code::H3_NO_ERROR);
53 self.connect_stream.stop_stream(h3::error::Code::H3_NO_ERROR);
54 }
55}
56
57impl<C, B> WebTransportSession<C, B>
58where
59 C: quic::Connection<B>,
60 B: Buf,
61{
62 async fn new(
63 mut stream: RequestStream<C::BidiStream, B>,
64 webtransport_request_tx: mpsc::Sender<WebTransportRequest<C, B>>,
65 ) -> Result<Option<Self>, h3::Error> {
66 let (bidi_tx, bidi_rx) = mpsc::channel(16);
67 let (uni_tx, uni_rx) = mpsc::channel(16);
68 let (datagram_tx, datagram_rx) = mpsc::channel(16);
69
70 let (tx, rx) = oneshot::channel();
71
72 if webtransport_request_tx
73 .send(WebTransportRequest::Upgrade {
74 session_id: stream.id().into(),
75 bidi_request: bidi_tx,
76 uni_request: uni_tx,
77 datagram_request: datagram_tx,
78 response: tx,
79 })
80 .await
81 .is_err()
82 {
83 stream
84 .send_response(
85 http::Response::builder()
86 .status(StatusCode::INTERNAL_SERVER_ERROR)
87 .body(())
88 .unwrap(),
89 )
90 .await?;
91 return Ok(None);
92 }
93
94 let Ok((opener, session_close_tx)) = rx.await else {
95 stream
96 .send_response(
97 http::Response::builder()
98 .status(StatusCode::INTERNAL_SERVER_ERROR)
99 .body(())
100 .unwrap(),
101 )
102 .await?;
103 return Ok(None);
104 };
105
106 stream
107 .send_response(
108 http::Response::builder()
109 .header("sec-webtransport-http3-draft", "draft02")
111 .status(StatusCode::OK)
112 .body(())
113 .unwrap(),
114 )
115 .await?;
116
117 Ok(Some(Self {
118 connect_stream: stream,
119 webtransport_request_tx,
120 session_close_tx,
121 inner: Mutex::new(Inner {
122 opener,
123 bidi_request_rx: bidi_rx,
124 uni_request_rx: uni_rx,
125 datagram_request_rx: datagram_rx,
126 }),
127 }))
128 }
129
130 pub fn session_id(&self) -> SessionId {
132 self.connect_stream.id().into()
133 }
134
135 pub async fn accept_bi(&self) -> Option<BidiStream<C::BidiStream, B>> {
137 poll_fn(|cx| self.poll_accept_bi(cx)).await
138 }
139
140 pub fn poll_accept_bi(&self, cx: &mut Context<'_>) -> Poll<Option<BidiStream<C::BidiStream, B>>> {
142 self.inner.lock().unwrap().bidi_request_rx.poll_recv(cx)
143 }
144
145 pub async fn accept_uni(&self) -> Option<RecvStream<C::RecvStream, B>> {
147 poll_fn(|cx| self.poll_accept_uni(cx)).await
148 }
149
150 pub fn poll_accept_uni(&self, cx: &mut Context<'_>) -> Poll<Option<RecvStream<C::RecvStream, B>>> {
152 self.inner.lock().unwrap().uni_request_rx.poll_recv(cx)
153 }
154
155 pub async fn accept_datagram(&self) -> Option<B> {
157 poll_fn(|cx| self.poll_accept_datagram(cx)).await
158 }
159
160 pub fn poll_accept_datagram(&self, cx: &mut Context<'_>) -> Poll<Option<B>> {
162 self.inner.lock().unwrap().datagram_request_rx.poll_recv(cx)
163 }
164
165 pub async fn send_datagram(&self, datagram: B) -> Result<(), h3::Error> {
167 let (tx, rx) = oneshot::channel();
168 self.webtransport_request_tx
169 .send(WebTransportRequest::SendDatagram {
170 session_id: self.session_id(),
171 datagram,
172 resp: tx,
173 })
174 .await
175 .ok();
176
177 match rx.await {
178 Ok(Ok(())) => Ok(()),
179 Ok(Err(e)) => Err(e),
180 Err(_) => Ok(()),
182 }
183 }
184
185 pub async fn open_bi(
187 &self,
188 ) -> Result<crate::stream::BidiStream<C::BidiStream, B>, <C::OpenStreams as OpenStreams<B>>::OpenError> {
189 poll_fn(|cx| self.poll_open_bi(cx)).await
190 }
191
192 #[allow(clippy::type_complexity)]
194 pub fn poll_open_bi(
195 &self,
196 cx: &mut Context<'_>,
197 ) -> Poll<Result<crate::stream::BidiStream<C::BidiStream, B>, <C::OpenStreams as OpenStreams<B>>::OpenError>> {
198 self.inner
199 .lock()
200 .unwrap()
201 .opener
202 .poll_open_bidi(cx)
203 .map(|res| res.map(|stream| crate::stream::BidiStream::new(BufRecvStream::new(stream))))
204 }
205
206 pub async fn open_uni(
208 &self,
209 ) -> Result<crate::stream::SendStream<C::SendStream, B>, <C::OpenStreams as OpenStreams<B>>::OpenError> {
210 poll_fn(|cx| self.poll_open_uni(cx)).await
211 }
212
213 #[allow(clippy::type_complexity)]
215 pub fn poll_open_uni(
216 &self,
217 cx: &mut Context<'_>,
218 ) -> Poll<Result<crate::stream::SendStream<C::SendStream, B>, <C::OpenStreams as OpenStreams<B>>::OpenError>> {
219 self.inner
220 .lock()
221 .unwrap()
222 .opener
223 .poll_open_send(cx)
224 .map(|res| res.map(|stream| crate::stream::SendStream::new(BufRecvStream::new(stream))))
225 }
226}
227
228impl<C, B> WebTransportSession<C, B>
229where
230 C: quic::Connection<B>,
231 B: Buf,
232{
233 pub fn begin<B2, F, Fut>(request: &mut Request<B2>, on_upgrade: F) -> Option<http::Response<()>>
235 where
236 C: quic::Connection<B> + 'static,
237 B: Buf + 'static + Send + Sync,
238 C::AcceptError: Send + Sync,
239 C::BidiStream: Send + Sync,
240 C::RecvStream: Send + Sync,
241 C::OpenStreams: Send + Sync,
242 Fut: std::future::Future<Output = ()> + Send + Sync + 'static,
243 F: FnOnce(WebTransportSession<C, B>) -> Fut + Send + Sync + 'static,
244 {
245 let can_upgrade = request.extensions_mut().remove::<WebTransportCanUpgrade<C, B>>()?;
246
247 let resp = Response::builder()
248 .extension(WebTransportUpgradePending::<C, B> {
249 complete_upgrade: Arc::new(Mutex::new(Some(Box::new(move |stream| {
250 Box::pin(async move {
251 let Some(session) = WebTransportSession::new(stream, can_upgrade.webtransport_request_tx).await?
252 else {
253 return Ok(());
254 };
255
256 on_upgrade(session).await;
257 Ok(())
258 })
259 })))),
260 })
261 .status(StatusCode::BAD_REQUEST)
262 .body(())
263 .unwrap();
264
265 Some(resp)
266 }
267
268 #[allow(clippy::type_complexity)]
270 pub fn complete(
271 response: &mut Response<B>,
272 stream: RequestStream<C::BidiStream, B>,
273 ) -> Result<BoxFuture<'static, Result<(), h3::Error>>, RequestStream<C::BidiStream, B>>
274 where
275 C: quic::Connection<B> + 'static,
276 B: Buf + 'static + Send + Sync,
277 {
278 let Some(upgrade_pending) = response.extensions_mut().remove::<WebTransportUpgradePending<C, B>>() else {
279 return Err(stream);
280 };
281
282 upgrade_pending.upgrade(stream)
283 }
284
285 pub async fn accept<E, E2>(
287 request: &mut Request<()>,
288 mut stream: RequestStream<C::BidiStream, B>,
289 ) -> Result<Option<WebTransportSession<C, B>>, h3::Error>
290 where
291 C: quic::Connection<B> + quic::SendDatagramExt<B, Error = E> + quic::RecvDatagramExt<Buf = B, Error = E2> + 'static,
292 B: Buf + 'static + Send + Sync,
293 C::AcceptError: Send + Sync,
294 C::BidiStream: Send + Sync,
295 C::RecvStream: Send + Sync,
296 C::OpenStreams: Send + Sync,
297 E: Into<h3::Error>,
298 E2: Into<h3::Error>,
299 {
300 let Some(can_upgrade) = request.extensions_mut().remove::<WebTransportCanUpgrade<C, B>>() else {
301 stream
302 .send_response(http::Response::builder().status(StatusCode::BAD_REQUEST).body(()).unwrap())
303 .await?;
304 stream.finish().await?;
305 return Ok(None);
306 };
307
308 WebTransportSession::new(stream, can_upgrade.webtransport_request_tx).await
309 }
310}