1use std::io;
2
3use scuffle_bytes_util::BitReader;
4use utils::read_leb128;
5
6pub mod seq;
7mod utils;
8
9#[derive(Debug, Clone, PartialEq, Eq, Copy)]
12pub struct ObuHeader {
13 pub obu_type: ObuType,
14 pub size: Option<u64>,
15 pub extension_header: Option<ObuExtensionHeader>,
16}
17
18#[derive(Debug, Clone, PartialEq, Eq, Copy)]
21pub struct ObuExtensionHeader {
22 pub temporal_id: u8,
23 pub spatial_id: u8,
24}
25
26impl ObuHeader {
27 pub fn parse(cursor: &mut impl io::Read) -> io::Result<Self> {
28 let mut bit_reader = BitReader::new(cursor);
29 let forbidden_bit = bit_reader.read_bit()?;
30 if forbidden_bit {
31 return Err(io::Error::new(io::ErrorKind::InvalidData, "obu_forbidden_bit is not 0"));
32 }
33
34 let obu_type = bit_reader.read_bits(4)?;
35 let extension_flag = bit_reader.read_bit()?;
36 let has_size_field = bit_reader.read_bit()?;
37
38 bit_reader.read_bit()?; let extension_header = if extension_flag {
41 let temporal_id = bit_reader.read_bits(3)?;
42 let spatial_id = bit_reader.read_bits(2)?;
43 bit_reader.read_bits(3)?; Some(ObuExtensionHeader {
45 temporal_id: temporal_id as u8,
46 spatial_id: spatial_id as u8,
47 })
48 } else {
49 None
50 };
51
52 let size = if has_size_field {
53 Some(read_leb128(&mut bit_reader)?)
55 } else {
56 None
57 };
58
59 if !bit_reader.is_aligned() {
60 return Err(io::Error::new(io::ErrorKind::InvalidData, "bit reader is not aligned"));
61 }
62
63 Ok(ObuHeader {
64 obu_type: ObuType::from(obu_type as u8),
65 size,
66 extension_header,
67 })
68 }
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Copy)]
74pub enum ObuType {
75 SequenceHeader,
76 TemporalDelimiter,
77 FrameHeader,
78 TileGroup,
79 Metadata,
80 Frame,
81 RedundantFrameHeader,
82 TileList,
83 Padding,
84 Reserved(u8),
85}
86
87impl From<u8> for ObuType {
88 fn from(value: u8) -> Self {
89 match value {
90 1 => ObuType::SequenceHeader,
91 2 => ObuType::TemporalDelimiter,
92 3 => ObuType::FrameHeader,
93 4 => ObuType::TileGroup,
94 5 => ObuType::Metadata,
95 6 => ObuType::Frame,
96 7 => ObuType::RedundantFrameHeader,
97 8 => ObuType::TileList,
98 15 => ObuType::Padding,
99 _ => ObuType::Reserved(value),
100 }
101 }
102}
103
104impl From<ObuType> for u8 {
105 fn from(value: ObuType) -> Self {
106 match value {
107 ObuType::SequenceHeader => 1,
108 ObuType::TemporalDelimiter => 2,
109 ObuType::FrameHeader => 3,
110 ObuType::TileGroup => 4,
111 ObuType::Metadata => 5,
112 ObuType::Frame => 6,
113 ObuType::RedundantFrameHeader => 7,
114 ObuType::TileList => 8,
115 ObuType::Padding => 15,
116 ObuType::Reserved(value) => value,
117 }
118 }
119}
120
121#[cfg(test)]
122#[cfg_attr(all(coverage_nightly, test), coverage(off))]
123mod tests {
124 use bytes::Buf;
125
126 use super::*;
127
128 #[test]
129 fn test_obu_header_parse() {
130 let mut cursor = std::io::Cursor::new(b"\n\x0f\0\0\0j\xef\xbf\xe1\xbc\x02\x19\x90\x10\x10\x10@");
131 let header = ObuHeader::parse(&mut cursor).unwrap();
132 insta::assert_debug_snapshot!(header, @r"
133 ObuHeader {
134 obu_type: SequenceHeader,
135 size: Some(
136 15,
137 ),
138 extension_header: None,
139 }
140 ");
141
142 assert_eq!(cursor.position(), 2);
143 assert_eq!(cursor.remaining(), 15);
144 }
145
146 #[test]
147 fn test_obu_header_parse_no_size_field() {
148 let mut cursor = std::io::Cursor::new(b"\x00");
149 let header = ObuHeader::parse(&mut cursor).unwrap();
150 insta::assert_debug_snapshot!(header, @r"
151 ObuHeader {
152 obu_type: Reserved(
153 0,
154 ),
155 size: None,
156 extension_header: None,
157 }
158 ");
159
160 assert_eq!(cursor.position(), 1);
161 assert_eq!(cursor.remaining(), 0);
162 }
163
164 #[test]
165 fn test_obu_header_parse_extension_header() {
166 let mut cursor = std::io::Cursor::new([0b00000100, 0b11010000]);
167 let header = ObuHeader::parse(&mut cursor).unwrap();
168 insta::assert_debug_snapshot!(header, @r"
169 ObuHeader {
170 obu_type: Reserved(
171 0,
172 ),
173 size: None,
174 extension_header: Some(
175 ObuExtensionHeader {
176 temporal_id: 6,
177 spatial_id: 2,
178 },
179 ),
180 }
181 ");
182
183 assert_eq!(cursor.position(), 2);
184 assert_eq!(cursor.remaining(), 0);
185 }
186
187 #[test]
188 fn test_obu_header_forbidden_bit_set() {
189 let err = ObuHeader::parse(&mut std::io::Cursor::new(
190 b"\xff\x0f\0\0\0j\xef\xbf\xe1\xbc\x02\x19\x90\x10\x10\x10@",
191 ))
192 .unwrap_err();
193 insta::assert_debug_snapshot!(err, @r#"
194 Custom {
195 kind: InvalidData,
196 error: "obu_forbidden_bit is not 0",
197 }
198 "#);
199 }
200
201 #[test]
202 fn test_obu_to_from_u8() {
203 let case = [
204 (ObuType::SequenceHeader, 1),
205 (ObuType::TemporalDelimiter, 2),
206 (ObuType::FrameHeader, 3),
207 (ObuType::TileGroup, 4),
208 (ObuType::Metadata, 5),
209 (ObuType::Frame, 6),
210 (ObuType::RedundantFrameHeader, 7),
211 (ObuType::TileList, 8),
212 (ObuType::Padding, 15),
213 (ObuType::Reserved(0), 0),
214 (ObuType::Reserved(100), 100),
215 ];
216
217 for (obu_type, value) in case {
218 assert_eq!(u8::from(obu_type), value);
219 assert_eq!(ObuType::from(value), obu_type);
220 }
221 }
222}