scuffle_av1/obu/
mod.rs

1use std::io;
2
3use scuffle_bytes_util::BitReader;
4use utils::read_leb128;
5
6pub mod seq;
7mod utils;
8
9/// OBU Header
10/// AV1-Spec-2 - 5.3.2
11#[derive(Debug, Clone, PartialEq, Eq, Copy)]
12pub struct ObuHeader {
13    pub obu_type: ObuType,
14    pub size: Option<u64>,
15    pub extension_header: Option<ObuExtensionHeader>,
16}
17
18/// Obu Header Extension
19/// AV1-Spec-2 - 5.3.3
20#[derive(Debug, Clone, PartialEq, Eq, Copy)]
21pub struct ObuExtensionHeader {
22    pub temporal_id: u8,
23    pub spatial_id: u8,
24}
25
26impl ObuHeader {
27    pub fn parse(cursor: &mut impl io::Read) -> io::Result<Self> {
28        let mut bit_reader = BitReader::new(cursor);
29        let forbidden_bit = bit_reader.read_bit()?;
30        if forbidden_bit {
31            return Err(io::Error::new(io::ErrorKind::InvalidData, "obu_forbidden_bit is not 0"));
32        }
33
34        let obu_type = bit_reader.read_bits(4)?;
35        let extension_flag = bit_reader.read_bit()?;
36        let has_size_field = bit_reader.read_bit()?;
37
38        bit_reader.read_bit()?; // reserved_1bit
39
40        let extension_header = if extension_flag {
41            let temporal_id = bit_reader.read_bits(3)?;
42            let spatial_id = bit_reader.read_bits(2)?;
43            bit_reader.read_bits(3)?; // reserved_3bits
44            Some(ObuExtensionHeader {
45                temporal_id: temporal_id as u8,
46                spatial_id: spatial_id as u8,
47            })
48        } else {
49            None
50        };
51
52        let size = if has_size_field {
53            // obu_size
54            Some(read_leb128(&mut bit_reader)?)
55        } else {
56            None
57        };
58
59        if !bit_reader.is_aligned() {
60            return Err(io::Error::new(io::ErrorKind::InvalidData, "bit reader is not aligned"));
61        }
62
63        Ok(ObuHeader {
64            obu_type: ObuType::from(obu_type as u8),
65            size,
66            extension_header,
67        })
68    }
69}
70
71/// OBU Type
72/// AV1-Spec-2 - 6.2.2
73#[derive(Debug, Clone, PartialEq, Eq, Copy)]
74pub enum ObuType {
75    SequenceHeader,
76    TemporalDelimiter,
77    FrameHeader,
78    TileGroup,
79    Metadata,
80    Frame,
81    RedundantFrameHeader,
82    TileList,
83    Padding,
84    Reserved(u8),
85}
86
87impl From<u8> for ObuType {
88    fn from(value: u8) -> Self {
89        match value {
90            1 => ObuType::SequenceHeader,
91            2 => ObuType::TemporalDelimiter,
92            3 => ObuType::FrameHeader,
93            4 => ObuType::TileGroup,
94            5 => ObuType::Metadata,
95            6 => ObuType::Frame,
96            7 => ObuType::RedundantFrameHeader,
97            8 => ObuType::TileList,
98            15 => ObuType::Padding,
99            _ => ObuType::Reserved(value),
100        }
101    }
102}
103
104impl From<ObuType> for u8 {
105    fn from(value: ObuType) -> Self {
106        match value {
107            ObuType::SequenceHeader => 1,
108            ObuType::TemporalDelimiter => 2,
109            ObuType::FrameHeader => 3,
110            ObuType::TileGroup => 4,
111            ObuType::Metadata => 5,
112            ObuType::Frame => 6,
113            ObuType::RedundantFrameHeader => 7,
114            ObuType::TileList => 8,
115            ObuType::Padding => 15,
116            ObuType::Reserved(value) => value,
117        }
118    }
119}
120
121#[cfg(test)]
122#[cfg_attr(all(coverage_nightly, test), coverage(off))]
123mod tests {
124    use bytes::Buf;
125
126    use super::*;
127
128    #[test]
129    fn test_obu_header_parse() {
130        let mut cursor = std::io::Cursor::new(b"\n\x0f\0\0\0j\xef\xbf\xe1\xbc\x02\x19\x90\x10\x10\x10@");
131        let header = ObuHeader::parse(&mut cursor).unwrap();
132        insta::assert_debug_snapshot!(header, @r"
133        ObuHeader {
134            obu_type: SequenceHeader,
135            size: Some(
136                15,
137            ),
138            extension_header: None,
139        }
140        ");
141
142        assert_eq!(cursor.position(), 2);
143        assert_eq!(cursor.remaining(), 15);
144    }
145
146    #[test]
147    fn test_obu_header_parse_no_size_field() {
148        let mut cursor = std::io::Cursor::new(b"\x00");
149        let header = ObuHeader::parse(&mut cursor).unwrap();
150        insta::assert_debug_snapshot!(header, @r"
151        ObuHeader {
152            obu_type: Reserved(
153                0,
154            ),
155            size: None,
156            extension_header: None,
157        }
158        ");
159
160        assert_eq!(cursor.position(), 1);
161        assert_eq!(cursor.remaining(), 0);
162    }
163
164    #[test]
165    fn test_obu_header_parse_extension_header() {
166        let mut cursor = std::io::Cursor::new([0b00000100, 0b11010000]);
167        let header = ObuHeader::parse(&mut cursor).unwrap();
168        insta::assert_debug_snapshot!(header, @r"
169        ObuHeader {
170            obu_type: Reserved(
171                0,
172            ),
173            size: None,
174            extension_header: Some(
175                ObuExtensionHeader {
176                    temporal_id: 6,
177                    spatial_id: 2,
178                },
179            ),
180        }
181        ");
182
183        assert_eq!(cursor.position(), 2);
184        assert_eq!(cursor.remaining(), 0);
185    }
186
187    #[test]
188    fn test_obu_header_forbidden_bit_set() {
189        let err = ObuHeader::parse(&mut std::io::Cursor::new(
190            b"\xff\x0f\0\0\0j\xef\xbf\xe1\xbc\x02\x19\x90\x10\x10\x10@",
191        ))
192        .unwrap_err();
193        insta::assert_debug_snapshot!(err, @r#"
194        Custom {
195            kind: InvalidData,
196            error: "obu_forbidden_bit is not 0",
197        }
198        "#);
199    }
200
201    #[test]
202    fn test_obu_to_from_u8() {
203        let case = [
204            (ObuType::SequenceHeader, 1),
205            (ObuType::TemporalDelimiter, 2),
206            (ObuType::FrameHeader, 3),
207            (ObuType::TileGroup, 4),
208            (ObuType::Metadata, 5),
209            (ObuType::Frame, 6),
210            (ObuType::RedundantFrameHeader, 7),
211            (ObuType::TileList, 8),
212            (ObuType::Padding, 15),
213            (ObuType::Reserved(0), 0),
214            (ObuType::Reserved(100), 100),
215        ];
216
217        for (obu_type, value) in case {
218            assert_eq!(u8::from(obu_type), value);
219            assert_eq!(ObuType::from(value), obu_type);
220        }
221    }
222}