1use std::fmt::Debug;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use bytes::{Buf, Bytes};
6use http_body::Frame;
7
8#[derive(thiserror::Error, Debug)]
10pub enum IncomingBodyError {
11 #[error("hyper error: {0}")]
12 #[cfg(any(feature = "http1", feature = "http2"))]
13 #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
14 Hyper(#[from] hyper::Error),
15 #[error("quic error: {0}")]
16 #[cfg(feature = "http3")]
17 #[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
18 Quic(#[from] h3::Error),
19}
20
21pub enum IncomingBody {
26 #[cfg(any(feature = "http1", feature = "http2"))]
27 #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
28 Hyper(hyper::body::Incoming),
29 #[cfg(feature = "http3")]
30 #[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
31 Quic(crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>),
32}
33
34#[cfg(any(feature = "http1", feature = "http2"))]
35#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
36impl From<hyper::body::Incoming> for IncomingBody {
37 fn from(body: hyper::body::Incoming) -> Self {
38 IncomingBody::Hyper(body)
39 }
40}
41
42#[cfg(feature = "http3")]
43#[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
44impl From<crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>> for IncomingBody {
45 fn from(body: crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>) -> Self {
46 IncomingBody::Quic(body)
47 }
48}
49
50impl http_body::Body for IncomingBody {
51 type Data = Bytes;
52 type Error = IncomingBodyError;
53
54 fn is_end_stream(&self) -> bool {
55 match self {
56 #[cfg(any(feature = "http1", feature = "http2"))]
57 IncomingBody::Hyper(body) => body.is_end_stream(),
58 #[cfg(feature = "http3")]
59 IncomingBody::Quic(body) => body.is_end_stream(),
60 #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
61 _ => false,
62 }
63 }
64
65 fn poll_frame(
66 self: std::pin::Pin<&mut Self>,
67 _cx: &mut std::task::Context<'_>,
68 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
69 match self.get_mut() {
70 #[cfg(any(feature = "http1", feature = "http2"))]
71 IncomingBody::Hyper(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
72 #[cfg(feature = "http3")]
73 IncomingBody::Quic(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
74 #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
75 _ => std::task::Poll::Ready(None),
76 }
77 }
78
79 fn size_hint(&self) -> http_body::SizeHint {
80 match self {
81 #[cfg(any(feature = "http1", feature = "http2"))]
82 IncomingBody::Hyper(body) => body.size_hint(),
83 #[cfg(feature = "http3")]
84 IncomingBody::Quic(body) => body.size_hint(),
85 #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
86 _ => http_body::SizeHint::default(),
87 }
88 }
89}
90
91pin_project_lite::pin_project! {
92 pub struct TrackedBody<B, T> {
94 #[pin]
95 body: B,
96 tracker: T,
97 }
98}
99
100impl<B, T> TrackedBody<B, T> {
101 pub fn new(body: B, tracker: T) -> Self {
102 Self { body, tracker }
103 }
104}
105
106#[derive(thiserror::Error)]
108pub enum TrackedBodyError<B, T>
109where
110 B: http_body::Body,
111 T: Tracker,
112{
113 #[error("body error: {0}")]
114 Body(B::Error),
115 #[error("tracker error: {0}")]
116 Tracker(T::Error),
117}
118
119impl<B, T> Debug for TrackedBodyError<B, T>
120where
121 B: http_body::Body,
122 B::Error: Debug,
123 T: Tracker,
124 T::Error: Debug,
125{
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 match self {
128 TrackedBodyError::Body(err) => f.debug_tuple("TrackedBodyError::Body").field(err).finish(),
129 TrackedBodyError::Tracker(err) => f.debug_tuple("TrackedBodyError::Tracker").field(err).finish(),
130 }
131 }
132}
133
134pub trait Tracker: Send + Sync + 'static {
136 type Error;
137
138 fn on_data(&self, size: usize) -> Result<(), Self::Error> {
142 let _ = size;
143 Ok(())
144 }
145}
146
147impl<B, T> http_body::Body for TrackedBody<B, T>
148where
149 B: http_body::Body,
150 T: Tracker,
151{
152 type Data = B::Data;
153 type Error = TrackedBodyError<B, T>;
154
155 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
156 let this = self.project();
157
158 match this.body.poll_frame(cx) {
159 Poll::Pending => Poll::Pending,
160 Poll::Ready(frame) => {
161 if let Some(Ok(frame)) = &frame {
162 if let Some(data) = frame.data_ref() {
163 if let Err(err) = this.tracker.on_data(data.remaining()) {
164 return Poll::Ready(Some(Err(TrackedBodyError::Tracker(err))));
165 }
166 }
167 }
168
169 Poll::Ready(frame.transpose().map_err(TrackedBodyError::Body).transpose())
170 }
171 }
172 }
173
174 fn is_end_stream(&self) -> bool {
175 self.body.is_end_stream()
176 }
177
178 fn size_hint(&self) -> http_body::SizeHint {
179 self.body.size_hint()
180 }
181}
182
183#[cfg(test)]
184#[cfg_attr(all(test, coverage_nightly), coverage(off))]
185mod tests {
186 use std::convert::Infallible;
187
188 use crate::body::TrackedBodyError;
189
190 #[test]
191 fn tracked_body_error_debug() {
192 struct TestTracker;
193
194 impl super::Tracker for TestTracker {
195 type Error = Infallible;
196 }
197
198 struct TestBody;
199
200 impl http_body::Body for TestBody {
201 type Data = bytes::Bytes;
202 type Error = ();
203
204 fn poll_frame(
205 self: std::pin::Pin<&mut Self>,
206 _cx: &mut std::task::Context<'_>,
207 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
208 std::task::Poll::Ready(None)
209 }
210 }
211
212 let err = TrackedBodyError::<TestBody, TestTracker>::Body(());
213 assert_eq!(format!("{:?}", err), "TrackedBodyError::Body(())",);
214 }
215}