scuffle_http/
body.rs

1use std::fmt::Debug;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use bytes::{Buf, Bytes};
6use http_body::Frame;
7
8/// An error that can occur when reading the body of an incoming request.
9#[derive(thiserror::Error, Debug)]
10pub enum IncomingBodyError {
11    #[error("hyper error: {0}")]
12    #[cfg(any(feature = "http1", feature = "http2"))]
13    #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
14    Hyper(#[from] hyper::Error),
15    #[error("quic error: {0}")]
16    #[cfg(feature = "http3")]
17    #[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
18    Quic(#[from] h3::Error),
19}
20
21/// The body of an incoming request.
22///
23/// This enum is used to abstract away the differences between the body types of HTTP/1, HTTP/2 and HTTP/3.
24/// It implements the [`http_body::Body`] trait.
25pub enum IncomingBody {
26    #[cfg(any(feature = "http1", feature = "http2"))]
27    #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
28    Hyper(hyper::body::Incoming),
29    #[cfg(feature = "http3")]
30    #[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
31    Quic(crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>),
32}
33
34#[cfg(any(feature = "http1", feature = "http2"))]
35#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
36impl From<hyper::body::Incoming> for IncomingBody {
37    fn from(body: hyper::body::Incoming) -> Self {
38        IncomingBody::Hyper(body)
39    }
40}
41
42#[cfg(feature = "http3")]
43#[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
44impl From<crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>> for IncomingBody {
45    fn from(body: crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>) -> Self {
46        IncomingBody::Quic(body)
47    }
48}
49
50impl http_body::Body for IncomingBody {
51    type Data = Bytes;
52    type Error = IncomingBodyError;
53
54    fn is_end_stream(&self) -> bool {
55        match self {
56            #[cfg(any(feature = "http1", feature = "http2"))]
57            IncomingBody::Hyper(body) => body.is_end_stream(),
58            #[cfg(feature = "http3")]
59            IncomingBody::Quic(body) => body.is_end_stream(),
60            #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
61            _ => false,
62        }
63    }
64
65    fn poll_frame(
66        self: std::pin::Pin<&mut Self>,
67        _cx: &mut std::task::Context<'_>,
68    ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
69        match self.get_mut() {
70            #[cfg(any(feature = "http1", feature = "http2"))]
71            IncomingBody::Hyper(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
72            #[cfg(feature = "http3")]
73            IncomingBody::Quic(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
74            #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
75            _ => std::task::Poll::Ready(None),
76        }
77    }
78
79    fn size_hint(&self) -> http_body::SizeHint {
80        match self {
81            #[cfg(any(feature = "http1", feature = "http2"))]
82            IncomingBody::Hyper(body) => body.size_hint(),
83            #[cfg(feature = "http3")]
84            IncomingBody::Quic(body) => body.size_hint(),
85            #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
86            _ => http_body::SizeHint::default(),
87        }
88    }
89}
90
91pin_project_lite::pin_project! {
92    /// A wrapper around an HTTP body that tracks the size of the data that is read from it.
93    pub struct TrackedBody<B, T> {
94        #[pin]
95        body: B,
96        tracker: T,
97    }
98}
99
100impl<B, T> TrackedBody<B, T> {
101    pub fn new(body: B, tracker: T) -> Self {
102        Self { body, tracker }
103    }
104}
105
106/// An error that can occur when tracking the body of an incoming request.
107#[derive(thiserror::Error)]
108pub enum TrackedBodyError<B, T>
109where
110    B: http_body::Body,
111    T: Tracker,
112{
113    #[error("body error: {0}")]
114    Body(B::Error),
115    #[error("tracker error: {0}")]
116    Tracker(T::Error),
117}
118
119impl<B, T> Debug for TrackedBodyError<B, T>
120where
121    B: http_body::Body,
122    B::Error: Debug,
123    T: Tracker,
124    T::Error: Debug,
125{
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        match self {
128            TrackedBodyError::Body(err) => f.debug_tuple("TrackedBodyError::Body").field(err).finish(),
129            TrackedBodyError::Tracker(err) => f.debug_tuple("TrackedBodyError::Tracker").field(err).finish(),
130        }
131    }
132}
133
134/// A trait for tracking the size of the data that is read from an HTTP body.
135pub trait Tracker: Send + Sync + 'static {
136    type Error;
137
138    /// Called when data is read from the body.
139    ///
140    /// The `size` parameter is the size of the data that is remaining to be read from the body.
141    fn on_data(&self, size: usize) -> Result<(), Self::Error> {
142        let _ = size;
143        Ok(())
144    }
145}
146
147impl<B, T> http_body::Body for TrackedBody<B, T>
148where
149    B: http_body::Body,
150    T: Tracker,
151{
152    type Data = B::Data;
153    type Error = TrackedBodyError<B, T>;
154
155    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
156        let this = self.project();
157
158        match this.body.poll_frame(cx) {
159            Poll::Pending => Poll::Pending,
160            Poll::Ready(frame) => {
161                if let Some(Ok(frame)) = &frame {
162                    if let Some(data) = frame.data_ref() {
163                        if let Err(err) = this.tracker.on_data(data.remaining()) {
164                            return Poll::Ready(Some(Err(TrackedBodyError::Tracker(err))));
165                        }
166                    }
167                }
168
169                Poll::Ready(frame.transpose().map_err(TrackedBodyError::Body).transpose())
170            }
171        }
172    }
173
174    fn is_end_stream(&self) -> bool {
175        self.body.is_end_stream()
176    }
177
178    fn size_hint(&self) -> http_body::SizeHint {
179        self.body.size_hint()
180    }
181}
182
183#[cfg(test)]
184#[cfg_attr(all(test, coverage_nightly), coverage(off))]
185mod tests {
186    use std::convert::Infallible;
187
188    use crate::body::TrackedBodyError;
189
190    #[test]
191    fn tracked_body_error_debug() {
192        struct TestTracker;
193
194        impl super::Tracker for TestTracker {
195            type Error = Infallible;
196        }
197
198        struct TestBody;
199
200        impl http_body::Body for TestBody {
201            type Data = bytes::Bytes;
202            type Error = ();
203
204            fn poll_frame(
205                self: std::pin::Pin<&mut Self>,
206                _cx: &mut std::task::Context<'_>,
207            ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
208                std::task::Poll::Ready(None)
209            }
210        }
211
212        let err = TrackedBodyError::<TestBody, TestTracker>::Body(());
213        assert_eq!(format!("{:?}", err), "TrackedBodyError::Body(())",);
214    }
215}