scuffle_rtmp/session/
server_session.rs

1use std::borrow::Cow;
2use std::time::Duration;
3
4use bytes::BytesMut;
5use scuffle_amf0::Amf0Value;
6use scuffle_bytes_util::BytesCursorExt;
7use scuffle_future_ext::FutureExt;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::sync::oneshot;
10
11use super::define::RtmpCommand;
12use super::errors::SessionError;
13use crate::channels::{ChannelData, DataProducer, PublishRequest, UniqueID};
14use crate::chunk::{ChunkDecoder, ChunkEncoder, CHUNK_SIZE};
15use crate::handshake::{HandshakeServer, ServerHandshakeState};
16use crate::messages::{MessageParser, RtmpMessageData};
17use crate::netconnection::NetConnection;
18use crate::netstream::NetStreamWriter;
19use crate::protocol_control_messages::ProtocolControlMessagesWriter;
20use crate::user_control_messages::EventMessagesWriter;
21use crate::{handshake, PublishProducer};
22
23pub struct Session<S> {
24    /// When you connect via rtmp, you specify the app name in the url
25    /// For example: rtmp://localhost:1935/live/xyz
26    /// The app name is "live"
27    /// The next part of the url is the stream name (or the stream key) "xyz"
28    /// However the stream key is not required to be the same for each stream
29    /// you publish / play Traditionally we only publish a single stream per
30    /// RTMP connection, However we can publish multiple streams per RTMP
31    /// connection (using different stream keys) and or play multiple streams
32    /// per RTMP connection (using different stream keys) as per the RTMP spec.
33    app_name: Option<String>,
34
35    /// This is a unique id for this session
36    /// This is issued when the client connects to the server
37    uid: Option<UniqueID>,
38
39    /// Used to read and write data
40    io: S,
41
42    /// Buffer to read data into
43    read_buf: BytesMut,
44    /// Buffer to write data to
45    write_buf: Vec<u8>,
46
47    /// Sometimes when doing the handshake we read too much data,
48    /// this flag is used to indicate that we have data ready to parse and we
49    /// should not read more data from the stream
50    skip_read: bool,
51
52    /// This is used to read the data from the stream and convert it into rtmp
53    /// messages
54    chunk_decoder: ChunkDecoder,
55    /// This is used to convert rtmp messages into chunks
56    chunk_encoder: ChunkEncoder,
57
58    /// StreamID
59    stream_id: u32,
60
61    /// Data Producer
62    data_producer: DataProducer,
63
64    /// Is Publishing
65    is_publishing: bool,
66
67    /// when the publisher connects and tries to publish a stream, we need to
68    /// send a publish request to the server
69    publish_request_producer: PublishProducer,
70}
71
72impl<S> Session<S> {
73    pub fn new(io: S, data_producer: DataProducer, publish_request_producer: PublishProducer) -> Self {
74        Self {
75            uid: None,
76            app_name: None,
77            io,
78            skip_read: false,
79            chunk_decoder: ChunkDecoder::default(),
80            chunk_encoder: ChunkEncoder::default(),
81            read_buf: BytesMut::new(),
82            write_buf: Vec::new(),
83            data_producer,
84            stream_id: 0,
85            is_publishing: false,
86            publish_request_producer,
87        }
88    }
89
90    pub fn uid(&self) -> Option<UniqueID> {
91        self.uid
92    }
93}
94
95impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> Session<S> {
96    /// Run the session to completion
97    /// The result of the return value will be true if all publishers have
98    /// disconnected If any publishers are still connected, the result will be
99    /// false This can be used to detect non-graceful disconnects (ie. the
100    /// client crashed)
101    pub async fn run(&mut self) -> Result<bool, SessionError> {
102        let mut handshaker = HandshakeServer::default();
103        // Run the handshake to completion
104        while !self.do_handshake(&mut handshaker).await? {
105            self.flush().await?;
106        }
107
108        // Drop the handshaker, we don't need it anymore
109        // We can get rid of the memory that was allocated for it
110        drop(handshaker);
111
112        tracing::debug!("Handshake complete");
113
114        // Run the session to completion
115        while match self.do_ready().await {
116            Ok(v) => v,
117            Err(err) if err.is_client_closed() => {
118                // The client closed the connection
119                // We are done with the session
120                tracing::debug!("Client closed the connection");
121                false
122            }
123            Err(e) => {
124                return Err(e);
125            }
126        } {
127            self.flush().await?;
128        }
129
130        // We should technically check the stream_map here
131        // However most clients just disconnect without cleanly stopping the subscrition
132        // streams (play streams) So we just check that all publishers have disconnected
133        // cleanly
134        Ok(!self.is_publishing)
135    }
136
137    /// This is the first stage of the session
138    /// It is used to do the handshake with the client
139    /// The handshake is the first thing that happens when you connect to an
140    /// rtmp server
141    async fn do_handshake(&mut self, handshaker: &mut HandshakeServer) -> Result<bool, SessionError> {
142        // Read the handshake data + 1 byte for the version
143        const READ_SIZE: usize = handshake::RTMP_HANDSHAKE_SIZE + 1;
144        self.read_buf.reserve(READ_SIZE);
145
146        let mut bytes_read = 0;
147        while bytes_read < READ_SIZE {
148            let n = self
149                .io
150                .read_buf(&mut self.read_buf)
151                .with_timeout(Duration::from_secs(2))
152                .await??;
153            bytes_read += n;
154        }
155
156        let mut cursor = std::io::Cursor::new(self.read_buf.split().freeze());
157
158        handshaker.handshake(&mut cursor, &mut self.write_buf)?;
159
160        if handshaker.state() == ServerHandshakeState::Finish {
161            let over_read = cursor.extract_remaining();
162
163            if !over_read.is_empty() {
164                self.skip_read = true;
165                self.read_buf.extend_from_slice(&over_read);
166            }
167
168            self.send_set_chunk_size().await?;
169
170            // We are done with the handshake
171            // This causes the loop to exit
172            // And move onto the next stage of the session
173            Ok(true)
174        } else {
175            // We are not done with the handshake yet
176            // We need to read more data from the stream
177            // This causes the loop to continue
178            Ok(false)
179        }
180    }
181
182    /// This is the second stage of the session
183    /// It is used to read data from the stream and parse it into rtmp messages
184    /// We also send data to the client if they are playing a stream
185    async fn do_ready(&mut self) -> Result<bool, SessionError> {
186        // If we have data ready to parse, parse it
187        if self.skip_read {
188            self.skip_read = false;
189        } else {
190            self.read_buf.reserve(CHUNK_SIZE);
191
192            let n = self
193                .io
194                .read_buf(&mut self.read_buf)
195                .with_timeout(Duration::from_millis(2500))
196                .await??;
197
198            if n == 0 {
199                return Ok(false);
200            }
201        }
202
203        self.parse_chunks().await?;
204
205        Ok(true)
206    }
207
208    /// Parse data from the client into rtmp messages and process them
209    async fn parse_chunks(&mut self) -> Result<(), SessionError> {
210        while let Some(chunk) = self.chunk_decoder.read_chunk(&mut self.read_buf)? {
211            let timestamp = chunk.message_header.timestamp;
212            let msg_stream_id = chunk.message_header.msg_stream_id;
213
214            if let Some(msg) = MessageParser::parse(&chunk)? {
215                self.process_messages(msg, msg_stream_id, timestamp).await?;
216            }
217        }
218
219        Ok(())
220    }
221
222    /// Process rtmp messages
223    async fn process_messages(
224        &mut self,
225        rtmp_msg: RtmpMessageData<'_>,
226        stream_id: u32,
227        timestamp: u32,
228    ) -> Result<(), SessionError> {
229        match rtmp_msg {
230            RtmpMessageData::Amf0Command {
231                command_name,
232                transaction_id,
233                command_object,
234                others,
235            } => {
236                self.on_amf0_command_message(stream_id, command_name, transaction_id, command_object, others)
237                    .await?
238            }
239            RtmpMessageData::SetChunkSize { chunk_size } => {
240                self.on_set_chunk_size(chunk_size as usize)?;
241            }
242            RtmpMessageData::AudioData { data } => {
243                self.on_data(stream_id, ChannelData::Audio { timestamp, data }).await?;
244            }
245            RtmpMessageData::VideoData { data } => {
246                self.on_data(stream_id, ChannelData::Video { timestamp, data }).await?;
247            }
248            RtmpMessageData::AmfData { data } => {
249                self.on_data(stream_id, ChannelData::Metadata { timestamp, data }).await?;
250            }
251        }
252
253        Ok(())
254    }
255
256    /// Set the server chunk size to the client
257    async fn send_set_chunk_size(&mut self) -> Result<(), SessionError> {
258        ProtocolControlMessagesWriter::write_set_chunk_size(&self.chunk_encoder, &mut self.write_buf, CHUNK_SIZE as u32)?;
259        self.chunk_encoder.set_chunk_size(CHUNK_SIZE);
260
261        Ok(())
262    }
263
264    /// on_data is called when we receive a data message from the client (a
265    /// published_stream) Such as audio, video, or metadata
266    /// We then forward the data to the specified publisher
267    async fn on_data(&self, stream_id: u32, data: ChannelData) -> Result<(), SessionError> {
268        if stream_id != self.stream_id || !self.is_publishing {
269            return Err(SessionError::UnknownStreamID(stream_id));
270        };
271
272        if matches!(
273            self.data_producer.send(data).with_timeout(Duration::from_secs(2)).await,
274            Err(_) | Ok(Err(_))
275        ) {
276            tracing::debug!("Publisher dropped");
277            return Err(SessionError::PublisherDropped);
278        }
279
280        Ok(())
281    }
282
283    /// on_amf0_command_message is called when we receive an AMF0 command
284    /// message from the client We then handle the command message
285    async fn on_amf0_command_message(
286        &mut self,
287        stream_id: u32,
288        command_name: Amf0Value<'_>,
289        transaction_id: Amf0Value<'_>,
290        command_object: Amf0Value<'_>,
291        others: Vec<Amf0Value<'_>>,
292    ) -> Result<(), SessionError> {
293        let cmd = RtmpCommand::from(match command_name {
294            Amf0Value::String(ref s) => s,
295            _ => "",
296        });
297
298        let transaction_id = match transaction_id {
299            Amf0Value::Number(number) => number,
300            _ => 0.0,
301        };
302
303        let obj = match command_object {
304            Amf0Value::Object(obj) => obj,
305            _ => Cow::Owned(Vec::new()),
306        };
307
308        match cmd {
309            RtmpCommand::Connect => {
310                self.on_command_connect(transaction_id, stream_id, &obj, others).await?;
311            }
312            RtmpCommand::CreateStream => {
313                self.on_command_create_stream(transaction_id, stream_id, &obj, others).await?;
314            }
315            RtmpCommand::DeleteStream => {
316                self.on_command_delete_stream(transaction_id, stream_id, &obj, others).await?;
317            }
318            RtmpCommand::Play => {
319                return Err(SessionError::PlayNotSupported);
320            }
321            RtmpCommand::Publish => {
322                self.on_command_publish(transaction_id, stream_id, &obj, others).await?;
323            }
324            RtmpCommand::CloseStream | RtmpCommand::ReleaseStream => {
325                // Not sure what this is for
326            }
327            RtmpCommand::Unknown(_) => {}
328        }
329
330        Ok(())
331    }
332
333    /// on_set_chunk_size is called when we receive a set chunk size message
334    /// from the client We then update the chunk size of the unpacketizer
335    fn on_set_chunk_size(&mut self, chunk_size: usize) -> Result<(), SessionError> {
336        if self.chunk_decoder.update_max_chunk_size(chunk_size) {
337            Ok(())
338        } else {
339            Err(SessionError::InvalidChunkSize(chunk_size))
340        }
341    }
342
343    /// on_command_connect is called when we receive a amf0 command message with
344    /// the name "connect" We then handle the connect message
345    /// This is called when the client first connects to the server
346    async fn on_command_connect(
347        &mut self,
348        transaction_id: f64,
349        _stream_id: u32,
350        command_obj: &[(Cow<'_, str>, Amf0Value<'_>)],
351        _others: Vec<Amf0Value<'_>>,
352    ) -> Result<(), SessionError> {
353        ProtocolControlMessagesWriter::write_window_acknowledgement_size(
354            &self.chunk_encoder,
355            &mut self.write_buf,
356            CHUNK_SIZE as u32,
357        )?;
358
359        ProtocolControlMessagesWriter::write_set_peer_bandwidth(
360            &self.chunk_encoder,
361            &mut self.write_buf,
362            CHUNK_SIZE as u32,
363            2, // 2 = dynamic
364        )?;
365
366        let app_name = command_obj.iter().find(|(key, _)| key == "app");
367        let app_name = match app_name {
368            Some((_, Amf0Value::String(app))) => app,
369            _ => {
370                return Err(SessionError::NoAppName);
371            }
372        };
373
374        self.app_name = Some(app_name.to_string());
375
376        // The only AMF encoding supported by this server is AMF0
377        // So we ignore the objectEncoding value sent by the client
378        // and always use AMF0
379        // - OBS does not support AMF3 (https://github.com/obsproject/obs-studio/blob/1be1f51635ac85b3ad768a88b3265b192bd0bf18/plugins/obs-outputs/librtmp/rtmp.c#L1737)
380        // - Ffmpeg does not support AMF3 either (https://github.com/FFmpeg/FFmpeg/blob/c125860892e931d9b10f88ace73c91484815c3a8/libavformat/rtmpproto.c#L569)
381        // - NginxRTMP does not support AMF3 (https://github.com/arut/nginx-rtmp-module/issues/313)
382        // - SRS does not support AMF3 (https://github.com/ossrs/srs/blob/dcd02fe69cdbd7f401a7b8d139d95b522deb55b1/trunk/src/protocol/srs_protocol_rtmp_stack.cpp#L599)
383        // However, the new enhanced-rtmp-v1 spec from YouTube does encourage the use of AMF3 over AMF0 (https://github.com/veovera/enhanced-rtmp)
384        // We will eventually support this spec but for now we will stick to AMF0
385        NetConnection::write_connect_response(
386            &self.chunk_encoder,
387            &mut self.write_buf,
388            transaction_id,
389            "FMS/3,0,1,123", // flash version (this value is used by other media servers as well)
390            31.0,            // No idea what this is, but it is used by other media servers as well
391            "NetConnection.Connect.Success",
392            "status", // Again not sure what this is but other media servers use it.
393            "Connection Succeeded.",
394            0.0,
395        )?;
396
397        Ok(())
398    }
399
400    /// on_command_create_stream is called when we receive a amf0 command
401    /// message with the name "createStream" We then handle the createStream
402    /// message This is called when the client wants to create a stream
403    /// A NetStream is used to start publishing or playing a stream
404    async fn on_command_create_stream(
405        &mut self,
406        transaction_id: f64,
407        _stream_id: u32,
408        _command_obj: &[(Cow<'_, str>, Amf0Value<'_>)],
409        _others: Vec<Amf0Value<'_>>,
410    ) -> Result<(), SessionError> {
411        // 1.0 is the Stream ID of the stream we are creating
412        NetConnection::write_create_stream_response(&self.chunk_encoder, &mut self.write_buf, transaction_id, 1.0)?;
413
414        Ok(())
415    }
416
417    /// A delete stream message is unrelated to the NetConnection close method.
418    /// Delete stream is basically a way to tell the server that you are done
419    /// publishing or playing a stream. The server will then remove the stream
420    /// from its list of streams.
421    async fn on_command_delete_stream(
422        &mut self,
423        transaction_id: f64,
424        _stream_id: u32,
425        _command_obj: &[(Cow<'_, str>, Amf0Value<'_>)],
426        others: Vec<Amf0Value<'_>>,
427    ) -> Result<(), SessionError> {
428        let stream_id = match others.first() {
429            Some(Amf0Value::Number(stream_id)) => *stream_id,
430            _ => 0.0,
431        } as u32;
432
433        if self.stream_id == stream_id && self.is_publishing {
434            self.stream_id = 0;
435            self.is_publishing = false;
436        }
437
438        NetStreamWriter::write_on_status(
439            &self.chunk_encoder,
440            &mut self.write_buf,
441            transaction_id,
442            "status",
443            "NetStream.DeleteStream.Suceess",
444            "",
445        )?;
446
447        Ok(())
448    }
449
450    /// on_command_publish is called when we receive a amf0 command message with
451    /// the name "publish" publish commands are used to publish a stream to the
452    /// server ie. the user wants to start streaming to the server
453    async fn on_command_publish(
454        &mut self,
455        transaction_id: f64,
456        stream_id: u32,
457        _command_obj: &[(Cow<'_, str>, Amf0Value<'_>)],
458        others: Vec<Amf0Value<'_>>,
459    ) -> Result<(), SessionError> {
460        let stream_name = match others.first() {
461            Some(Amf0Value::String(val)) => val,
462            _ => {
463                return Err(SessionError::NoStreamName);
464            }
465        };
466
467        let Some(app_name) = &self.app_name else {
468            return Err(SessionError::NoAppName);
469        };
470
471        let (response, waiter) = oneshot::channel();
472
473        if self
474            .publish_request_producer
475            .send(PublishRequest {
476                app_name: app_name.clone(),
477                stream_name: stream_name.to_string(),
478                response,
479            })
480            .await
481            .is_err()
482        {
483            return Err(SessionError::PublishRequestDenied);
484        }
485
486        let Ok(uid) = waiter.await else {
487            return Err(SessionError::PublishRequestDenied);
488        };
489
490        self.uid = Some(uid);
491
492        self.is_publishing = true;
493        self.stream_id = stream_id;
494
495        EventMessagesWriter::write_stream_begin(&self.chunk_encoder, &mut self.write_buf, stream_id)?;
496
497        NetStreamWriter::write_on_status(
498            &self.chunk_encoder,
499            &mut self.write_buf,
500            transaction_id,
501            "status",
502            "NetStream.Publish.Start",
503            "",
504        )?;
505
506        Ok(())
507    }
508
509    async fn flush(&mut self) -> Result<(), SessionError> {
510        if !self.write_buf.is_empty() {
511            self.io
512                .write_all(self.write_buf.as_ref())
513                .with_timeout(Duration::from_secs(2))
514                .await??;
515            self.write_buf.clear();
516        }
517
518        Ok(())
519    }
520}