scuffle_h3_webtransport/
stream.rs

1use std::task::Poll;
2
3use bytes::{Buf, Bytes};
4use h3::quic;
5use h3::stream::BufRecvStream;
6use pin_project_lite::pin_project;
7use tokio::io::ReadBuf;
8
9pin_project! {
10    /// WebTransport receive stream
11    pub struct RecvStream<S,B> {
12        #[pin]
13        stream: BufRecvStream<S, B>,
14    }
15}
16
17impl<S, B> RecvStream<S, B> {
18    #[allow(missing_docs)]
19    pub fn new(stream: BufRecvStream<S, B>) -> Self {
20        Self { stream }
21    }
22}
23
24impl<S, B> quic::RecvStream for RecvStream<S, B>
25where
26    S: quic::RecvStream,
27    B: Buf,
28{
29    type Buf = Bytes;
30    type Error = S::Error;
31
32    fn poll_data(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
33        self.stream.poll_data(cx)
34    }
35
36    fn stop_sending(&mut self, error_code: u64) {
37        self.stream.stop_sending(error_code)
38    }
39
40    fn recv_id(&self) -> quic::StreamId {
41        self.stream.recv_id()
42    }
43}
44
45impl<S, B> futures_util::io::AsyncRead for RecvStream<S, B>
46where
47    BufRecvStream<S, B>: futures_util::io::AsyncRead,
48{
49    fn poll_read(
50        self: std::pin::Pin<&mut Self>,
51        cx: &mut std::task::Context<'_>,
52        buf: &mut [u8],
53    ) -> Poll<std::io::Result<usize>> {
54        let p = self.project();
55        p.stream.poll_read(cx, buf)
56    }
57}
58
59impl<S, B> tokio::io::AsyncRead for RecvStream<S, B>
60where
61    BufRecvStream<S, B>: tokio::io::AsyncRead,
62{
63    fn poll_read(
64        self: std::pin::Pin<&mut Self>,
65        cx: &mut std::task::Context<'_>,
66        buf: &mut ReadBuf<'_>,
67    ) -> Poll<std::io::Result<()>> {
68        let p = self.project();
69        p.stream.poll_read(cx, buf)
70    }
71}
72
73pin_project! {
74    /// WebTransport send stream
75    pub struct SendStream<S,B> {
76        #[pin]
77        stream: BufRecvStream<S ,B>,
78    }
79}
80
81impl<S, B> std::fmt::Debug for SendStream<S, B> {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("SendStream").field("stream", &self.stream).finish()
84    }
85}
86
87impl<S, B> SendStream<S, B> {
88    #[allow(missing_docs)]
89    pub(crate) fn new(stream: BufRecvStream<S, B>) -> Self {
90        Self { stream }
91    }
92}
93
94impl<S, B> quic::SendStreamUnframed<B> for SendStream<S, B>
95where
96    S: quic::SendStreamUnframed<B>,
97    B: Buf,
98{
99    fn poll_send<D: Buf>(&mut self, cx: &mut std::task::Context<'_>, buf: &mut D) -> Poll<Result<usize, Self::Error>> {
100        self.stream.poll_send(cx, buf)
101    }
102}
103
104impl<S, B> quic::SendStream<B> for SendStream<S, B>
105where
106    S: quic::SendStream<B>,
107    B: Buf,
108{
109    type Error = S::Error;
110
111    fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
112        self.stream.poll_finish(cx)
113    }
114
115    fn reset(&mut self, reset_code: u64) {
116        self.stream.reset(reset_code)
117    }
118
119    fn send_id(&self) -> quic::StreamId {
120        self.stream.send_id()
121    }
122
123    fn send_data<T: Into<h3::stream::WriteBuf<B>>>(&mut self, data: T) -> Result<(), Self::Error> {
124        self.stream.send_data(data)
125    }
126
127    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
128        self.stream.poll_ready(cx)
129    }
130}
131
132impl<S, B> futures_util::io::AsyncWrite for SendStream<S, B>
133where
134    BufRecvStream<S, B>: futures_util::io::AsyncWrite,
135{
136    fn poll_write(
137        self: std::pin::Pin<&mut Self>,
138        cx: &mut std::task::Context<'_>,
139        buf: &[u8],
140    ) -> Poll<std::io::Result<usize>> {
141        let p = self.project();
142        p.stream.poll_write(cx, buf)
143    }
144
145    fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<std::io::Result<()>> {
146        let p = self.project();
147        p.stream.poll_flush(cx)
148    }
149
150    fn poll_close(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<std::io::Result<()>> {
151        let p = self.project();
152        p.stream.poll_close(cx)
153    }
154}
155
156impl<S, B> tokio::io::AsyncWrite for SendStream<S, B>
157where
158    BufRecvStream<S, B>: tokio::io::AsyncWrite,
159{
160    fn poll_write(
161        self: std::pin::Pin<&mut Self>,
162        cx: &mut std::task::Context<'_>,
163        buf: &[u8],
164    ) -> Poll<std::io::Result<usize>> {
165        let p = self.project();
166        p.stream.poll_write(cx, buf)
167    }
168
169    fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<std::io::Result<()>> {
170        let p = self.project();
171        p.stream.poll_flush(cx)
172    }
173
174    fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<std::io::Result<()>> {
175        let p = self.project();
176        p.stream.poll_shutdown(cx)
177    }
178}
179
180pin_project! {
181    /// Combined send and receive stream.
182    ///
183    /// Can be split into a [`RecvStream`] and [`SendStream`] if the underlying QUIC implementation
184    /// supports it.
185    pub struct BidiStream<S, B> {
186        #[pin]
187        stream: BufRecvStream<S, B>,
188    }
189}
190
191impl<S, B> BidiStream<S, B> {
192    pub(crate) fn new(stream: BufRecvStream<S, B>) -> Self {
193        Self { stream }
194    }
195}
196
197impl<S, B> quic::SendStream<B> for BidiStream<S, B>
198where
199    S: quic::SendStream<B>,
200    B: Buf,
201{
202    type Error = S::Error;
203
204    fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
205        self.stream.poll_finish(cx)
206    }
207
208    fn reset(&mut self, reset_code: u64) {
209        self.stream.reset(reset_code)
210    }
211
212    fn send_id(&self) -> quic::StreamId {
213        self.stream.send_id()
214    }
215
216    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
217        self.stream.poll_ready(cx)
218    }
219
220    fn send_data<T: Into<h3::stream::WriteBuf<B>>>(&mut self, data: T) -> Result<(), Self::Error> {
221        self.stream.send_data(data)
222    }
223}
224
225impl<S, B> quic::SendStreamUnframed<B> for BidiStream<S, B>
226where
227    S: quic::SendStreamUnframed<B>,
228    B: Buf,
229{
230    fn poll_send<D: Buf>(&mut self, cx: &mut std::task::Context<'_>, buf: &mut D) -> Poll<Result<usize, Self::Error>> {
231        self.stream.poll_send(cx, buf)
232    }
233}
234
235impl<S: quic::RecvStream, B> quic::RecvStream for BidiStream<S, B> {
236    type Buf = Bytes;
237    type Error = S::Error;
238
239    fn poll_data(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
240        self.stream.poll_data(cx)
241    }
242
243    fn stop_sending(&mut self, error_code: u64) {
244        self.stream.stop_sending(error_code)
245    }
246
247    fn recv_id(&self) -> quic::StreamId {
248        self.stream.recv_id()
249    }
250}
251
252impl<S, B> quic::BidiStream<B> for BidiStream<S, B>
253where
254    S: quic::BidiStream<B>,
255    B: Buf,
256{
257    type RecvStream = RecvStream<S::RecvStream, B>;
258    type SendStream = SendStream<S::SendStream, B>;
259
260    fn split(self) -> (Self::SendStream, Self::RecvStream) {
261        let (send, recv) = self.stream.split();
262        (SendStream::new(send), RecvStream::new(recv))
263    }
264}
265
266impl<S, B> futures_util::io::AsyncRead for BidiStream<S, B>
267where
268    BufRecvStream<S, B>: futures_util::io::AsyncRead,
269{
270    fn poll_read(
271        self: std::pin::Pin<&mut Self>,
272        cx: &mut std::task::Context<'_>,
273        buf: &mut [u8],
274    ) -> Poll<std::io::Result<usize>> {
275        let p = self.project();
276        p.stream.poll_read(cx, buf)
277    }
278}
279
280impl<S, B> futures_util::io::AsyncWrite for BidiStream<S, B>
281where
282    BufRecvStream<S, B>: futures_util::io::AsyncWrite,
283{
284    fn poll_write(
285        self: std::pin::Pin<&mut Self>,
286        cx: &mut std::task::Context<'_>,
287        buf: &[u8],
288    ) -> Poll<std::io::Result<usize>> {
289        let p = self.project();
290        p.stream.poll_write(cx, buf)
291    }
292
293    fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<std::io::Result<()>> {
294        let p = self.project();
295        p.stream.poll_flush(cx)
296    }
297
298    fn poll_close(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<std::io::Result<()>> {
299        let p = self.project();
300        p.stream.poll_close(cx)
301    }
302}
303
304impl<S, B> tokio::io::AsyncRead for BidiStream<S, B>
305where
306    BufRecvStream<S, B>: tokio::io::AsyncRead,
307{
308    fn poll_read(
309        self: std::pin::Pin<&mut Self>,
310        cx: &mut std::task::Context<'_>,
311        buf: &mut ReadBuf<'_>,
312    ) -> Poll<std::io::Result<()>> {
313        let p = self.project();
314        p.stream.poll_read(cx, buf)
315    }
316}
317
318impl<S, B> tokio::io::AsyncWrite for BidiStream<S, B>
319where
320    BufRecvStream<S, B>: tokio::io::AsyncWrite,
321{
322    fn poll_write(
323        self: std::pin::Pin<&mut Self>,
324        cx: &mut std::task::Context<'_>,
325        buf: &[u8],
326    ) -> Poll<std::io::Result<usize>> {
327        let p = self.project();
328        p.stream.poll_write(cx, buf)
329    }
330
331    fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<std::io::Result<()>> {
332        let p = self.project();
333        p.stream.poll_flush(cx)
334    }
335
336    fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<std::io::Result<()>> {
337        let p = self.project();
338        p.stream.poll_shutdown(cx)
339    }
340}